generate_ollama.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #!/usr/bin/env python3
  2. """Generate synthetic training data for QMD query expansion using local Ollama."""
  3. import argparse
  4. import json
  5. import random
  6. import sys
  7. import time
  8. from dataset.schema import normalize_output_items, parse_output_text
  9. from pathlib import Path
  10. try:
  11. import requests
  12. except ImportError:
  13. print("Install requests: pip install requests")
  14. exit(1)
  15. # Diverse query seeds across many domains
  16. QUERY_SEEDS = [
  17. # Programming & Tech
  18. "async await javascript",
  19. "rust ownership borrow checker",
  20. "kubernetes pod networking",
  21. "docker compose volumes",
  22. "nginx reverse proxy",
  23. "postgresql index optimization",
  24. "redis caching strategies",
  25. "graphql mutations",
  26. "websocket authentication",
  27. "terraform state management",
  28. "ansible playbook variables",
  29. "prometheus alerting rules",
  30. "elasticsearch aggregations",
  31. "kafka consumer groups",
  32. "grpc streaming",
  33. "oauth2 refresh tokens",
  34. "jwt token expiration",
  35. "cors preflight requests",
  36. "css grid layout",
  37. "react hooks useEffect",
  38. "vue composition api",
  39. "svelte stores",
  40. "nextjs middleware",
  41. "webpack code splitting",
  42. "typescript generics constraints",
  43. "python asyncio gather",
  44. "go goroutines channels",
  45. "java streams filter map",
  46. "c++ smart pointers",
  47. "swift optionals unwrapping",
  48. # DevOps & Infrastructure
  49. "ci cd pipeline best practices",
  50. "blue green deployment",
  51. "canary release strategy",
  52. "infrastructure as code",
  53. "secrets management vault",
  54. "load balancer health checks",
  55. "ssl certificate renewal",
  56. "dns propagation time",
  57. "cdn cache invalidation",
  58. "container orchestration",
  59. "service mesh istio",
  60. "observability tracing",
  61. "log aggregation elk",
  62. "metrics dashboards grafana",
  63. "incident response runbook",
  64. # Data & ML
  65. "pandas dataframe groupby",
  66. "numpy array broadcasting",
  67. "scikit learn pipeline",
  68. "pytorch autograd",
  69. "tensorflow keras layers",
  70. "huggingface transformers",
  71. "feature engineering techniques",
  72. "hyperparameter tuning",
  73. "model evaluation metrics",
  74. "data preprocessing normalization",
  75. "time series forecasting",
  76. "anomaly detection",
  77. "recommendation systems",
  78. "natural language processing",
  79. "computer vision cnn",
  80. "reinforcement learning",
  81. "transfer learning",
  82. "model deployment mlops",
  83. # Databases
  84. "sql join types explained",
  85. "database normalization forms",
  86. "acid transactions",
  87. "database sharding strategies",
  88. "read replicas setup",
  89. "connection pooling",
  90. "query optimization explain",
  91. "stored procedures triggers",
  92. "database migrations",
  93. "nosql document model",
  94. "graph database queries",
  95. "vector database similarity",
  96. # Security
  97. "xss prevention sanitization",
  98. "sql injection prepared statements",
  99. "csrf tokens",
  100. "content security policy",
  101. "rate limiting api",
  102. "input validation patterns",
  103. "password hashing bcrypt",
  104. "two factor authentication",
  105. "penetration testing",
  106. "security headers http",
  107. "vulnerability scanning",
  108. "audit logging",
  109. # System Administration
  110. "linux file permissions",
  111. "systemd service unit",
  112. "cron job scheduling",
  113. "ssh key management",
  114. "firewall rules iptables",
  115. "process monitoring",
  116. "disk space management",
  117. "memory leak debugging",
  118. "network troubleshooting",
  119. "backup restore strategies",
  120. "log rotation configuration",
  121. "performance profiling",
  122. # General Knowledge
  123. "climate change effects",
  124. "renewable energy sources",
  125. "electric vehicles",
  126. "artificial intelligence ethics",
  127. "blockchain technology",
  128. "quantum computing basics",
  129. "space exploration mars",
  130. "gene editing crispr",
  131. "vaccine development",
  132. "economic indicators gdp",
  133. "stock market investing",
  134. "cryptocurrency trading",
  135. "mental health awareness",
  136. "nutrition diet tips",
  137. "exercise fitness routine",
  138. "meditation mindfulness",
  139. "sleep hygiene habits",
  140. "stress management",
  141. "time management productivity",
  142. "remote work tips",
  143. "team collaboration",
  144. "project management agile",
  145. "design thinking process",
  146. "user experience research",
  147. # Short/Ambiguous Queries (important for training)
  148. "cache",
  149. "proxy",
  150. "queue",
  151. "mutex",
  152. "semaphore",
  153. "deadlock",
  154. "heap",
  155. "stack",
  156. "tree",
  157. "graph",
  158. "hash",
  159. "sort",
  160. "api",
  161. "sdk",
  162. "cli",
  163. "gui",
  164. "orm",
  165. "cdn",
  166. "auth",
  167. "cors",
  168. "csrf",
  169. "xss",
  170. "jwt",
  171. "ssh",
  172. ]
  173. PROMPT_TEMPLATE = """Generate search query expansions for: {query}
  174. Output EXACTLY this format (3 lex, 2 vec, 1 hyde):
  175. lex: keyword phrase 1
  176. lex: keyword phrase 2
  177. lex: keyword phrase 3
  178. vec: natural language search query
  179. vec: alternative semantic query
  180. hyde: A specific 2-sentence document passage answering this query.
  181. Output:"""
  182. def generate_with_ollama(
  183. query: str, model: str = "gemma3:4b", base_url: str = "http://localhost:11434"
  184. ) -> str | None:
  185. """Generate query expansion using Ollama API."""
  186. try:
  187. response = requests.post(
  188. f"{base_url}/api/generate",
  189. json={
  190. "model": model,
  191. "prompt": PROMPT_TEMPLATE.format(query=query),
  192. "stream": False,
  193. "options": {
  194. "temperature": 0.7,
  195. "top_p": 0.9,
  196. "num_predict": 800, # More tokens for thinking models
  197. },
  198. },
  199. timeout=120,
  200. )
  201. response.raise_for_status()
  202. return response.json().get("response", "").strip()
  203. except Exception as e:
  204. print(f"Error generating for '{query}': {e}", file=sys.stderr)
  205. return None
  206. def parse_expansion(output: str) -> list[list[str]] | None:
  207. """Parse the model output into structured format."""
  208. items = normalize_output_items(parse_output_text(output))
  209. lex_count = sum(1 for kind, _ in items if kind == "lex")
  210. vec_count = sum(1 for kind, _ in items if kind == "vec")
  211. hyde_count = sum(1 for kind, _ in items if kind == "hyde")
  212. if lex_count >= 2 and vec_count >= 1 and hyde_count >= 1:
  213. return items
  214. return None
  215. def generate_query_variations(seed: str) -> list[str]:
  216. """Generate variations of a seed query."""
  217. variations = [seed]
  218. # Add question forms
  219. if not seed.startswith(("how", "what", "why", "when", "where")):
  220. variations.append(f"how to {seed}")
  221. variations.append(f"what is {seed}")
  222. # Add context
  223. variations.append(f"{seed} tutorial")
  224. variations.append(f"{seed} best practices")
  225. variations.append(f"{seed} examples")
  226. return variations
  227. def main():
  228. parser = argparse.ArgumentParser(description="Generate training data using Ollama")
  229. parser.add_argument(
  230. "--output", "-o", default="data/qmd_expansion_ollama.jsonl", help="Output file"
  231. )
  232. parser.add_argument(
  233. "--count", "-n", type=int, default=1000, help="Number of examples to generate"
  234. )
  235. parser.add_argument("--model", "-m", default="gemma3:4b", help="Ollama model name")
  236. parser.add_argument(
  237. "--base-url", default="http://localhost:11434", help="Ollama base URL"
  238. )
  239. parser.add_argument(
  240. "--resume", action="store_true", help="Resume from existing file"
  241. )
  242. args = parser.parse_args()
  243. output_path = Path(args.output)
  244. output_path.parent.mkdir(parents=True, exist_ok=True)
  245. # Load existing if resuming
  246. existing_queries = set()
  247. if args.resume and output_path.exists():
  248. with open(output_path) as f:
  249. for line in f:
  250. obj = json.loads(line)
  251. existing_queries.add(obj.get("query", obj.get("input", "")).lower())
  252. print(
  253. f"Resuming with {len(existing_queries)} existing examples", file=sys.stderr
  254. )
  255. # Generate query pool
  256. all_queries = []
  257. for seed in QUERY_SEEDS:
  258. all_queries.extend(generate_query_variations(seed))
  259. # Shuffle and filter
  260. random.shuffle(all_queries)
  261. queries_to_process = [q for q in all_queries if q.lower() not in existing_queries]
  262. print(
  263. f"Processing {min(args.count, len(queries_to_process))} queries with {args.model}...",
  264. file=sys.stderr,
  265. )
  266. generated = 0
  267. errors = 0
  268. mode = "a" if args.resume else "w"
  269. with open(output_path, mode) as f:
  270. for i, query in enumerate(queries_to_process):
  271. if generated >= args.count:
  272. break
  273. output = generate_with_ollama(query, args.model, args.base_url)
  274. if output:
  275. parsed = parse_expansion(output)
  276. if parsed:
  277. example = {"query": query, "output": parsed}
  278. f.write(json.dumps(example) + "\n")
  279. f.flush()
  280. generated += 1
  281. if generated % 10 == 0:
  282. print(
  283. f"Generated {generated}/{args.count} ({errors} errors)",
  284. file=sys.stderr,
  285. )
  286. else:
  287. errors += 1
  288. else:
  289. errors += 1
  290. # Small delay to avoid overwhelming the API
  291. time.sleep(0.1)
  292. print(f"\nDone! Generated {generated} examples, {errors} errors", file=sys.stderr)
  293. print(f"Output: {output_path}", file=sys.stderr)
  294. if __name__ == "__main__":
  295. main()