| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- #!/usr/bin/env python3
- # /// script
- # requires-python = ">=3.10"
- # dependencies = ["pydantic>=2.0"]
- # ///
- """
- Dataset Analysis and Quality Report Generator
- Analyzes training data loaded through the strict Pydantic schema for:
- 1. Query length distribution
- 2. Category diversity
- 3. Named entity coverage
- 4. Output format coverage
- 5. Duplicate detection
- """
- import argparse
- import sys
- from pathlib import Path
- from collections import Counter, defaultdict
- from dataclasses import dataclass
- sys.path.insert(0, str(Path(__file__).parent.parent))
- from dataset.schema import TrainingExample, OutputType, load_examples
- @dataclass
- class DatasetStats:
- total_examples: int = 0
- short_queries: int = 0
- medium_queries: int = 0
- long_queries: int = 0
- has_lex: int = 0
- has_vec: int = 0
- has_hyde: int = 0
- long_hyde_count: int = 0
- duplicate_queries: int = 0
- named_entity_queries: int = 0
- temporal_queries: int = 0
- short_keyword_queries: int = 0
- def categorize_query(query: str) -> str:
- query_lower = query.lower()
- words = query_lower.split()
- word_count = len(words)
- if word_count <= 2:
- return "short_keyword"
- if any(w[0].isupper() for w in query.split() if w):
- return "named_entity"
- temporal_keywords = [
- "latest", "recent", "new", "update", "changelog",
- "changed", "version", "release", "news", "2024", "2025",
- ]
- if any(kw in query_lower for kw in temporal_keywords):
- return "temporal"
- if query_lower.startswith("how "):
- return "how_to"
- if query_lower.startswith("what "):
- return "what_is"
- if any(kw in query_lower for kw in ["difference", "vs", "versus", "compare"]):
- return "comparison"
- if any(kw in query_lower for kw in ["meeting", "notes", "journal", "ideas", "thoughts"]):
- return "personal"
- return "other"
- def extract_named_entities(query: str) -> list:
- entities = []
- stopwords = {"the", "a", "an", "is", "are", "to", "for", "of", "in", "and", "or"}
- for word in query.split():
- if word.lower() in stopwords:
- continue
- if word and word[0].isupper() and len(word) > 1:
- entities.append(word)
- if any(c in word for c in ".+-0123456789") and len(word) > 1:
- entities.append(word)
- return entities
- def analyze_examples(examples: list[TrainingExample]) -> tuple[DatasetStats, dict, dict]:
- stats = DatasetStats()
- categories: Counter = Counter()
- seen_queries: set[str] = set()
- category_examples: dict[str, list[str]] = defaultdict(list)
- for ex in examples:
- stats.total_examples += 1
- query_lower = ex.query.lower()
- if query_lower in seen_queries:
- stats.duplicate_queries += 1
- else:
- seen_queries.add(query_lower)
- word_count = len(ex.query.split())
- if word_count <= 2:
- stats.short_queries += 1
- elif word_count <= 5:
- stats.medium_queries += 1
- else:
- stats.long_queries += 1
- category = categorize_query(ex.query)
- categories[category] += 1
- category_examples[category].append(ex.query)
- if extract_named_entities(ex.query):
- stats.named_entity_queries += 1
- # Use the typed OutputPair model
- types_present = {p.type for p in ex.output}
- if OutputType.lex in types_present:
- stats.has_lex += 1
- if OutputType.vec in types_present:
- stats.has_vec += 1
- if OutputType.hyde in types_present:
- stats.has_hyde += 1
- for p in ex.output:
- if p.type == OutputType.hyde and len(p.text) > 200:
- stats.long_hyde_count += 1
- stats.temporal_queries = categories.get("temporal", 0)
- stats.short_keyword_queries = categories.get("short_keyword", 0)
- return stats, dict(categories), dict(category_examples)
- def print_report(stats: DatasetStats, categories: dict, category_examples: dict):
- print("=" * 70)
- print("QMD TRAINING DATA ANALYSIS REPORT")
- print("=" * 70)
- print()
- total = stats.total_examples
- print("BASIC STATISTICS")
- print("-" * 40)
- print(f"Total examples: {total:>6}")
- print(f"Duplicates found: {stats.duplicate_queries:>6}")
- print()
- print("QUERY LENGTH DISTRIBUTION")
- print("-" * 40)
- print(f"Short (1-2 words): {stats.short_queries:>6} ({100 * stats.short_queries / total:5.1f}%)")
- print(f"Medium (3-5 words): {stats.medium_queries:>6} ({100 * stats.medium_queries / total:5.1f}%)")
- print(f"Long (6+ words): {stats.long_queries:>6} ({100 * stats.long_queries / total:5.1f}%)")
- print()
- print("CATEGORY DISTRIBUTION")
- print("-" * 40)
- for cat, count in sorted(categories.items(), key=lambda x: -x[1]):
- pct = 100 * count / total
- bar = "#" * int(pct / 2)
- print(f"{cat:20} {count:>6} ({pct:5.1f}%) {bar}")
- print()
- print("OUTPUT FORMAT COVERAGE")
- print("-" * 40)
- print(f"Has lex: {stats.has_lex:>6} ({100 * stats.has_lex / total:5.1f}%)")
- print(f"Has vec: {stats.has_vec:>6} ({100 * stats.has_vec / total:5.1f}%)")
- print(f"Has hyde: {stats.has_hyde:>6} ({100 * stats.has_hyde / total:5.1f}%)")
- print(f"Long hyde (>200ch): {stats.long_hyde_count:>6}")
- print()
- print("EVALUATION ALIGNMENT")
- print("-" * 40)
- print(f"Named entity queries: {stats.named_entity_queries:>6} ({100 * stats.named_entity_queries / total:5.1f}%)")
- print(f"Temporal/recency: {stats.temporal_queries:>6} ({100 * stats.temporal_queries / total:5.1f}%)")
- print(f"Short keyword queries: {stats.short_keyword_queries:>6} ({100 * stats.short_keyword_queries / total:5.1f}%)")
- print()
- print("RECOMMENDATIONS")
- print("-" * 40)
- recommendations = []
- if stats.short_queries / total < 0.15:
- recommendations.append("Short queries below 15% - add more 1-2 word keyword queries")
- if stats.named_entity_queries / total < 0.10:
- recommendations.append("Named entity queries below 10% - add more capitalized tech term queries")
- if stats.temporal_queries / total < 0.05:
- recommendations.append("Temporal queries below 5% - add more 'latest', 'recent' queries")
- if stats.long_hyde_count > 50:
- recommendations.append(f"{stats.long_hyde_count} long hyde sections - consider truncating")
- if stats.duplicate_queries > 0:
- recommendations.append(f"{stats.duplicate_queries} duplicate queries - consider deduplication")
- if not recommendations:
- print("Dataset looks good! No major issues detected.")
- else:
- for rec in recommendations:
- print(f" - {rec}")
- print()
- print("=" * 70)
- def main():
- parser = argparse.ArgumentParser(description="Analyze QMD training dataset")
- parser.add_argument(
- "--input",
- type=str,
- default="data/qmd_expansion_v3_structured.jsonl",
- help="Path to training data JSONL file",
- )
- parser.add_argument(
- "--show-examples",
- type=int,
- default=3,
- help="Number of example queries to show per category",
- )
- args = parser.parse_args()
- input_path = Path(args.input)
- if not input_path.exists():
- script_dir = Path(__file__).parent.parent
- input_path = script_dir / args.input
- if not input_path.exists():
- print(f"Error: Could not find dataset at {input_path}")
- return 1
- print(f"Analyzing: {input_path}")
- print()
- examples = load_examples(input_path)
- stats, categories, category_examples = analyze_examples(examples)
- print_report(stats, categories, category_examples)
- if args.show_examples > 0:
- print("SAMPLE QUERIES BY CATEGORY")
- print("-" * 40)
- for cat in sorted(categories.keys()):
- exs = category_examples.get(cat, [])
- if exs:
- print(f"\n{cat.upper()}:")
- for ex in exs[:args.show_examples]:
- print(f" - {ex}")
- print()
- return 0
- if __name__ == "__main__":
- exit(main())
|