analyze_data.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. #!/usr/bin/env python3
  2. """
  3. Dataset Analysis and Quality Report Generator
  4. Analyzes the training data for:
  5. 1. Query length distribution
  6. 2. Category diversity
  7. 3. Named entity coverage
  8. 4. Temporal query coverage
  9. 5. Short query coverage (important for ambiguous queries)
  10. 6. Duplicate detection
  11. 7. Quality issues (long hyde, missing fields, etc.)
  12. """
  13. import json
  14. import re
  15. import sys
  16. from pathlib import Path
  17. from collections import Counter, defaultdict
  18. from dataclasses import dataclass
  19. sys.path.insert(0, str(Path(__file__).parent.parent))
  20. from dataset.schema import normalize_output_items, parse_output_text
  21. @dataclass
  22. class DatasetStats:
  23. total_examples: int = 0
  24. short_queries: int = 0 # 1-2 words
  25. medium_queries: int = 0 # 3-5 words
  26. long_queries: int = 0 # 6+ words
  27. has_lex: int = 0
  28. has_vec: int = 0
  29. has_hyde: int = 0
  30. long_hyde_count: int = 0
  31. duplicate_queries: int = 0
  32. named_entity_queries: int = 0
  33. temporal_queries: int = 0
  34. short_keyword_queries: int = 0
  35. def categorize_query(query: str) -> str:
  36. """Categorize a query by type."""
  37. query_lower = query.lower()
  38. words = query_lower.split()
  39. word_count = len(words)
  40. # Short keyword queries
  41. if word_count <= 2:
  42. return "short_keyword"
  43. # Named entity queries (capitalized words or tech terms)
  44. if any(w[0].isupper() for w in words if w):
  45. return "named_entity"
  46. # Temporal/recency queries
  47. temporal_keywords = [
  48. "latest",
  49. "recent",
  50. "new",
  51. "update",
  52. "changelog",
  53. "changed",
  54. "version",
  55. "release",
  56. "news",
  57. "2024",
  58. "2025",
  59. ]
  60. if any(kw in query_lower for kw in temporal_keywords):
  61. return "temporal"
  62. # How-to queries
  63. if query_lower.startswith("how "):
  64. return "how_to"
  65. # What is queries
  66. if query_lower.startswith("what "):
  67. return "what_is"
  68. # Difference/comparison queries
  69. if any(kw in query_lower for kw in ["difference", "vs", "versus", "compare"]):
  70. return "comparison"
  71. # Personal/journal style
  72. if any(
  73. kw in query_lower for kw in ["meeting", "notes", "journal", "ideas", "thoughts"]
  74. ):
  75. return "personal"
  76. return "other"
  77. def extract_named_entities(query: str) -> list:
  78. """Extract potential named entities from query."""
  79. entities = []
  80. words = query.split()
  81. for word in words:
  82. # Skip stopwords
  83. if word.lower() in {
  84. "the",
  85. "a",
  86. "an",
  87. "is",
  88. "are",
  89. "to",
  90. "for",
  91. "of",
  92. "in",
  93. "and",
  94. "or",
  95. }:
  96. continue
  97. # Capitalized words (potential named entities)
  98. if word and word[0].isupper() and len(word) > 1:
  99. entities.append(word)
  100. # Technology terms with version numbers or special chars
  101. if any(c in word for c in ".+-0123456789") and len(word) > 1:
  102. entities.append(word)
  103. return entities
  104. def analyze_dataset(filepath: Path) -> tuple[DatasetStats, dict, dict]:
  105. """Analyze the dataset and return statistics."""
  106. stats = DatasetStats()
  107. categories = Counter()
  108. seen_queries = set()
  109. duplicate_count = 0
  110. category_examples = defaultdict(list)
  111. with open(filepath, "r", encoding="utf-8") as f:
  112. for line_num, line in enumerate(f, 1):
  113. line = line.strip()
  114. if not line:
  115. continue
  116. try:
  117. example = json.loads(line)
  118. query = example.get("query", "") or example.get("input", "")
  119. output = example.get("output", [])
  120. if isinstance(output, str):
  121. output = parse_output_text(output)
  122. output = normalize_output_items(output)
  123. stats.total_examples += 1
  124. # Check for duplicates
  125. query_lower = query.lower()
  126. if query_lower in seen_queries:
  127. duplicate_count += 1
  128. else:
  129. seen_queries.add(query_lower)
  130. # Query length categorization
  131. word_count = len(query.split())
  132. if word_count <= 2:
  133. stats.short_queries += 1
  134. elif word_count <= 5:
  135. stats.medium_queries += 1
  136. else:
  137. stats.long_queries += 1
  138. # Category detection
  139. category = categorize_query(query)
  140. categories[category] += 1
  141. category_examples[category].append(query)
  142. # Named entity detection
  143. if extract_named_entities(query):
  144. stats.named_entity_queries += 1
  145. # Output analysis
  146. has_lex = any(o[0] == "lex" for o in output)
  147. has_vec = any(o[0] == "vec" for o in output)
  148. has_hyde = any(o[0] == "hyde" for o in output)
  149. if has_lex:
  150. stats.has_lex += 1
  151. if has_vec:
  152. stats.has_vec += 1
  153. if has_hyde:
  154. stats.has_hyde += 1
  155. # Check hyde length
  156. for kind, text in output:
  157. if kind == "hyde" and len(text) > 200:
  158. stats.long_hyde_count += 1
  159. except json.JSONDecodeError:
  160. print(f"Warning: Could not parse line {line_num}")
  161. stats.duplicate_queries = duplicate_count
  162. stats.temporal_queries = categories.get("temporal", 0)
  163. stats.short_keyword_queries = categories.get("short_keyword", 0)
  164. return stats, dict(categories), dict(category_examples)
  165. def print_report(stats: DatasetStats, categories: dict, category_examples: dict):
  166. """Print a comprehensive analysis report."""
  167. print("=" * 70)
  168. print("QMD TRAINING DATA ANALYSIS REPORT")
  169. print("=" * 70)
  170. print()
  171. # Basic statistics
  172. print("📊 BASIC STATISTICS")
  173. print("-" * 40)
  174. print(f"Total examples: {stats.total_examples:>6}")
  175. print(f"Duplicates found: {stats.duplicate_queries:>6}")
  176. print()
  177. # Query length distribution
  178. print("📝 QUERY LENGTH DISTRIBUTION")
  179. print("-" * 40)
  180. total = stats.total_examples
  181. print(
  182. f"Short (1-2 words): {stats.short_queries:>6} ({100 * stats.short_queries / total:5.1f}%)"
  183. )
  184. print(
  185. f"Medium (3-5 words): {stats.medium_queries:>6} ({100 * stats.medium_queries / total:5.1f}%)"
  186. )
  187. print(
  188. f"Long (6+ words): {stats.long_queries:>6} ({100 * stats.long_queries / total:5.1f}%)"
  189. )
  190. print()
  191. # Category distribution
  192. print("🏷️ CATEGORY DISTRIBUTION")
  193. print("-" * 40)
  194. for cat, count in sorted(categories.items(), key=lambda x: -x[1]):
  195. pct = 100 * count / total
  196. bar = "█" * int(pct / 2)
  197. print(f"{cat:20} {count:>6} ({pct:5.1f}%) {bar}")
  198. print()
  199. # Output format coverage
  200. print("✅ OUTPUT FORMAT COVERAGE")
  201. print("-" * 40)
  202. print(
  203. f"Has lex: {stats.has_lex:>6} ({100 * stats.has_lex / total:5.1f}%)"
  204. )
  205. print(
  206. f"Has vec: {stats.has_vec:>6} ({100 * stats.has_vec / total:5.1f}%)"
  207. )
  208. print(
  209. f"Has hyde: {stats.has_hyde:>6} ({100 * stats.has_hyde / total:5.1f}%)"
  210. )
  211. print(f"Long hyde (>200ch): {stats.long_hyde_count:>6}")
  212. print()
  213. # Critical metrics for evals
  214. print("🎯 EVALUATION ALIGNMENT")
  215. print("-" * 40)
  216. print(
  217. f"Named entity queries: {stats.named_entity_queries:>6} ({100 * stats.named_entity_queries / total:5.1f}%)"
  218. )
  219. print(
  220. f"Temporal/recency: {stats.temporal_queries:>6} ({100 * stats.temporal_queries / total:5.1f}%)"
  221. )
  222. print(
  223. f"Short keyword queries: {stats.short_keyword_queries:>6} ({100 * stats.short_keyword_queries / total:5.1f}%)"
  224. )
  225. print()
  226. # Recommendations
  227. print("💡 RECOMMENDATIONS")
  228. print("-" * 40)
  229. recommendations = []
  230. if stats.short_queries / total < 0.15:
  231. recommendations.append(
  232. "⚠️ Short queries below 15% - add more 1-2 word keyword queries"
  233. )
  234. if stats.named_entity_queries / total < 0.10:
  235. recommendations.append(
  236. "⚠️ Named entity queries below 10% - add more capitalized tech term queries"
  237. )
  238. if stats.temporal_queries / total < 0.05:
  239. recommendations.append(
  240. "⚠️ Temporal queries below 5% - add more 'latest', 'recent' queries"
  241. )
  242. if stats.long_hyde_count > 50:
  243. recommendations.append(
  244. f"⚠️ {stats.long_hyde_count} long hyde sections - consider truncating"
  245. )
  246. if stats.duplicate_queries > 0:
  247. recommendations.append(
  248. f"⚠️ {stats.duplicate_queries} duplicate queries - consider deduplication"
  249. )
  250. if categories.get("short_keyword", 0) < 100:
  251. recommendations.append(
  252. "⚠️ Need more short keyword examples for ambiguous query training"
  253. )
  254. if not recommendations:
  255. print("✅ Dataset looks good! No major issues detected.")
  256. else:
  257. for rec in recommendations:
  258. print(rec)
  259. print()
  260. print("=" * 70)
  261. def main():
  262. """Main entry point."""
  263. import argparse
  264. parser = argparse.ArgumentParser(description="Analyze QMD training dataset")
  265. parser.add_argument(
  266. "--input",
  267. type=str,
  268. default="data/qmd_expansion_v2.jsonl",
  269. help="Path to training data JSONL file",
  270. )
  271. parser.add_argument(
  272. "--show-examples",
  273. type=int,
  274. default=3,
  275. help="Number of example queries to show per category",
  276. )
  277. args = parser.parse_args()
  278. input_path = Path(args.input)
  279. if not input_path.exists():
  280. # Try relative to script directory
  281. script_dir = Path(__file__).parent.parent
  282. input_path = script_dir / args.input
  283. if not input_path.exists():
  284. print(f"Error: Could not find dataset at {input_path}")
  285. print("Please run from finetune directory or specify correct path")
  286. return 1
  287. print(f"Analyzing: {input_path}")
  288. print()
  289. stats, categories, category_examples = analyze_dataset(input_path)
  290. print_report(stats, categories, category_examples)
  291. # Show examples if requested
  292. if args.show_examples > 0:
  293. print("📋 SAMPLE QUERIES BY CATEGORY")
  294. print("-" * 40)
  295. for cat in sorted(categories.keys()):
  296. examples = category_examples.get(cat, [])
  297. if examples:
  298. print(f"\n{cat.upper()}:")
  299. for ex in examples[: args.show_examples]:
  300. print(f" • {ex}")
  301. print()
  302. return 0
  303. if __name__ == "__main__":
  304. exit(main())