prepare_data.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "transformers>=4.45.0",
  6. # "jinja2",
  7. # ]
  8. # ///
  9. """Prepare QMD query expansion data for training.
  10. See PROMPT_FORMAT.md for format specification.
  11. """
  12. import argparse
  13. import json
  14. import random
  15. from pathlib import Path
  16. from transformers import AutoTokenizer
  17. _tokenizer = None
  18. def get_tokenizer():
  19. global _tokenizer
  20. if _tokenizer is None:
  21. _tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
  22. return _tokenizer
  23. # Short single-word queries that need proper expansion examples
  24. SHORT_QUERIES = [
  25. # Technical keywords
  26. "auth", "config", "setup", "api", "cache", "log", "test", "debug",
  27. "deploy", "build", "lint", "format", "migrate", "backup", "restore",
  28. "docker", "git", "npm", "pip", "brew", "curl", "ssh", "ssl", "tls",
  29. "cors", "csrf", "jwt", "oauth", "saml", "ldap", "rbac", "acl",
  30. "crud", "rest", "graphql", "grpc", "websocket", "sse", "http",
  31. "redis", "mongo", "postgres", "mysql", "sqlite", "elastic", "kafka",
  32. "nginx", "apache", "caddy", "traefik", "haproxy", "envoy",
  33. "react", "vue", "angular", "svelte", "solid", "htmx", "alpine",
  34. "node", "deno", "bun", "python", "rust", "golang", "java", "kotlin",
  35. "webpack", "vite", "esbuild", "rollup", "parcel", "turbopack",
  36. "jest", "vitest", "pytest", "mocha", "cypress", "playwright",
  37. # Common short phrases
  38. "env vars", "api keys", "error handling", "rate limiting",
  39. "file upload", "user auth", "db connection", "query params",
  40. "hot reload", "code split", "tree shake", "lazy load",
  41. ]
  42. # Templates for generating short query expansions
  43. # IMPORTANT: All lex lines MUST include {q} to preserve key terms
  44. SHORT_TEMPLATES = [
  45. {
  46. "lex": ["{q} configuration", "{q} settings", "{q} setup"],
  47. "vec": ["how to configure {q} in my project", "{q} setup and configuration tutorial"],
  48. "hyde": "To set up {q}, first install the required dependencies. Then configure the settings in your project configuration file.",
  49. },
  50. {
  51. "lex": ["{q} tutorial", "{q} guide", "{q} basics"],
  52. "vec": ["beginner guide to {q}", "how to get started with {q}"],
  53. "hyde": "This guide covers the basics of {q}. Follow the steps below to get started with your first implementation.",
  54. },
  55. {
  56. "lex": ["{q} best practices", "{q} patterns", "{q} tips"],
  57. "vec": ["best practices for using {q}", "recommended patterns for {q}"],
  58. "hyde": "When working with {q}, follow these best practices: use consistent naming, handle errors properly, and document your code.",
  59. },
  60. {
  61. "lex": ["{q} troubleshooting", "{q} fix", "{q} errors"],
  62. "vec": ["how to fix {q} errors", "troubleshooting common {q} problems"],
  63. "hyde": "If you encounter {q} issues, check your configuration first. Common problems include missing dependencies and incorrect settings.",
  64. },
  65. {
  66. "lex": ["{q} examples", "{q} code", "{q} usage"],
  67. "vec": ["code examples for {q}", "practical {q} implementation examples"],
  68. "hyde": "Here are some practical examples of {q} in action. Each example demonstrates a common use case with working code.",
  69. },
  70. ]
  71. def truncate_hyde(hyde_text: str, max_len: int = 150) -> str:
  72. """Truncate hyde to max length, ending at sentence boundary."""
  73. if len(hyde_text) <= max_len:
  74. return hyde_text
  75. truncated = hyde_text[:max_len]
  76. last_period = truncated.rfind(". ")
  77. if last_period > max_len // 2:
  78. return truncated[:last_period + 1]
  79. last_space = truncated.rfind(" ")
  80. if last_space > max_len // 2:
  81. return truncated[:last_space] + "."
  82. return truncated[:max_len-1] + "."
  83. def clean_output(output: str) -> str:
  84. """Clean output: truncate hyde, remove invalid lines."""
  85. lines = []
  86. for line in output.strip().split("\n"):
  87. line = line.strip()
  88. if not line:
  89. continue
  90. if line.startswith("hyde:"):
  91. hyde_text = line[5:].strip()
  92. hyde_text = truncate_hyde(hyde_text)
  93. lines.append(f"hyde: {hyde_text}")
  94. elif line.startswith(("lex:", "vec:")):
  95. lines.append(line)
  96. return "\n".join(lines)
  97. def generate_short_example(query: str) -> dict:
  98. """Generate a training example for a short query."""
  99. template = random.choice(SHORT_TEMPLATES)
  100. lex_lines = random.sample(template["lex"], 2)
  101. vec_lines = random.sample(template["vec"], 2)
  102. hyde_line = template["hyde"]
  103. output_lines = []
  104. output_lines.append(f"hyde: {hyde_line.format(q=query)}")
  105. for lex in lex_lines:
  106. output_lines.append(f"lex: {lex.format(q=query)}")
  107. for vec in vec_lines:
  108. output_lines.append(f"vec: {vec.format(q=query)}")
  109. return {"input": query, "output": "\n".join(output_lines)}
  110. def format_for_training(input_text: str, output_text: str) -> dict:
  111. """Format a single example for SFT training using Qwen chat format."""
  112. tokenizer = get_tokenizer()
  113. # Use /no_think to disable thinking mode - we want direct output
  114. messages = [
  115. {"role": "user", "content": f"/no_think Expand this search query: {input_text}"},
  116. {"role": "assistant", "content": output_text}
  117. ]
  118. # Use tokenizer to generate proper chat format with special tokens
  119. text = tokenizer.apply_chat_template(
  120. messages,
  121. tokenize=False,
  122. add_generation_prompt=False,
  123. )
  124. # Strip empty <think> tags - we don't want thinking mode
  125. # The template adds "<think>\n\n</think>\n\n" which we remove
  126. text = text.replace("<think>\n\n</think>\n\n", "")
  127. return {
  128. "text": text,
  129. "messages": messages,
  130. }
  131. def main():
  132. parser = argparse.ArgumentParser(description="Prepare data for training")
  133. parser.add_argument("--input", type=str, default="data/*.jsonl", help="Input JSONL file(s) - supports glob patterns")
  134. parser.add_argument("--output", type=str, default="data/train", help="Output directory")
  135. parser.add_argument("--split", type=float, default=0.1, help="Validation split ratio")
  136. parser.add_argument("--add-short", type=int, default=3, help="Variations per short query to add")
  137. args = parser.parse_args()
  138. output_dir = Path(args.output)
  139. output_dir.mkdir(parents=True, exist_ok=True)
  140. # Support glob patterns for input
  141. import glob
  142. if "*" in args.input:
  143. input_files = sorted(glob.glob(args.input))
  144. if not input_files:
  145. print(f"Error: No files found matching: {args.input}")
  146. exit(1)
  147. print(f"Found {len(input_files)} input files: {[Path(f).name for f in input_files]}")
  148. else:
  149. input_path = Path(args.input)
  150. if not input_path.exists():
  151. print(f"Error: Input file not found: {input_path}")
  152. exit(1)
  153. input_files = [str(input_path)]
  154. # Load and clean existing examples from all input files
  155. examples = []
  156. seen_queries = set()
  157. long_hyde_count = 0
  158. for input_file in input_files:
  159. file_count = 0
  160. with open(input_file) as f:
  161. for line in f:
  162. if line.strip():
  163. ex = json.loads(line)
  164. # Clean output (truncate hyde, remove invalid lines)
  165. original_output = ex["output"]
  166. ex["output"] = clean_output(ex["output"])
  167. # Track hyde truncation
  168. if "hyde:" in original_output:
  169. for orig_line in original_output.split("\n"):
  170. if orig_line.strip().startswith("hyde:"):
  171. if len(orig_line) > 160:
  172. long_hyde_count += 1
  173. # Validate cleaned output
  174. has_lex = "lex:" in ex["output"]
  175. has_vec = "vec:" in ex["output"]
  176. # Skip duplicates
  177. if has_lex and has_vec and ex["input"].lower() not in seen_queries:
  178. examples.append(ex)
  179. seen_queries.add(ex["input"].lower())
  180. file_count += 1
  181. print(f" {Path(input_file).name}: {file_count} examples")
  182. print(f"Loaded and cleaned {len(examples)} examples total")
  183. print(f"Truncated {long_hyde_count} long hyde sections")
  184. # Count existing short queries
  185. short_existing = sum(1 for ex in examples if len(ex["input"].split()) <= 2)
  186. print(f"Existing short queries (1-2 words): {short_existing}")
  187. # Generate additional short query examples
  188. new_short = []
  189. for query in SHORT_QUERIES:
  190. if query.lower() not in seen_queries:
  191. for _ in range(args.add_short):
  192. new_short.append(generate_short_example(query))
  193. seen_queries.add(query.lower())
  194. print(f"Generated {len(new_short)} new short query examples")
  195. # Combine and shuffle
  196. all_examples = examples + new_short
  197. random.shuffle(all_examples)
  198. # Format for training
  199. formatted = [format_for_training(ex["input"], ex["output"]) for ex in all_examples]
  200. # Split into train/val
  201. split_idx = int(len(formatted) * (1 - args.split))
  202. train_data = formatted[:split_idx]
  203. val_data = formatted[split_idx:]
  204. # Write train set
  205. train_path = output_dir / "train.jsonl"
  206. with open(train_path, "w") as f:
  207. for item in train_data:
  208. f.write(json.dumps(item) + "\n")
  209. # Write validation set
  210. val_path = output_dir / "val.jsonl"
  211. with open(val_path, "w") as f:
  212. for item in val_data:
  213. f.write(json.dumps(item) + "\n")
  214. # Write chat format (for TRL)
  215. chat_path = output_dir / "train_chat.jsonl"
  216. with open(chat_path, "w") as f:
  217. for item in train_data:
  218. f.write(json.dumps({"messages": item["messages"]}) + "\n")
  219. # Stats
  220. short_final = sum(1 for ex in all_examples if len(ex["input"].split()) <= 2)
  221. print(f"\n=== Summary ===")
  222. print(f"Total examples: {len(all_examples)}")
  223. print(f"Short queries: {short_final} ({100*short_final/len(all_examples):.1f}%)")
  224. print(f"Train: {len(train_data)}, Val: {len(val_data)}")
  225. print(f"Output: {output_dir}")
  226. # Dataset info
  227. dataset_info = {
  228. "dataset_name": "qmd-query-expansion",
  229. "train_samples": len(train_data),
  230. "val_samples": len(val_data),
  231. "short_query_pct": round(100*short_final/len(all_examples), 1),
  232. "columns": ["prompt", "completion", "text", "messages"],
  233. }
  234. with open(output_dir / "dataset_info.json", "w") as f:
  235. json.dump(dataset_info, f, indent=2)
  236. if __name__ == "__main__":
  237. main()