prepare_data.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "transformers>=4.45.0",
  6. # "jinja2",
  7. # ]
  8. # ///
  9. """Prepare QMD query expansion data for training.
  10. See PROMPT_FORMAT.md for format specification.
  11. """
  12. import argparse
  13. import json
  14. import random
  15. from pathlib import Path
  16. from transformers import AutoTokenizer
  17. _tokenizer = None
  18. def get_tokenizer():
  19. global _tokenizer
  20. if _tokenizer is None:
  21. _tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
  22. return _tokenizer
  23. # Short single-word queries that need proper expansion examples - organized by category
  24. SHORT_QUERIES = [
  25. # === Programming Languages & Runtimes ===
  26. "python",
  27. "typescript",
  28. "javascript",
  29. "rust",
  30. "golang",
  31. "java",
  32. "kotlin",
  33. "swift",
  34. "ruby",
  35. "php",
  36. "cpp",
  37. "c",
  38. "elixir",
  39. "scala",
  40. "clojure",
  41. "dart",
  42. "r",
  43. "node",
  44. "deno",
  45. "bun",
  46. # === Frontend Frameworks ===
  47. "react",
  48. "vue",
  49. "angular",
  50. "svelte",
  51. "solid",
  52. "htmx",
  53. "alpine",
  54. "nextjs",
  55. "nuxt",
  56. "jquery",
  57. "backbone",
  58. "ember",
  59. # === Backend Frameworks ===
  60. "django",
  61. "flask",
  62. "fastapi",
  63. "express",
  64. "rails",
  65. "spring",
  66. "laravel",
  67. "gin",
  68. # === Databases ===
  69. "postgres",
  70. "mysql",
  71. "mongodb",
  72. "redis",
  73. "elasticsearch",
  74. "sqlite",
  75. "dynamodb",
  76. "cassandra",
  77. "cockroachdb",
  78. "neo4j",
  79. "couchdb",
  80. # === Infrastructure & DevOps ===
  81. "docker",
  82. "kubernetes",
  83. "terraform",
  84. "ansible",
  85. "vagrant",
  86. "packer",
  87. "jenkins",
  88. "gitlab-ci",
  89. "github-actions",
  90. "circleci",
  91. "travis",
  92. "argo",
  93. "nginx",
  94. "apache",
  95. "caddy",
  96. "traefik",
  97. "haproxy",
  98. "envoy",
  99. # === Cloud Platforms ===
  100. "aws",
  101. "gcp",
  102. "azure",
  103. "vercel",
  104. "netlify",
  105. "heroku",
  106. "digitalocean",
  107. "cloudflare",
  108. "flyio",
  109. "render",
  110. # === Tools & Utilities ===
  111. "git",
  112. "linux",
  113. "bash",
  114. "zsh",
  115. "vim",
  116. "tmux",
  117. "curl",
  118. "wget",
  119. "ssh",
  120. "npm",
  121. "pip",
  122. "brew",
  123. "apt",
  124. "yum",
  125. "cargo",
  126. "gem",
  127. "composer",
  128. "maven",
  129. # === Security & Auth ===
  130. "auth",
  131. "oauth",
  132. "jwt",
  133. "saml",
  134. "ldap",
  135. "rbac",
  136. "cors",
  137. "csrf",
  138. "xss",
  139. "ssl",
  140. "tls",
  141. "cert",
  142. "encrypt",
  143. "hash",
  144. "cipher",
  145. # === Web Technologies ===
  146. "rest",
  147. "graphql",
  148. "grpc",
  149. "websocket",
  150. "sse",
  151. "http",
  152. "https",
  153. "html",
  154. "css",
  155. "sass",
  156. "less",
  157. "styled-components",
  158. "tailwind",
  159. # === Data & ML ===
  160. "pandas",
  161. "numpy",
  162. "tensorflow",
  163. "pytorch",
  164. "sklearn",
  165. "jupyter",
  166. "spark",
  167. "kafka",
  168. "airflow",
  169. "dbt",
  170. "hadoop",
  171. "hive",
  172. "presto",
  173. # === Testing ===
  174. "jest",
  175. "vitest",
  176. "pytest",
  177. "mocha",
  178. "cypress",
  179. "playwright",
  180. "selenium",
  181. "rspec",
  182. "junit",
  183. "testng",
  184. # === Build Tools ===
  185. "webpack",
  186. "vite",
  187. "esbuild",
  188. "rollup",
  189. "parcel",
  190. "turbopack",
  191. "babel",
  192. # === Monitoring & Observability ===
  193. "prometheus",
  194. "grafana",
  195. "datadog",
  196. "newrelic",
  197. "sentry",
  198. "jaeger",
  199. "logging",
  200. "metrics",
  201. "tracing",
  202. "observability",
  203. # === API & Integration ===
  204. "swagger",
  205. "openapi",
  206. "postman",
  207. "api",
  208. "webhook",
  209. "sdk",
  210. "cli",
  211. # === Architecture Patterns ===
  212. "microservices",
  213. "serverless",
  214. "monolith",
  215. "event-driven",
  216. "cqrs",
  217. "event-sourcing",
  218. "saga",
  219. "circuit-breaker",
  220. "retry",
  221. "idempotency",
  222. # === Development Concepts ===
  223. "config",
  224. "setup",
  225. "cache",
  226. "log",
  227. "debug",
  228. "deploy",
  229. "build",
  230. "lint",
  231. "format",
  232. "migrate",
  233. "backup",
  234. "restore",
  235. "env",
  236. "vars",
  237. "secrets",
  238. "rate-limit",
  239. "load-balance",
  240. "scale",
  241. "replicate",
  242. "shard",
  243. # === General Knowledge: Trivia ===
  244. "trivia",
  245. "quiz",
  246. "facts",
  247. "did-you-know",
  248. "random-facts",
  249. "world-records",
  250. # === General Knowledge: Geography ===
  251. "countries",
  252. "capitals",
  253. "continents",
  254. "oceans",
  255. "rivers",
  256. "mountains",
  257. "deserts",
  258. "islands",
  259. "climate",
  260. "population",
  261. "maps",
  262. "coordinates",
  263. # === General Knowledge: Philosophy ===
  264. "ethics",
  265. "metaphysics",
  266. "epistemology",
  267. "logic",
  268. "stoicism",
  269. "existentialism",
  270. "nihilism",
  271. "utilitarianism",
  272. "deontology",
  273. "virtue-ethics",
  274. "free-will",
  275. # === General Knowledge: History ===
  276. "ancient",
  277. "medieval",
  278. "renaissance",
  279. "industrial",
  280. "world-war",
  281. "cold-war",
  282. "revolution",
  283. "empire",
  284. "civilization",
  285. "archaeology",
  286. "timeline",
  287. # === General Knowledge: Science ===
  288. "physics",
  289. "chemistry",
  290. "biology",
  291. "astronomy",
  292. "geology",
  293. "ecology",
  294. "evolution",
  295. "genetics",
  296. "quantum",
  297. "relativity",
  298. "thermodynamics",
  299. # === General Knowledge: Arts & Culture ===
  300. "art",
  301. "music",
  302. "literature",
  303. "film",
  304. "theater",
  305. "dance",
  306. "sculpture",
  307. "painting",
  308. "photography",
  309. "architecture",
  310. "poetry",
  311. "novel",
  312. # === Common Short Phrases ===
  313. "env vars",
  314. "api keys",
  315. "error handling",
  316. "rate limiting",
  317. "file upload",
  318. "user auth",
  319. "db connection",
  320. "query params",
  321. "hot reload",
  322. "code split",
  323. "tree shake",
  324. "lazy load",
  325. "dependency injection",
  326. "event listener",
  327. "middleware chain",
  328. "route handler",
  329. "controller logic",
  330. "service layer",
  331. "repository pattern",
  332. "unit of work",
  333. "domain model",
  334. "value object",
  335. "aggregate root",
  336. "event bus",
  337. "message queue",
  338. "job scheduler",
  339. "web server",
  340. "app server",
  341. "proxy server",
  342. "load balancer",
  343. "cdn",
  344. "dns",
  345. "ssl-cert",
  346. "firewall",
  347. "subnet",
  348. "vpc",
  349. "gateway",
  350. ]
  351. # Templates for generating short query expansions
  352. # IMPORTANT: All lex lines MUST include {q} to preserve key terms
  353. SHORT_TEMPLATES = [
  354. # Configuration/Setup templates
  355. {
  356. "lex": ["{q} configuration", "{q} settings", "{q} setup guide"],
  357. "vec": [
  358. "how to configure {q} in my project",
  359. "{q} setup and configuration tutorial",
  360. ],
  361. "hyde": "To set up {q}, first install the required dependencies. Then configure the settings in your project configuration file.",
  362. },
  363. # Tutorial/Learning templates
  364. {
  365. "lex": ["{q} tutorial", "{q} guide", "{q} basics"],
  366. "vec": ["beginner guide to {q}", "how to get started with {q}"],
  367. "hyde": "This guide covers the basics of {q}. Follow the steps below to get started with your first implementation.",
  368. },
  369. # Best practices templates
  370. {
  371. "lex": ["{q} best practices", "{q} patterns", "{q} tips"],
  372. "vec": ["best practices for using {q}", "recommended patterns for {q}"],
  373. "hyde": "When working with {q}, follow these best practices: use consistent naming, handle errors properly, and document your code.",
  374. },
  375. # Troubleshooting templates
  376. {
  377. "lex": ["{q} troubleshooting", "{q} fix", "{q} errors"],
  378. "vec": ["how to fix {q} errors", "troubleshooting common {q} problems"],
  379. "hyde": "If you encounter {q} issues, check your configuration first. Common problems include missing dependencies and incorrect settings.",
  380. },
  381. # Examples/Code templates
  382. {
  383. "lex": ["{q} examples", "{q} code samples", "{q} usage"],
  384. "vec": ["code examples for {q}", "practical {q} implementation examples"],
  385. "hyde": "Here are some practical examples of {q} in action. Each example demonstrates a common use case with working code.",
  386. },
  387. # Documentation/Reference templates
  388. {
  389. "lex": ["{q} documentation", "{q} reference", "{q} manual"],
  390. "vec": ["official {q} documentation", "{q} API reference guide"],
  391. "hyde": "The official documentation for {q} provides comprehensive information about features, configuration options, and usage examples.",
  392. },
  393. # Installation templates
  394. {
  395. "lex": ["{q} install", "{q} setup", "{q} getting started"],
  396. "vec": ["how to install {q} on my system", "{q} installation guide"],
  397. "hyde": "To install {q}, run the appropriate package manager command for your system. Verify the installation by checking the version.",
  398. },
  399. # Comparison templates
  400. {
  401. "lex": ["{q} comparison", "{q} vs alternatives", "{q} differences"],
  402. "vec": ["how does {q} compare to alternatives", "{q} pros and cons"],
  403. "hyde": "When comparing {q} to similar tools, consider factors like performance, ease of use, community support, and ecosystem compatibility.",
  404. },
  405. # Performance templates
  406. {
  407. "lex": ["{q} performance", "{q} optimization", "{q} speed"],
  408. "vec": ["how to optimize {q} performance", "{q} performance tuning tips"],
  409. "hyde": "To improve {q} performance, profile your application to identify bottlenecks. Common optimizations include caching, lazy loading, and query optimization.",
  410. },
  411. # Security templates
  412. {
  413. "lex": ["{q} security", "{q} hardening", "{q} vulnerabilities"],
  414. "vec": ["how to secure {q} configuration", "{q} security best practices"],
  415. "hyde": "Security considerations for {q} include input validation, authentication, authorization, and keeping dependencies up to date with security patches.",
  416. },
  417. # Testing templates
  418. {
  419. "lex": ["{q} testing", "{q} test suite", "{q} unit tests"],
  420. "vec": ["how to test {q} code", "{q} testing strategies and frameworks"],
  421. "hyde": "Testing {q} involves writing unit tests, integration tests, and end-to-end tests. Use appropriate testing frameworks for your language and platform.",
  422. },
  423. # Deployment templates
  424. {
  425. "lex": ["{q} deployment", "{q} production", "{q} release"],
  426. "vec": ["how to deploy {q} to production", "{q} production deployment guide"],
  427. "hyde": "Deploying {q} to production requires proper configuration, environment variables, monitoring, and rollback procedures for reliability.",
  428. },
  429. # Debugging templates
  430. {
  431. "lex": ["{q} debugging", "{q} troubleshooting", "{q} error handling"],
  432. "vec": ["how to debug {q} issues", "{q} debugging techniques and tools"],
  433. "hyde": "Debugging {q} involves using logging, breakpoints, stack traces, and specialized debugging tools to identify and fix issues efficiently.",
  434. },
  435. # Integration templates
  436. {
  437. "lex": ["{q} integration", "{q} connect", "{q} interoperability"],
  438. "vec": ["how to integrate {q} with other systems", "{q} integration patterns"],
  439. "hyde": "Integrating {q} with other systems requires understanding APIs, data formats, authentication mechanisms, and error handling strategies.",
  440. },
  441. # Migration templates
  442. {
  443. "lex": ["{q} migration", "{q} upgrade", "{q} versioning"],
  444. "vec": ["how to migrate to {q}", "{q} upgrade guide and breaking changes"],
  445. "hyde": "Migrating to {q} involves planning, testing compatibility, addressing breaking changes, and validating functionality before production deployment.",
  446. },
  447. ]
  448. def truncate_hyde(hyde_text: str, max_len: int = 150) -> str:
  449. """Truncate hyde to max length, ending at sentence boundary."""
  450. if len(hyde_text) <= max_len:
  451. return hyde_text
  452. truncated = hyde_text[:max_len]
  453. last_period = truncated.rfind(". ")
  454. if last_period > max_len // 2:
  455. return truncated[: last_period + 1]
  456. last_space = truncated.rfind(" ")
  457. if last_space > max_len // 2:
  458. return truncated[:last_space] + "."
  459. return truncated[: max_len - 1] + "."
  460. def clean_output(output: str) -> str:
  461. """Clean output: truncate hyde, remove invalid lines."""
  462. lines = []
  463. for line in output.strip().split("\n"):
  464. line = line.strip()
  465. if not line:
  466. continue
  467. if line.startswith("hyde:"):
  468. hyde_text = line[5:].strip()
  469. hyde_text = truncate_hyde(hyde_text)
  470. lines.append(f"hyde: {hyde_text}")
  471. elif line.startswith(("lex:", "vec:")):
  472. lines.append(line)
  473. return "\n".join(lines)
  474. def generate_short_example(query: str) -> dict:
  475. """Generate a training example for a short query."""
  476. template = random.choice(SHORT_TEMPLATES)
  477. lex_lines = random.sample(template["lex"], 2)
  478. vec_lines = random.sample(template["vec"], 2)
  479. hyde_line = template["hyde"]
  480. output_lines = []
  481. for lex in lex_lines:
  482. output_lines.append(f"lex: {lex.format(q=query)}")
  483. for vec in vec_lines:
  484. output_lines.append(f"vec: {vec.format(q=query)}")
  485. output_lines.append(f"hyde: {hyde_line.format(q=query)}")
  486. return {"input": query, "output": "\n".join(output_lines)}
  487. def format_for_training(input_text: str, output_text: str) -> dict:
  488. """Format a single example for SFT training using Qwen chat format."""
  489. tokenizer = get_tokenizer()
  490. # Use /no_think to disable thinking mode - we want direct output
  491. messages = [
  492. {
  493. "role": "user",
  494. "content": f"/no_think Expand this search query: {input_text}",
  495. },
  496. {"role": "assistant", "content": output_text},
  497. ]
  498. # Use tokenizer to generate proper chat format with special tokens
  499. text = tokenizer.apply_chat_template(
  500. messages,
  501. tokenize=False,
  502. add_generation_prompt=False,
  503. )
  504. # Strip empty tags - we don't want thinking mode
  505. # The template adds " \n\n\u5df4\u6bd4\n\n" which we remove
  506. text = text.replace(" \n\n\u5df4\u6bd4\n\n", "")
  507. return {
  508. "text": text,
  509. "messages": messages,
  510. }
  511. def main():
  512. parser = argparse.ArgumentParser(description="Prepare data for training")
  513. parser.add_argument(
  514. "--input", type=str, default="data/qmd_expansion.jsonl", help="Input JSONL file"
  515. )
  516. parser.add_argument(
  517. "--output", type=str, default="data/train", help="Output directory"
  518. )
  519. parser.add_argument(
  520. "--split", type=float, default=0.1, help="Validation split ratio"
  521. )
  522. parser.add_argument(
  523. "--add-short", type=int, default=3, help="Variations per short query to add"
  524. )
  525. args = parser.parse_args()
  526. input_path = Path(args.input)
  527. output_dir = Path(args.output)
  528. output_dir.mkdir(parents=True, exist_ok=True)
  529. if not input_path.exists():
  530. print(f"Error: Input file not found: {input_path}")
  531. exit(1)
  532. # Load and clean existing examples
  533. examples = []
  534. seen_queries = set()
  535. long_hyde_count = 0
  536. with open(input_path) as f:
  537. for line in f:
  538. if line.strip():
  539. ex = json.loads(line)
  540. # Clean output (truncate hyde, remove invalid lines)
  541. original_output = ex["output"]
  542. ex["output"] = clean_output(ex["output"])
  543. # Track hyde truncation
  544. if "hyde:" in original_output:
  545. for orig_line in original_output.split("\n"):
  546. if orig_line.strip().startswith("hyde:"):
  547. if len(orig_line) > 160:
  548. long_hyde_count += 1
  549. # Validate cleaned output
  550. has_lex = "lex:" in ex["output"]
  551. has_vec = "vec:" in ex["output"]
  552. if has_lex and has_vec:
  553. examples.append(ex)
  554. seen_queries.add(ex["input"].lower())
  555. print(f"Loaded and cleaned {len(examples)} examples")
  556. print(f"Truncated {long_hyde_count} long hyde sections")
  557. # Count existing short queries
  558. short_existing = sum(1 for ex in examples if len(ex["input"].split()) <= 2)
  559. print(f"Existing short queries (1-2 words): {short_existing}")
  560. # Generate additional short query examples
  561. new_short = []
  562. for query in SHORT_QUERIES:
  563. if query.lower() not in seen_queries:
  564. for _ in range(args.add_short):
  565. new_short.append(generate_short_example(query))
  566. seen_queries.add(query.lower())
  567. print(f"Generated {len(new_short)} new short query examples")
  568. # Combine and shuffle
  569. all_examples = examples + new_short
  570. random.shuffle(all_examples)
  571. # Format for training
  572. formatted = [format_for_training(ex["input"], ex["output"]) for ex in all_examples]
  573. # Split into train/val
  574. split_idx = int(len(formatted) * (1 - args.split))
  575. train_data = formatted[:split_idx]
  576. val_data = formatted[split_idx:]
  577. # Write train set
  578. train_path = output_dir / "train.jsonl"
  579. with open(train_path, "w") as f:
  580. for item in train_data:
  581. f.write(json.dumps(item) + "\n")
  582. # Write validation set
  583. val_path = output_dir / "val.jsonl"
  584. with open(val_path, "w") as f:
  585. for item in val_data:
  586. f.write(json.dumps(item) + "\n")
  587. # Write chat format (for TRL)
  588. chat_path = output_dir / "train_chat.jsonl"
  589. with open(chat_path, "w") as f:
  590. for item in train_data:
  591. f.write(json.dumps({"messages": item["messages"]}) + "\n")
  592. # Stats
  593. short_final = sum(1 for ex in all_examples if len(ex["input"].split()) <= 2)
  594. print(f"\n=== Summary ===")
  595. print(f"Total examples: {len(all_examples)}")
  596. print(
  597. f"Short queries: {short_final} ({100 * short_final / len(all_examples):.1f}%)"
  598. )
  599. print(f"Train: {len(train_data)}, Val: {len(val_data)}")
  600. print(f"Output: {output_dir}")
  601. # Dataset info
  602. dataset_info = {
  603. "dataset_name": "qmd-query-expansion",
  604. "train_samples": len(train_data),
  605. "val_samples": len(val_data),
  606. "short_query_pct": round(100 * short_final / len(all_examples), 1),
  607. "columns": ["prompt", "completion", "text", "messages"],
  608. }
  609. with open(output_dir / "dataset_info.json", "w") as f:
  610. json.dump(dataset_info, f, indent=2)
  611. if __name__ == "__main__":
  612. main()