generate_only_variants.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = []
  4. # ///
  5. """
  6. Generate 'only:' variant training data from high-quality expansions.
  7. Takes existing training data and creates derivative examples where the query
  8. ends with 'only: lex', 'only: hyde', or 'only: vec', and the output contains
  9. ONLY that component type.
  10. Usage:
  11. uv run generate_only_variants.py data/qmd_expansion_handcrafted.jsonl
  12. uv run generate_only_variants.py data/qmd_expansion_handcrafted.jsonl -o data/qmd_only_variants.jsonl
  13. uv run generate_only_variants.py data/*.jsonl --combine # combine all inputs
  14. """
  15. import argparse
  16. import json
  17. import sys
  18. from pathlib import Path
  19. def parse_expansion(text: str) -> dict:
  20. """Parse a multi-line expansion into {lex, vec, hyde} lists."""
  21. result = {"lex": [], "vec": [], "hyde": []}
  22. for line in text.strip().split("\n"):
  23. line = line.strip()
  24. if not line:
  25. continue
  26. if line.startswith("lex:"):
  27. result["lex"].append(line[4:].strip())
  28. elif line.startswith("vec:"):
  29. result["vec"].append(line[4:].strip())
  30. elif line.startswith("hyde:"):
  31. result["hyde"].append(line[5:].strip())
  32. return result
  33. # Templated patterns to filter out from hyde (low quality)
  34. TEMPLATED_PATTERNS = [
  35. "This comprehensive guide covers",
  36. "This comprehensive guide to",
  37. "requires practice and patience",
  38. "This resource provides",
  39. "Follow the steps carefully",
  40. "covers all the essential information",
  41. "includes practical examples, best practices",
  42. ]
  43. def is_templated_hyde(hyde_text: str) -> bool:
  44. """Check if a hyde output is a low-quality templated response."""
  45. return any(pattern in hyde_text for pattern in TEMPLATED_PATTERNS)
  46. def format_output(parsed: dict, only_type: str) -> str | None:
  47. """Format output for a single type. Returns None if type is empty or low quality."""
  48. items = parsed.get(only_type, [])
  49. if not items:
  50. return None
  51. # Filter out templated hyde outputs
  52. if only_type == "hyde":
  53. filtered = [item for item in items if not is_templated_hyde(item)]
  54. if not filtered:
  55. return None
  56. items = filtered
  57. lines = []
  58. for item in items:
  59. lines.append(f"{only_type}: {item}")
  60. return "\n".join(lines)
  61. def generate_only_variants(input_query: str, output: str) -> list[dict]:
  62. """Generate all valid 'only:' variants from a single example."""
  63. variants = []
  64. parsed = parse_expansion(output)
  65. for only_type in ["lex", "vec", "hyde"]:
  66. formatted = format_output(parsed, only_type)
  67. if formatted:
  68. # Add the '/only:' suffix to the query (slash prefix)
  69. new_query = f"{input_query} /only:{only_type}"
  70. variants.append({
  71. "input": new_query,
  72. "output": formatted,
  73. "_source_type": only_type,
  74. "_source_query": input_query,
  75. })
  76. return variants
  77. def process_file(input_path: Path) -> list[dict]:
  78. """Process a single JSONL file and return all 'only:' variants."""
  79. variants = []
  80. seen_queries = set()
  81. with open(input_path) as f:
  82. for line_num, line in enumerate(f, 1):
  83. line = line.strip()
  84. if not line:
  85. continue
  86. try:
  87. data = json.loads(line)
  88. except json.JSONDecodeError as e:
  89. print(f" Warning: Skipping line {line_num} (invalid JSON): {e}", file=sys.stderr)
  90. continue
  91. # Skip metadata lines
  92. if data.get("_meta"):
  93. continue
  94. input_query = data.get("input", "")
  95. output = data.get("output", "")
  96. if not input_query or not output:
  97. continue
  98. # Skip if query already has '/only:' suffix
  99. if " /only:" in input_query.lower():
  100. continue
  101. # Skip duplicates
  102. if input_query in seen_queries:
  103. continue
  104. seen_queries.add(input_query)
  105. # Generate variants
  106. for variant in generate_only_variants(input_query, output):
  107. variants.append(variant)
  108. return variants
  109. def main():
  110. parser = argparse.ArgumentParser(
  111. description="Generate 'only:' variant training data from high-quality expansions",
  112. formatter_class=argparse.RawDescriptionHelpFormatter,
  113. )
  114. parser.add_argument(
  115. "input_files",
  116. nargs="+",
  117. help="Input JSONL files with training data",
  118. )
  119. parser.add_argument(
  120. "-o", "--output",
  121. default="data/qmd_only_variants.jsonl",
  122. help="Output JSONL file (default: data/qmd_only_variants.jsonl)",
  123. )
  124. parser.add_argument(
  125. "--combine",
  126. action="store_true",
  127. help="Combine all input files into one output",
  128. )
  129. parser.add_argument(
  130. "--stats",
  131. action="store_true",
  132. help="Print statistics about generated variants",
  133. )
  134. args = parser.parse_args()
  135. all_variants = []
  136. stats = {"lex": 0, "vec": 0, "hyde": 0}
  137. for input_file in args.input_files:
  138. input_path = Path(input_file)
  139. if not input_path.exists():
  140. print(f"Warning: {input_file} not found, skipping", file=sys.stderr)
  141. continue
  142. print(f"Processing {input_path.name}...", file=sys.stderr)
  143. variants = process_file(input_path)
  144. for v in variants:
  145. stats[v["_source_type"]] += 1
  146. if args.combine:
  147. all_variants.extend(variants)
  148. else:
  149. # Write to separate output files per input
  150. output_path = input_path.parent / f"{input_path.stem}_only.jsonl"
  151. with open(output_path, "w") as f:
  152. for variant in variants:
  153. # Remove internal fields before writing
  154. clean = {"input": variant["input"], "output": variant["output"]}
  155. f.write(json.dumps(clean) + "\n")
  156. print(f" -> {len(variants)} variants written to {output_path}", file=sys.stderr)
  157. if args.combine and all_variants:
  158. output_path = Path(args.output)
  159. output_path.parent.mkdir(parents=True, exist_ok=True)
  160. with open(output_path, "w") as f:
  161. for variant in all_variants:
  162. clean = {"input": variant["input"], "output": variant["output"]}
  163. f.write(json.dumps(clean) + "\n")
  164. print(f"\nTotal: {len(all_variants)} variants written to {output_path}", file=sys.stderr)
  165. if args.stats or args.combine:
  166. print(f"\nStats:", file=sys.stderr)
  167. print(f" lex: {stats['lex']}", file=sys.stderr)
  168. print(f" vec: {stats['vec']}", file=sys.stderr)
  169. print(f" hyde: {stats['hyde']}", file=sys.stderr)
  170. print(f" total: {sum(stats.values())}", file=sys.stderr)
  171. if __name__ == "__main__":
  172. main()