analyze_data.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = ["pydantic>=2.0"]
  5. # ///
  6. """
  7. Dataset Analysis and Quality Report Generator
  8. Analyzes training data loaded through the strict Pydantic schema for:
  9. 1. Query length distribution
  10. 2. Category diversity
  11. 3. Named entity coverage
  12. 4. Output format coverage
  13. 5. Duplicate detection
  14. """
  15. import argparse
  16. import sys
  17. from pathlib import Path
  18. from collections import Counter, defaultdict
  19. from dataclasses import dataclass
  20. sys.path.insert(0, str(Path(__file__).parent.parent))
  21. from dataset.schema import TrainingExample, OutputType, load_examples
  22. @dataclass
  23. class DatasetStats:
  24. total_examples: int = 0
  25. short_queries: int = 0
  26. medium_queries: int = 0
  27. long_queries: int = 0
  28. has_lex: int = 0
  29. has_vec: int = 0
  30. has_hyde: int = 0
  31. long_hyde_count: int = 0
  32. duplicate_queries: int = 0
  33. named_entity_queries: int = 0
  34. temporal_queries: int = 0
  35. short_keyword_queries: int = 0
  36. def categorize_query(query: str) -> str:
  37. query_lower = query.lower()
  38. words = query_lower.split()
  39. word_count = len(words)
  40. if word_count <= 2:
  41. return "short_keyword"
  42. if any(w[0].isupper() for w in query.split() if w):
  43. return "named_entity"
  44. temporal_keywords = [
  45. "latest", "recent", "new", "update", "changelog",
  46. "changed", "version", "release", "news", "2024", "2025",
  47. ]
  48. if any(kw in query_lower for kw in temporal_keywords):
  49. return "temporal"
  50. if query_lower.startswith("how "):
  51. return "how_to"
  52. if query_lower.startswith("what "):
  53. return "what_is"
  54. if any(kw in query_lower for kw in ["difference", "vs", "versus", "compare"]):
  55. return "comparison"
  56. if any(kw in query_lower for kw in ["meeting", "notes", "journal", "ideas", "thoughts"]):
  57. return "personal"
  58. return "other"
  59. def extract_named_entities(query: str) -> list:
  60. entities = []
  61. stopwords = {"the", "a", "an", "is", "are", "to", "for", "of", "in", "and", "or"}
  62. for word in query.split():
  63. if word.lower() in stopwords:
  64. continue
  65. if word and word[0].isupper() and len(word) > 1:
  66. entities.append(word)
  67. if any(c in word for c in ".+-0123456789") and len(word) > 1:
  68. entities.append(word)
  69. return entities
  70. def analyze_examples(examples: list[TrainingExample]) -> tuple[DatasetStats, dict, dict]:
  71. stats = DatasetStats()
  72. categories: Counter = Counter()
  73. seen_queries: set[str] = set()
  74. category_examples: dict[str, list[str]] = defaultdict(list)
  75. for ex in examples:
  76. stats.total_examples += 1
  77. query_lower = ex.query.lower()
  78. if query_lower in seen_queries:
  79. stats.duplicate_queries += 1
  80. else:
  81. seen_queries.add(query_lower)
  82. word_count = len(ex.query.split())
  83. if word_count <= 2:
  84. stats.short_queries += 1
  85. elif word_count <= 5:
  86. stats.medium_queries += 1
  87. else:
  88. stats.long_queries += 1
  89. category = categorize_query(ex.query)
  90. categories[category] += 1
  91. category_examples[category].append(ex.query)
  92. if extract_named_entities(ex.query):
  93. stats.named_entity_queries += 1
  94. # Use the typed OutputPair model
  95. types_present = {p.type for p in ex.output}
  96. if OutputType.lex in types_present:
  97. stats.has_lex += 1
  98. if OutputType.vec in types_present:
  99. stats.has_vec += 1
  100. if OutputType.hyde in types_present:
  101. stats.has_hyde += 1
  102. for p in ex.output:
  103. if p.type == OutputType.hyde and len(p.text) > 200:
  104. stats.long_hyde_count += 1
  105. stats.temporal_queries = categories.get("temporal", 0)
  106. stats.short_keyword_queries = categories.get("short_keyword", 0)
  107. return stats, dict(categories), dict(category_examples)
  108. def print_report(stats: DatasetStats, categories: dict, category_examples: dict):
  109. print("=" * 70)
  110. print("QMD TRAINING DATA ANALYSIS REPORT")
  111. print("=" * 70)
  112. print()
  113. total = stats.total_examples
  114. print("BASIC STATISTICS")
  115. print("-" * 40)
  116. print(f"Total examples: {total:>6}")
  117. print(f"Duplicates found: {stats.duplicate_queries:>6}")
  118. print()
  119. print("QUERY LENGTH DISTRIBUTION")
  120. print("-" * 40)
  121. print(f"Short (1-2 words): {stats.short_queries:>6} ({100 * stats.short_queries / total:5.1f}%)")
  122. print(f"Medium (3-5 words): {stats.medium_queries:>6} ({100 * stats.medium_queries / total:5.1f}%)")
  123. print(f"Long (6+ words): {stats.long_queries:>6} ({100 * stats.long_queries / total:5.1f}%)")
  124. print()
  125. print("CATEGORY DISTRIBUTION")
  126. print("-" * 40)
  127. for cat, count in sorted(categories.items(), key=lambda x: -x[1]):
  128. pct = 100 * count / total
  129. bar = "#" * int(pct / 2)
  130. print(f"{cat:20} {count:>6} ({pct:5.1f}%) {bar}")
  131. print()
  132. print("OUTPUT FORMAT COVERAGE")
  133. print("-" * 40)
  134. print(f"Has lex: {stats.has_lex:>6} ({100 * stats.has_lex / total:5.1f}%)")
  135. print(f"Has vec: {stats.has_vec:>6} ({100 * stats.has_vec / total:5.1f}%)")
  136. print(f"Has hyde: {stats.has_hyde:>6} ({100 * stats.has_hyde / total:5.1f}%)")
  137. print(f"Long hyde (>200ch): {stats.long_hyde_count:>6}")
  138. print()
  139. print("EVALUATION ALIGNMENT")
  140. print("-" * 40)
  141. print(f"Named entity queries: {stats.named_entity_queries:>6} ({100 * stats.named_entity_queries / total:5.1f}%)")
  142. print(f"Temporal/recency: {stats.temporal_queries:>6} ({100 * stats.temporal_queries / total:5.1f}%)")
  143. print(f"Short keyword queries: {stats.short_keyword_queries:>6} ({100 * stats.short_keyword_queries / total:5.1f}%)")
  144. print()
  145. print("RECOMMENDATIONS")
  146. print("-" * 40)
  147. recommendations = []
  148. if stats.short_queries / total < 0.15:
  149. recommendations.append("Short queries below 15% - add more 1-2 word keyword queries")
  150. if stats.named_entity_queries / total < 0.10:
  151. recommendations.append("Named entity queries below 10% - add more capitalized tech term queries")
  152. if stats.temporal_queries / total < 0.05:
  153. recommendations.append("Temporal queries below 5% - add more 'latest', 'recent' queries")
  154. if stats.long_hyde_count > 50:
  155. recommendations.append(f"{stats.long_hyde_count} long hyde sections - consider truncating")
  156. if stats.duplicate_queries > 0:
  157. recommendations.append(f"{stats.duplicate_queries} duplicate queries - consider deduplication")
  158. if not recommendations:
  159. print("Dataset looks good! No major issues detected.")
  160. else:
  161. for rec in recommendations:
  162. print(f" - {rec}")
  163. print()
  164. print("=" * 70)
  165. def main():
  166. parser = argparse.ArgumentParser(description="Analyze QMD training dataset")
  167. parser.add_argument(
  168. "--input",
  169. type=str,
  170. default="data/qmd_expansion_v3_structured.jsonl",
  171. help="Path to training data JSONL file",
  172. )
  173. parser.add_argument(
  174. "--show-examples",
  175. type=int,
  176. default=3,
  177. help="Number of example queries to show per category",
  178. )
  179. args = parser.parse_args()
  180. input_path = Path(args.input)
  181. if not input_path.exists():
  182. script_dir = Path(__file__).parent.parent
  183. input_path = script_dir / args.input
  184. if not input_path.exists():
  185. print(f"Error: Could not find dataset at {input_path}")
  186. return 1
  187. print(f"Analyzing: {input_path}")
  188. print()
  189. examples = load_examples(input_path)
  190. stats, categories, category_examples = analyze_examples(examples)
  191. print_report(stats, categories, category_examples)
  192. if args.show_examples > 0:
  193. print("SAMPLE QUERIES BY CATEGORY")
  194. print("-" * 40)
  195. for cat in sorted(categories.keys()):
  196. exs = category_examples.get(cat, [])
  197. if exs:
  198. print(f"\n{cat.upper()}:")
  199. for ex in exs[:args.show_examples]:
  200. print(f" - {ex}")
  201. print()
  202. return 0
  203. if __name__ == "__main__":
  204. exit(main())