generate_data.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714
  1. #!/usr/bin/env python3
  2. """Generate synthetic training data for QMD query expansion using Claude API."""
  3. import argparse
  4. import json
  5. import os
  6. import random
  7. from pathlib import Path
  8. from dataset.schema import normalize_output_items, parse_output_text
  9. try:
  10. import anthropic
  11. except ImportError:
  12. print("Install anthropic: pip install anthropic")
  13. exit(1)
  14. # Sample query templates for diverse training data - organized by category
  15. QUERY_TEMPLATES = [
  16. # === Technical documentation (35% of queries) ===
  17. "how to {action} {technology}",
  18. "{technology} {concept} example",
  19. "configure {technology} for {use_case}",
  20. "{error_type} error in {technology}",
  21. "best practices for {concept}",
  22. "{technology} vs {technology2}",
  23. "{action} {technology} {use_case}",
  24. "setup {technology} {use_case}",
  25. "{technology} tutorial for beginners",
  26. "{technology} documentation",
  27. "{technology} {error_type} troubleshooting",
  28. "{concept} in {technology}",
  29. "migrate from {technology} to {technology2}",
  30. "{action} {concept} {technology}",
  31. # === Personal notes / journals (15% of queries) ===
  32. "meeting notes {topic}",
  33. "ideas for {project}",
  34. "{date} journal entry",
  35. "thoughts on {topic}",
  36. "{project} {topic} notes",
  37. "{topic} meeting {date}",
  38. "reflect on {topic}",
  39. "brainstorm {project}",
  40. # === Research / learning (20% of queries) ===
  41. "what is {concept}",
  42. "difference between {thing1} and {thing2}",
  43. "{topic} tutorial",
  44. "learn {skill}",
  45. "understand {concept}",
  46. "explain {concept}",
  47. "{topic} fundamentals",
  48. "intro to {skill}",
  49. "{thing1} or {thing2}",
  50. "when to use {concept}",
  51. # === Short / keyword queries (15% of queries) ===
  52. "{keyword}",
  53. "{keyword} {modifier}",
  54. "{keyword} {action}",
  55. "{keyword} {use_case}",
  56. "{technology} {keyword}",
  57. "{concept} {keyword}",
  58. # === Temporal / recency queries (10% of queries) ===
  59. "latest {topic}",
  60. "recent {concept} changes",
  61. "new {technology} features",
  62. "{topic} update {date}",
  63. "what changed in {technology}",
  64. "{technology} changelog {date}",
  65. "{topic} news {date}",
  66. # === Named entities / specific topics (5% of queries) ===
  67. "{named_entity} {topic}",
  68. "{person} {concept}",
  69. "{organization} {use_case}",
  70. "{product} {action}",
  71. ]
  72. # Category weights for balanced sampling
  73. TEMPLATE_CATEGORIES = {
  74. "technical": list(range(0, 14)), # 0-13
  75. "personal": list(range(14, 22)), # 14-21
  76. "research": list(range(22, 31)), # 22-30
  77. "short": list(range(31, 36)), # 31-35
  78. "temporal": list(range(36, 42)), # 36-41
  79. "entities": list(range(42, 46)), # 42-45
  80. }
  81. ACTIONS = [
  82. "install",
  83. "configure",
  84. "setup",
  85. "debug",
  86. "deploy",
  87. "test",
  88. "optimize",
  89. "migrate",
  90. "build",
  91. "run",
  92. "lint",
  93. "format",
  94. "backup",
  95. "restore",
  96. "update",
  97. "rollback",
  98. "monitor",
  99. "scale",
  100. "secure",
  101. "integrate",
  102. "automate",
  103. "refactor",
  104. "initialize",
  105. ]
  106. TECHNOLOGIES = [
  107. # Languages
  108. "python",
  109. "typescript",
  110. "javascript",
  111. "rust",
  112. "golang",
  113. "java",
  114. "kotlin",
  115. "swift",
  116. "ruby",
  117. "php",
  118. "cpp",
  119. "c",
  120. "elixir",
  121. "scala",
  122. "clojure",
  123. "dart",
  124. # Frameworks/Frontend
  125. "react",
  126. "vue",
  127. "angular",
  128. "svelte",
  129. "solid",
  130. "htmx",
  131. "alpine",
  132. "nextjs",
  133. "nuxt",
  134. # Backend
  135. "django",
  136. "flask",
  137. "fastapi",
  138. "express",
  139. "rails",
  140. "spring",
  141. "laravel",
  142. # Infrastructure
  143. "docker",
  144. "kubernetes",
  145. "terraform",
  146. "ansible",
  147. "jenkins",
  148. "github-actions",
  149. # Databases
  150. "postgres",
  151. "mysql",
  152. "mongodb",
  153. "redis",
  154. "elasticsearch",
  155. "sqlite",
  156. "dynamodb",
  157. "cassandra",
  158. "cockroachdb",
  159. "supabase",
  160. "firebase",
  161. # Tools
  162. "git",
  163. "nginx",
  164. "apache",
  165. "linux",
  166. "aws",
  167. "gcp",
  168. "azure",
  169. "vercel",
  170. "netlify",
  171. # Data/ML
  172. "pandas",
  173. "numpy",
  174. "tensorflow",
  175. "pytorch",
  176. "scikit-learn",
  177. "jupyter",
  178. "spark",
  179. "kafka",
  180. "airflow",
  181. "dbt",
  182. ]
  183. TECHNOLOGIES_2 = [
  184. "docker",
  185. "kubernetes",
  186. "postgres",
  187. "mysql",
  188. "redis",
  189. "mongodb",
  190. "aws",
  191. "gcp",
  192. "react",
  193. "vue",
  194. "angular",
  195. "python",
  196. "javascript",
  197. "typescript",
  198. "github-actions",
  199. "gitlab-ci",
  200. "jenkins",
  201. "terraform",
  202. "ansible",
  203. ]
  204. CONCEPTS = [
  205. "authentication",
  206. "caching",
  207. "logging",
  208. "testing",
  209. "deployment",
  210. "API",
  211. "database",
  212. "security",
  213. "monitoring",
  214. "performance",
  215. "scalability",
  216. "reliability",
  217. "observability",
  218. "microservices",
  219. "serverless",
  220. "virtualization",
  221. "containerization",
  222. "orchestration",
  223. "CI/CD",
  224. "version control",
  225. "dependency injection",
  226. "event sourcing",
  227. "CQRS",
  228. "load balancing",
  229. "rate limiting",
  230. "circuit breaker",
  231. "retry logic",
  232. "idempotency",
  233. ]
  234. USE_CASES = [
  235. "production",
  236. "development",
  237. "CI/CD",
  238. "local",
  239. "cloud",
  240. "staging",
  241. "testing",
  242. "microservices",
  243. "serverless",
  244. "hybrid",
  245. "multi-tenant",
  246. "high-availability",
  247. "real-time",
  248. "batch processing",
  249. "stream processing",
  250. "data pipeline",
  251. ]
  252. ERROR_TYPES = [
  253. "connection",
  254. "timeout",
  255. "permission",
  256. "memory",
  257. "syntax",
  258. "runtime",
  259. "configuration",
  260. "dependency",
  261. "network",
  262. "authentication",
  263. "authorization",
  264. "validation",
  265. "concurrency",
  266. "deadlock",
  267. "resource",
  268. "quota",
  269. ]
  270. TOPICS = [
  271. "productivity",
  272. "workflow",
  273. "architecture",
  274. "design",
  275. "performance",
  276. "security",
  277. "scalability",
  278. "reliability",
  279. "observability",
  280. "maintainability",
  281. "testing",
  282. "documentation",
  283. "refactoring",
  284. "debugging",
  285. "optimization",
  286. "best practices",
  287. "patterns",
  288. "anti-patterns",
  289. "trade-offs",
  290. "decision making",
  291. ]
  292. KEYWORDS = [
  293. "auth",
  294. "config",
  295. "setup",
  296. "api",
  297. "cache",
  298. "log",
  299. "test",
  300. "debug",
  301. "env",
  302. "vars",
  303. "secrets",
  304. "tokens",
  305. "headers",
  306. "params",
  307. "query",
  308. "body",
  309. "route",
  310. "middleware",
  311. "handler",
  312. "controller",
  313. "model",
  314. "view",
  315. "template",
  316. "migration",
  317. "seed",
  318. "fixture",
  319. "mock",
  320. "stub",
  321. "spy",
  322. "fake",
  323. "build",
  324. "bundle",
  325. "compile",
  326. "transpile",
  327. "minify",
  328. "optimize",
  329. "deploy",
  330. "release",
  331. "rollback",
  332. "promote",
  333. "freeze",
  334. "thaw",
  335. "pull",
  336. "push",
  337. "commit",
  338. "merge",
  339. "rebase",
  340. "cherry-pick",
  341. "stash",
  342. "up",
  343. "down",
  344. "scale",
  345. "restart",
  346. "reload",
  347. "refresh",
  348. "flush",
  349. "cron",
  350. "queue",
  351. "job",
  352. "worker",
  353. "scheduler",
  354. "trigger",
  355. "webhook",
  356. "alert",
  357. "metric",
  358. "trace",
  359. "span",
  360. "event",
  361. "incident",
  362. "oncall",
  363. ]
  364. MODIFIERS = [
  365. "best",
  366. "fast",
  367. "simple",
  368. "advanced",
  369. "secure",
  370. "quick",
  371. "easy",
  372. "proper",
  373. "correct",
  374. "safe",
  375. "efficient",
  376. "reliable",
  377. "robust",
  378. "latest",
  379. "recent",
  380. "new",
  381. "old",
  382. "legacy",
  383. "modern",
  384. "local",
  385. "remote",
  386. "global",
  387. "shared",
  388. "private",
  389. "public",
  390. ]
  391. NAMED_ENTITIES = [
  392. "React",
  393. "Vue",
  394. "Angular",
  395. "Docker",
  396. "Kubernetes",
  397. "AWS",
  398. "GCP",
  399. "GitHub",
  400. "GitLab",
  401. "Vercel",
  402. "Netlify",
  403. "Supabase",
  404. "Firebase",
  405. "Stripe",
  406. "Twilio",
  407. "SendGrid",
  408. "Datadog",
  409. "PagerDuty",
  410. "Sentry",
  411. "Terraform",
  412. "Ansible",
  413. "Jenkins",
  414. "CircleCI",
  415. "TravisCI",
  416. ]
  417. PERSONS = [
  418. "Kent Beck",
  419. "Martin Fowler",
  420. "Robert Martin",
  421. "Dave Thomas",
  422. "Guido van Rossum",
  423. "Brendan Eich",
  424. "Ryan Dahl",
  425. "Anders Hejlsberg",
  426. "Linus Torvalds",
  427. "DHH",
  428. "Yukihiro Matsumoto",
  429. "Rich Hickey",
  430. ]
  431. ORGANIZATIONS = [
  432. "Google",
  433. "Microsoft",
  434. "Amazon",
  435. "Meta",
  436. "Apple",
  437. "Netflix",
  438. "Spotify",
  439. "Stripe",
  440. "Shopify",
  441. "Airbnb",
  442. "Uber",
  443. "Lyft",
  444. "Slack",
  445. "Discord",
  446. ]
  447. PRODUCTS = [
  448. "VS Code",
  449. "IntelliJ",
  450. "PyCharm",
  451. "WebStorm",
  452. "DataGrip",
  453. "Postman",
  454. "Insomnia",
  455. "TablePlus",
  456. "Docker Desktop",
  457. "Lens",
  458. "Figma",
  459. "Sketch",
  460. "Notion",
  461. "Linear",
  462. "Jira",
  463. "Trello",
  464. ]
  465. SYSTEM_PROMPT = """You are a search query optimization expert for a markdown document search system called QMD.
  466. Your task is to transform user queries into retrieval-optimized outputs with THREE distinct types:
  467. 1. **lex** lines: Keyword variations optimized for BM25 full-text search
  468. - Short, keyword-focused
  469. - Good for exact term matching
  470. - 1-3 lines
  471. 2. **vec** lines: Semantic reformulations for vector/embedding search
  472. - Complete phrases or questions
  473. - Capture semantic meaning
  474. - 1-3 lines
  475. 3. **hyde** line: A hypothetical document passage (HyDE technique)
  476. - A realistic passage that would answer the query
  477. - Contains domain-specific terminology
  478. - Written as if it's FROM a document, not ABOUT the query
  479. - MAX 1 line
  480. Output format (STRICT - follow exactly):
  481. ```
  482. hyde: A passage that would appear in a document answering this query.
  483. lex: keyword1
  484. lex: keyword2
  485. vec: semantic query reformulation
  486. ```
  487. Rules:
  488. - Each line must start with "lex:", "vec:", or "hyde:"
  489. - No blank lines
  490. - No repetition between lines
  491. - hyde should be a realistic document excerpt, not a question
  492. - Stay focused on the original query intent"""
  493. USER_PROMPT_TEMPLATE = """Generate query expansion outputs for this search query:
  494. Query: {query}
  495. Respond with ONLY the lex/vec/hyde lines, nothing else."""
  496. # Category weights - BALANCED approach
  497. # Tech at 15% (reasonable for QMD's technical document use case)
  498. CATEGORY_WEIGHTS = {
  499. "technical": 0.15, # 15% - Technical documentation
  500. "personal": 0.10, # 10% - Personal notes, journals
  501. "research": 0.10, # 10% - Research and learning
  502. "short": 0.15, # 15% - Short keyword queries
  503. "temporal": 0.10, # 10% - Temporal/recency queries (2025/2026)
  504. "entities": 0.05, # 5% - Named entity queries
  505. "health": 0.10, # 10% - Health & wellness
  506. "finance": 0.10, # 10% - Finance & business
  507. "lifestyle": 0.10, # 10% - Home, food, hobbies, travel
  508. "education": 0.05, # 5% - Education & arts
  509. }
  510. def generate_random_query() -> str:
  511. """Generate a random query from templates with category-weighted sampling."""
  512. # Select category based on weights
  513. categories = list(CATEGORY_WEIGHTS.keys())
  514. weights = list(CATEGORY_WEIGHTS.values())
  515. selected_category = random.choices(categories, weights=weights, k=1)[0]
  516. # Select template from that category
  517. template_idx = random.choice(TEMPLATE_CATEGORIES[selected_category])
  518. template = QUERY_TEMPLATES[template_idx]
  519. # Build replacements based on template type
  520. replacements = {
  521. "{action}": random.choice(ACTIONS),
  522. "{technology}": random.choice(TECHNOLOGIES),
  523. "{technology2}": random.choice(TECHNOLOGIES_2),
  524. "{concept}": random.choice(CONCEPTS),
  525. "{use_case}": random.choice(USE_CASES),
  526. "{error_type}": random.choice(ERROR_TYPES),
  527. "{topic}": random.choice(TOPICS),
  528. "{project}": random.choice(
  529. ["website", "app", "CLI tool", "API", "library", "service", "platform"]
  530. ),
  531. "{date}": random.choice(
  532. # Emphasize 2025/2026 for recency queries (current era)
  533. [
  534. "2026",
  535. "2026",
  536. "2025",
  537. "2025",
  538. "January 2026",
  539. "February 2026",
  540. "March 2026",
  541. "last month",
  542. "this week",
  543. "yesterday",
  544. "today",
  545. "recently",
  546. "latest",
  547. ]
  548. ),
  549. "{thing1}": random.choice(CONCEPTS[:10]),
  550. "{thing2}": random.choice(CONCEPTS[10:] if len(CONCEPTS) > 10 else CONCEPTS),
  551. "{skill}": random.choice(TECHNOLOGIES),
  552. "{keyword}": random.choice(KEYWORDS),
  553. "{modifier}": random.choice(MODIFIERS),
  554. "{named_entity}": random.choice(NAMED_ENTITIES),
  555. "{person}": random.choice(PERSONS),
  556. "{organization}": random.choice(ORGANIZATIONS),
  557. "{product}": random.choice(PRODUCTS),
  558. }
  559. query = template
  560. for key, value in replacements.items():
  561. query = query.replace(key, value)
  562. return query
  563. def generate_expansion(client: anthropic.Anthropic, query: str) -> str | None:
  564. """Generate expansion using Claude API."""
  565. try:
  566. response = client.messages.create(
  567. model="claude-sonnet-4-20250514",
  568. max_tokens=300,
  569. system=SYSTEM_PROMPT,
  570. messages=[
  571. {"role": "user", "content": USER_PROMPT_TEMPLATE.format(query=query)}
  572. ],
  573. )
  574. return response.content[0].text.strip()
  575. except Exception as e:
  576. print(f"Error generating expansion for '{query}': {e}")
  577. return None
  578. def validate_output(output: str) -> bool:
  579. """Validate that output follows the expected format."""
  580. lines = output.strip().split("\n")
  581. if not lines:
  582. return False
  583. has_lex = False
  584. has_vec = False
  585. for line in lines:
  586. line = line.strip()
  587. if not line:
  588. continue
  589. if line.startswith("lex:"):
  590. has_lex = True
  591. elif line.startswith("vec:"):
  592. has_vec = True
  593. elif line.startswith("hyde:"):
  594. pass
  595. else:
  596. return False # Invalid line type
  597. return has_lex and has_vec
  598. def main():
  599. parser = argparse.ArgumentParser(
  600. description="Generate QMD query expansion training data"
  601. )
  602. parser.add_argument(
  603. "--count", type=int, default=100, help="Number of examples to generate"
  604. )
  605. parser.add_argument(
  606. "--output",
  607. type=str,
  608. default="data/qmd_expansion.jsonl",
  609. help="Output file path",
  610. )
  611. parser.add_argument(
  612. "--queries", type=str, help="Optional file with custom queries (one per line)"
  613. )
  614. args = parser.parse_args()
  615. api_key = os.environ.get("ANTHROPIC_API_KEY")
  616. if not api_key:
  617. print("Error: ANTHROPIC_API_KEY environment variable not set")
  618. exit(1)
  619. client = anthropic.Anthropic(api_key=api_key)
  620. output_path = Path(args.output)
  621. output_path.parent.mkdir(parents=True, exist_ok=True)
  622. # Load custom queries if provided
  623. custom_queries = []
  624. if args.queries and Path(args.queries).exists():
  625. custom_queries = Path(args.queries).read_text().strip().split("\n")
  626. print(f"Loaded {len(custom_queries)} custom queries")
  627. examples = []
  628. seen_queries = set()
  629. print(f"Generating {args.count} examples...")
  630. i = 0
  631. while len(examples) < args.count:
  632. # Use custom query or generate random one
  633. if custom_queries and i < len(custom_queries):
  634. query = custom_queries[i].strip()
  635. else:
  636. query = generate_random_query()
  637. i += 1
  638. # Skip duplicates
  639. if query in seen_queries:
  640. continue
  641. seen_queries.add(query)
  642. # Generate expansion
  643. output = generate_expansion(client, query)
  644. if output and validate_output(output):
  645. output_items = normalize_output_items(parse_output_text(output))
  646. examples.append({"query": query, "output": output_items})
  647. print(f"[{len(examples)}/{args.count}] {query[:50]}...")
  648. else:
  649. print(f" Skipped invalid output for: {query[:50]}...")
  650. # Write output
  651. with open(output_path, "w") as f:
  652. for example in examples:
  653. f.write(json.dumps(example) + "\n")
  654. print(f"\nGenerated {len(examples)} examples to {output_path}")
  655. if __name__ == "__main__":
  656. main()