generate_data.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. #!/usr/bin/env python3
  2. """Generate synthetic training data for QMD query expansion using Claude API."""
  3. import argparse
  4. import json
  5. import os
  6. import random
  7. from pathlib import Path
  8. try:
  9. import anthropic
  10. except ImportError:
  11. print("Install anthropic: pip install anthropic")
  12. exit(1)
  13. # Sample query templates for diverse training data
  14. QUERY_TEMPLATES = [
  15. # Technical documentation
  16. "how to {action} {technology}",
  17. "{technology} {concept} example",
  18. "configure {technology} for {use_case}",
  19. "{error_type} error in {technology}",
  20. "best practices for {concept}",
  21. # Personal notes / journals
  22. "meeting notes {topic}",
  23. "ideas for {project}",
  24. "{date} journal entry",
  25. "thoughts on {topic}",
  26. # Research / learning
  27. "what is {concept}",
  28. "difference between {thing1} and {thing2}",
  29. "{topic} tutorial",
  30. "learn {skill}",
  31. # Short queries
  32. "{keyword}",
  33. "{keyword} {modifier}",
  34. ]
  35. ACTIONS = ["install", "configure", "setup", "debug", "deploy", "test", "optimize", "migrate"]
  36. TECHNOLOGIES = ["python", "typescript", "react", "docker", "kubernetes", "postgres", "redis", "nginx", "git", "linux"]
  37. CONCEPTS = ["authentication", "caching", "logging", "testing", "deployment", "API", "database", "security"]
  38. USE_CASES = ["production", "development", "CI/CD", "local", "cloud"]
  39. ERROR_TYPES = ["connection", "timeout", "permission", "memory", "syntax"]
  40. TOPICS = ["productivity", "workflow", "architecture", "design", "performance"]
  41. KEYWORDS = ["auth", "config", "setup", "api", "data", "cache", "log", "test"]
  42. MODIFIERS = ["best", "fast", "simple", "advanced", "secure"]
  43. SYSTEM_PROMPT = """You are a search query optimization expert for a markdown document search system called QMD.
  44. Your task is to transform user queries into retrieval-optimized outputs with THREE distinct types:
  45. 1. **lex** lines: Keyword variations optimized for BM25 full-text search
  46. - Short, keyword-focused
  47. - Good for exact term matching
  48. - 1-3 lines
  49. 2. **vec** lines: Semantic reformulations for vector/embedding search
  50. - Complete phrases or questions
  51. - Capture semantic meaning
  52. - 1-3 lines
  53. 3. **hyde** line: A hypothetical document passage (HyDE technique)
  54. - A realistic passage that would answer the query
  55. - Contains domain-specific terminology
  56. - Written as if it's FROM a document, not ABOUT the query
  57. - MAX 1 line
  58. Output format (STRICT - follow exactly):
  59. ```
  60. hyde: A passage that would appear in a document answering this query.
  61. lex: keyword1
  62. lex: keyword2
  63. vec: semantic query reformulation
  64. ```
  65. Rules:
  66. - Each line must start with "lex:", "vec:", or "hyde:"
  67. - No blank lines
  68. - No repetition between lines
  69. - hyde should be a realistic document excerpt, not a question
  70. - Stay focused on the original query intent"""
  71. USER_PROMPT_TEMPLATE = """Generate query expansion outputs for this search query:
  72. Query: {query}
  73. Respond with ONLY the lex/vec/hyde lines, nothing else."""
  74. def generate_random_query() -> str:
  75. """Generate a random query from templates."""
  76. template = random.choice(QUERY_TEMPLATES)
  77. replacements = {
  78. "{action}": random.choice(ACTIONS),
  79. "{technology}": random.choice(TECHNOLOGIES),
  80. "{concept}": random.choice(CONCEPTS),
  81. "{use_case}": random.choice(USE_CASES),
  82. "{error_type}": random.choice(ERROR_TYPES),
  83. "{topic}": random.choice(TOPICS),
  84. "{project}": random.choice(["website", "app", "CLI tool", "API", "library"]),
  85. "{date}": random.choice(["2024-01", "2024-06", "yesterday", "today"]),
  86. "{thing1}": random.choice(CONCEPTS[:4]),
  87. "{thing2}": random.choice(CONCEPTS[4:]),
  88. "{skill}": random.choice(TECHNOLOGIES),
  89. "{keyword}": random.choice(KEYWORDS),
  90. "{modifier}": random.choice(MODIFIERS),
  91. }
  92. query = template
  93. for key, value in replacements.items():
  94. query = query.replace(key, value)
  95. return query
  96. def generate_expansion(client: anthropic.Anthropic, query: str) -> str | None:
  97. """Generate expansion using Claude API."""
  98. try:
  99. response = client.messages.create(
  100. model="claude-sonnet-4-20250514",
  101. max_tokens=300,
  102. system=SYSTEM_PROMPT,
  103. messages=[
  104. {"role": "user", "content": USER_PROMPT_TEMPLATE.format(query=query)}
  105. ]
  106. )
  107. return response.content[0].text.strip()
  108. except Exception as e:
  109. print(f"Error generating expansion for '{query}': {e}")
  110. return None
  111. def validate_output(output: str) -> bool:
  112. """Validate that output follows the expected format."""
  113. lines = output.strip().split("\n")
  114. if not lines:
  115. return False
  116. has_lex = False
  117. has_vec = False
  118. for line in lines:
  119. line = line.strip()
  120. if not line:
  121. continue
  122. if line.startswith("lex:"):
  123. has_lex = True
  124. elif line.startswith("vec:"):
  125. has_vec = True
  126. elif line.startswith("hyde:"):
  127. pass
  128. else:
  129. return False # Invalid line type
  130. return has_lex and has_vec
  131. def main():
  132. parser = argparse.ArgumentParser(description="Generate QMD query expansion training data")
  133. parser.add_argument("--count", type=int, default=100, help="Number of examples to generate")
  134. parser.add_argument("--output", type=str, default="data/qmd_expansion.jsonl", help="Output file path")
  135. parser.add_argument("--queries", type=str, help="Optional file with custom queries (one per line)")
  136. args = parser.parse_args()
  137. api_key = os.environ.get("ANTHROPIC_API_KEY")
  138. if not api_key:
  139. print("Error: ANTHROPIC_API_KEY environment variable not set")
  140. exit(1)
  141. client = anthropic.Anthropic(api_key=api_key)
  142. output_path = Path(args.output)
  143. output_path.parent.mkdir(parents=True, exist_ok=True)
  144. # Load custom queries if provided
  145. custom_queries = []
  146. if args.queries and Path(args.queries).exists():
  147. custom_queries = Path(args.queries).read_text().strip().split("\n")
  148. print(f"Loaded {len(custom_queries)} custom queries")
  149. examples = []
  150. seen_queries = set()
  151. print(f"Generating {args.count} examples...")
  152. i = 0
  153. while len(examples) < args.count:
  154. # Use custom query or generate random one
  155. if custom_queries and i < len(custom_queries):
  156. query = custom_queries[i].strip()
  157. else:
  158. query = generate_random_query()
  159. i += 1
  160. # Skip duplicates
  161. if query in seen_queries:
  162. continue
  163. seen_queries.add(query)
  164. # Generate expansion
  165. output = generate_expansion(client, query)
  166. if output and validate_output(output):
  167. examples.append({"input": query, "output": output})
  168. print(f"[{len(examples)}/{args.count}] {query[:50]}...")
  169. else:
  170. print(f" Skipped invalid output for: {query[:50]}...")
  171. # Write output
  172. with open(output_path, "w") as f:
  173. for example in examples:
  174. f.write(json.dumps(example) + "\n")
  175. print(f"\nGenerated {len(examples)} examples to {output_path}")
  176. if __name__ == "__main__":
  177. main()