| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369 |
- #!/usr/bin/env python3
- """
- Dataset Analysis and Quality Report Generator
- Analyzes the training data for:
- 1. Query length distribution
- 2. Category diversity
- 3. Named entity coverage
- 4. Temporal query coverage
- 5. Short query coverage (important for ambiguous queries)
- 6. Duplicate detection
- 7. Quality issues (long hyde, missing fields, etc.)
- """
- import json
- import re
- import sys
- from pathlib import Path
- from collections import Counter, defaultdict
- from dataclasses import dataclass
- sys.path.insert(0, str(Path(__file__).parent.parent))
- from dataset.schema import normalize_output_items, parse_output_text
- @dataclass
- class DatasetStats:
- total_examples: int = 0
- short_queries: int = 0 # 1-2 words
- medium_queries: int = 0 # 3-5 words
- long_queries: int = 0 # 6+ words
- has_lex: int = 0
- has_vec: int = 0
- has_hyde: int = 0
- long_hyde_count: int = 0
- duplicate_queries: int = 0
- named_entity_queries: int = 0
- temporal_queries: int = 0
- short_keyword_queries: int = 0
- def categorize_query(query: str) -> str:
- """Categorize a query by type."""
- query_lower = query.lower()
- words = query_lower.split()
- word_count = len(words)
- # Short keyword queries
- if word_count <= 2:
- return "short_keyword"
- # Named entity queries (capitalized words or tech terms)
- if any(w[0].isupper() for w in words if w):
- return "named_entity"
- # Temporal/recency queries
- temporal_keywords = [
- "latest",
- "recent",
- "new",
- "update",
- "changelog",
- "changed",
- "version",
- "release",
- "news",
- "2024",
- "2025",
- ]
- if any(kw in query_lower for kw in temporal_keywords):
- return "temporal"
- # How-to queries
- if query_lower.startswith("how "):
- return "how_to"
- # What is queries
- if query_lower.startswith("what "):
- return "what_is"
- # Difference/comparison queries
- if any(kw in query_lower for kw in ["difference", "vs", "versus", "compare"]):
- return "comparison"
- # Personal/journal style
- if any(
- kw in query_lower for kw in ["meeting", "notes", "journal", "ideas", "thoughts"]
- ):
- return "personal"
- return "other"
- def extract_named_entities(query: str) -> list:
- """Extract potential named entities from query."""
- entities = []
- words = query.split()
- for word in words:
- # Skip stopwords
- if word.lower() in {
- "the",
- "a",
- "an",
- "is",
- "are",
- "to",
- "for",
- "of",
- "in",
- "and",
- "or",
- }:
- continue
- # Capitalized words (potential named entities)
- if word and word[0].isupper() and len(word) > 1:
- entities.append(word)
- # Technology terms with version numbers or special chars
- if any(c in word for c in ".+-0123456789") and len(word) > 1:
- entities.append(word)
- return entities
- def analyze_dataset(filepath: Path) -> tuple[DatasetStats, dict, dict]:
- """Analyze the dataset and return statistics."""
- stats = DatasetStats()
- categories = Counter()
- seen_queries = set()
- duplicate_count = 0
- category_examples = defaultdict(list)
- with open(filepath, "r", encoding="utf-8") as f:
- for line_num, line in enumerate(f, 1):
- line = line.strip()
- if not line:
- continue
- try:
- example = json.loads(line)
- query = example.get("query", "") or example.get("input", "")
- output = example.get("output", [])
- if isinstance(output, str):
- output = parse_output_text(output)
- output = normalize_output_items(output)
- stats.total_examples += 1
- # Check for duplicates
- query_lower = query.lower()
- if query_lower in seen_queries:
- duplicate_count += 1
- else:
- seen_queries.add(query_lower)
- # Query length categorization
- word_count = len(query.split())
- if word_count <= 2:
- stats.short_queries += 1
- elif word_count <= 5:
- stats.medium_queries += 1
- else:
- stats.long_queries += 1
- # Category detection
- category = categorize_query(query)
- categories[category] += 1
- category_examples[category].append(query)
- # Named entity detection
- if extract_named_entities(query):
- stats.named_entity_queries += 1
- # Output analysis
- has_lex = any(o[0] == "lex" for o in output)
- has_vec = any(o[0] == "vec" for o in output)
- has_hyde = any(o[0] == "hyde" for o in output)
- if has_lex:
- stats.has_lex += 1
- if has_vec:
- stats.has_vec += 1
- if has_hyde:
- stats.has_hyde += 1
- # Check hyde length
- for kind, text in output:
- if kind == "hyde" and len(text) > 200:
- stats.long_hyde_count += 1
- except json.JSONDecodeError:
- print(f"Warning: Could not parse line {line_num}")
- stats.duplicate_queries = duplicate_count
- stats.temporal_queries = categories.get("temporal", 0)
- stats.short_keyword_queries = categories.get("short_keyword", 0)
- return stats, dict(categories), dict(category_examples)
- def print_report(stats: DatasetStats, categories: dict, category_examples: dict):
- """Print a comprehensive analysis report."""
- print("=" * 70)
- print("QMD TRAINING DATA ANALYSIS REPORT")
- print("=" * 70)
- print()
- # Basic statistics
- print("📊 BASIC STATISTICS")
- print("-" * 40)
- print(f"Total examples: {stats.total_examples:>6}")
- print(f"Duplicates found: {stats.duplicate_queries:>6}")
- print()
- # Query length distribution
- print("📝 QUERY LENGTH DISTRIBUTION")
- print("-" * 40)
- total = stats.total_examples
- print(
- f"Short (1-2 words): {stats.short_queries:>6} ({100 * stats.short_queries / total:5.1f}%)"
- )
- print(
- f"Medium (3-5 words): {stats.medium_queries:>6} ({100 * stats.medium_queries / total:5.1f}%)"
- )
- print(
- f"Long (6+ words): {stats.long_queries:>6} ({100 * stats.long_queries / total:5.1f}%)"
- )
- print()
- # Category distribution
- print("🏷️ CATEGORY DISTRIBUTION")
- print("-" * 40)
- for cat, count in sorted(categories.items(), key=lambda x: -x[1]):
- pct = 100 * count / total
- bar = "█" * int(pct / 2)
- print(f"{cat:20} {count:>6} ({pct:5.1f}%) {bar}")
- print()
- # Output format coverage
- print("✅ OUTPUT FORMAT COVERAGE")
- print("-" * 40)
- print(
- f"Has lex: {stats.has_lex:>6} ({100 * stats.has_lex / total:5.1f}%)"
- )
- print(
- f"Has vec: {stats.has_vec:>6} ({100 * stats.has_vec / total:5.1f}%)"
- )
- print(
- f"Has hyde: {stats.has_hyde:>6} ({100 * stats.has_hyde / total:5.1f}%)"
- )
- print(f"Long hyde (>200ch): {stats.long_hyde_count:>6}")
- print()
- # Critical metrics for evals
- print("🎯 EVALUATION ALIGNMENT")
- print("-" * 40)
- print(
- f"Named entity queries: {stats.named_entity_queries:>6} ({100 * stats.named_entity_queries / total:5.1f}%)"
- )
- print(
- f"Temporal/recency: {stats.temporal_queries:>6} ({100 * stats.temporal_queries / total:5.1f}%)"
- )
- print(
- f"Short keyword queries: {stats.short_keyword_queries:>6} ({100 * stats.short_keyword_queries / total:5.1f}%)"
- )
- print()
- # Recommendations
- print("💡 RECOMMENDATIONS")
- print("-" * 40)
- recommendations = []
- if stats.short_queries / total < 0.15:
- recommendations.append(
- "⚠️ Short queries below 15% - add more 1-2 word keyword queries"
- )
- if stats.named_entity_queries / total < 0.10:
- recommendations.append(
- "⚠️ Named entity queries below 10% - add more capitalized tech term queries"
- )
- if stats.temporal_queries / total < 0.05:
- recommendations.append(
- "⚠️ Temporal queries below 5% - add more 'latest', 'recent' queries"
- )
- if stats.long_hyde_count > 50:
- recommendations.append(
- f"⚠️ {stats.long_hyde_count} long hyde sections - consider truncating"
- )
- if stats.duplicate_queries > 0:
- recommendations.append(
- f"⚠️ {stats.duplicate_queries} duplicate queries - consider deduplication"
- )
- if categories.get("short_keyword", 0) < 100:
- recommendations.append(
- "⚠️ Need more short keyword examples for ambiguous query training"
- )
- if not recommendations:
- print("✅ Dataset looks good! No major issues detected.")
- else:
- for rec in recommendations:
- print(rec)
- print()
- print("=" * 70)
- def main():
- """Main entry point."""
- import argparse
- parser = argparse.ArgumentParser(description="Analyze QMD training dataset")
- parser.add_argument(
- "--input",
- type=str,
- default="data/qmd_expansion_v2.jsonl",
- help="Path to training data JSONL file",
- )
- parser.add_argument(
- "--show-examples",
- type=int,
- default=3,
- help="Number of example queries to show per category",
- )
- args = parser.parse_args()
- input_path = Path(args.input)
- if not input_path.exists():
- # Try relative to script directory
- script_dir = Path(__file__).parent.parent
- input_path = script_dir / args.input
- if not input_path.exists():
- print(f"Error: Could not find dataset at {input_path}")
- print("Please run from finetune directory or specify correct path")
- return 1
- print(f"Analyzing: {input_path}")
- print()
- stats, categories, category_examples = analyze_dataset(input_path)
- print_report(stats, categories, category_examples)
- # Show examples if requested
- if args.show_examples > 0:
- print("📋 SAMPLE QUERIES BY CATEGORY")
- print("-" * 40)
- for cat in sorted(categories.keys()):
- examples = category_examples.get(cat, [])
- if examples:
- print(f"\n{cat.upper()}:")
- for ex in examples[: args.show_examples]:
- print(f" • {ex}")
- print()
- return 0
- if __name__ == "__main__":
- exit(main())
|