| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664 |
- #!/usr/bin/env python3
- # /// script
- # requires-python = ">=3.10"
- # dependencies = [
- # "transformers>=4.45.0",
- # "jinja2",
- # ]
- # ///
- """Prepare QMD query expansion data for training.
- See PROMPT_FORMAT.md for format specification.
- """
- import argparse
- import json
- import random
- from pathlib import Path
- from transformers import AutoTokenizer
- _tokenizer = None
- def get_tokenizer():
- global _tokenizer
- if _tokenizer is None:
- _tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
- return _tokenizer
- # Short single-word queries that need proper expansion examples - organized by category
- SHORT_QUERIES = [
- # === Programming Languages & Runtimes ===
- "python",
- "typescript",
- "javascript",
- "rust",
- "golang",
- "java",
- "kotlin",
- "swift",
- "ruby",
- "php",
- "cpp",
- "c",
- "elixir",
- "scala",
- "clojure",
- "dart",
- "r",
- "node",
- "deno",
- "bun",
- # === Frontend Frameworks ===
- "react",
- "vue",
- "angular",
- "svelte",
- "solid",
- "htmx",
- "alpine",
- "nextjs",
- "nuxt",
- "jquery",
- "backbone",
- "ember",
- # === Backend Frameworks ===
- "django",
- "flask",
- "fastapi",
- "express",
- "rails",
- "spring",
- "laravel",
- "gin",
- # === Databases ===
- "postgres",
- "mysql",
- "mongodb",
- "redis",
- "elasticsearch",
- "sqlite",
- "dynamodb",
- "cassandra",
- "cockroachdb",
- "neo4j",
- "couchdb",
- # === Infrastructure & DevOps ===
- "docker",
- "kubernetes",
- "terraform",
- "ansible",
- "vagrant",
- "packer",
- "jenkins",
- "gitlab-ci",
- "github-actions",
- "circleci",
- "travis",
- "argo",
- "nginx",
- "apache",
- "caddy",
- "traefik",
- "haproxy",
- "envoy",
- # === Cloud Platforms ===
- "aws",
- "gcp",
- "azure",
- "vercel",
- "netlify",
- "heroku",
- "digitalocean",
- "cloudflare",
- "flyio",
- "render",
- # === Tools & Utilities ===
- "git",
- "linux",
- "bash",
- "zsh",
- "vim",
- "tmux",
- "curl",
- "wget",
- "ssh",
- "npm",
- "pip",
- "brew",
- "apt",
- "yum",
- "cargo",
- "gem",
- "composer",
- "maven",
- # === Security & Auth ===
- "auth",
- "oauth",
- "jwt",
- "saml",
- "ldap",
- "rbac",
- "cors",
- "csrf",
- "xss",
- "ssl",
- "tls",
- "cert",
- "encrypt",
- "hash",
- "cipher",
- # === Web Technologies ===
- "rest",
- "graphql",
- "grpc",
- "websocket",
- "sse",
- "http",
- "https",
- "html",
- "css",
- "sass",
- "less",
- "styled-components",
- "tailwind",
- # === Data & ML ===
- "pandas",
- "numpy",
- "tensorflow",
- "pytorch",
- "sklearn",
- "jupyter",
- "spark",
- "kafka",
- "airflow",
- "dbt",
- "hadoop",
- "hive",
- "presto",
- # === Testing ===
- "jest",
- "vitest",
- "pytest",
- "mocha",
- "cypress",
- "playwright",
- "selenium",
- "rspec",
- "junit",
- "testng",
- # === Build Tools ===
- "webpack",
- "vite",
- "esbuild",
- "rollup",
- "parcel",
- "turbopack",
- "babel",
- # === Monitoring & Observability ===
- "prometheus",
- "grafana",
- "datadog",
- "newrelic",
- "sentry",
- "jaeger",
- "logging",
- "metrics",
- "tracing",
- "observability",
- # === API & Integration ===
- "swagger",
- "openapi",
- "postman",
- "api",
- "webhook",
- "sdk",
- "cli",
- # === Architecture Patterns ===
- "microservices",
- "serverless",
- "monolith",
- "event-driven",
- "cqrs",
- "event-sourcing",
- "saga",
- "circuit-breaker",
- "retry",
- "idempotency",
- # === Development Concepts ===
- "config",
- "setup",
- "cache",
- "log",
- "debug",
- "deploy",
- "build",
- "lint",
- "format",
- "migrate",
- "backup",
- "restore",
- "env",
- "vars",
- "secrets",
- "rate-limit",
- "load-balance",
- "scale",
- "replicate",
- "shard",
- # === General Knowledge: Trivia ===
- "trivia",
- "quiz",
- "facts",
- "did-you-know",
- "random-facts",
- "world-records",
- # === General Knowledge: Geography ===
- "countries",
- "capitals",
- "continents",
- "oceans",
- "rivers",
- "mountains",
- "deserts",
- "islands",
- "climate",
- "population",
- "maps",
- "coordinates",
- # === General Knowledge: Philosophy ===
- "ethics",
- "metaphysics",
- "epistemology",
- "logic",
- "stoicism",
- "existentialism",
- "nihilism",
- "utilitarianism",
- "deontology",
- "virtue-ethics",
- "free-will",
- # === General Knowledge: History ===
- "ancient",
- "medieval",
- "renaissance",
- "industrial",
- "world-war",
- "cold-war",
- "revolution",
- "empire",
- "civilization",
- "archaeology",
- "timeline",
- # === General Knowledge: Science ===
- "physics",
- "chemistry",
- "biology",
- "astronomy",
- "geology",
- "ecology",
- "evolution",
- "genetics",
- "quantum",
- "relativity",
- "thermodynamics",
- # === General Knowledge: Arts & Culture ===
- "art",
- "music",
- "literature",
- "film",
- "theater",
- "dance",
- "sculpture",
- "painting",
- "photography",
- "architecture",
- "poetry",
- "novel",
- # === Common Short Phrases ===
- "env vars",
- "api keys",
- "error handling",
- "rate limiting",
- "file upload",
- "user auth",
- "db connection",
- "query params",
- "hot reload",
- "code split",
- "tree shake",
- "lazy load",
- "dependency injection",
- "event listener",
- "middleware chain",
- "route handler",
- "controller logic",
- "service layer",
- "repository pattern",
- "unit of work",
- "domain model",
- "value object",
- "aggregate root",
- "event bus",
- "message queue",
- "job scheduler",
- "web server",
- "app server",
- "proxy server",
- "load balancer",
- "cdn",
- "dns",
- "ssl-cert",
- "firewall",
- "subnet",
- "vpc",
- "gateway",
- ]
- # Templates for generating short query expansions
- # IMPORTANT: All lex lines MUST include {q} to preserve key terms
- SHORT_TEMPLATES = [
- # Configuration/Setup templates
- {
- "lex": ["{q} configuration", "{q} settings", "{q} setup guide"],
- "vec": [
- "how to configure {q} in my project",
- "{q} setup and configuration tutorial",
- ],
- "hyde": "To set up {q}, first install the required dependencies. Then configure the settings in your project configuration file.",
- },
- # Tutorial/Learning templates
- {
- "lex": ["{q} tutorial", "{q} guide", "{q} basics"],
- "vec": ["beginner guide to {q}", "how to get started with {q}"],
- "hyde": "This guide covers the basics of {q}. Follow the steps below to get started with your first implementation.",
- },
- # Best practices templates
- {
- "lex": ["{q} best practices", "{q} patterns", "{q} tips"],
- "vec": ["best practices for using {q}", "recommended patterns for {q}"],
- "hyde": "When working with {q}, follow these best practices: use consistent naming, handle errors properly, and document your code.",
- },
- # Troubleshooting templates
- {
- "lex": ["{q} troubleshooting", "{q} fix", "{q} errors"],
- "vec": ["how to fix {q} errors", "troubleshooting common {q} problems"],
- "hyde": "If you encounter {q} issues, check your configuration first. Common problems include missing dependencies and incorrect settings.",
- },
- # Examples/Code templates
- {
- "lex": ["{q} examples", "{q} code samples", "{q} usage"],
- "vec": ["code examples for {q}", "practical {q} implementation examples"],
- "hyde": "Here are some practical examples of {q} in action. Each example demonstrates a common use case with working code.",
- },
- # Documentation/Reference templates
- {
- "lex": ["{q} documentation", "{q} reference", "{q} manual"],
- "vec": ["official {q} documentation", "{q} API reference guide"],
- "hyde": "The official documentation for {q} provides comprehensive information about features, configuration options, and usage examples.",
- },
- # Installation templates
- {
- "lex": ["{q} install", "{q} setup", "{q} getting started"],
- "vec": ["how to install {q} on my system", "{q} installation guide"],
- "hyde": "To install {q}, run the appropriate package manager command for your system. Verify the installation by checking the version.",
- },
- # Comparison templates
- {
- "lex": ["{q} comparison", "{q} vs alternatives", "{q} differences"],
- "vec": ["how does {q} compare to alternatives", "{q} pros and cons"],
- "hyde": "When comparing {q} to similar tools, consider factors like performance, ease of use, community support, and ecosystem compatibility.",
- },
- # Performance templates
- {
- "lex": ["{q} performance", "{q} optimization", "{q} speed"],
- "vec": ["how to optimize {q} performance", "{q} performance tuning tips"],
- "hyde": "To improve {q} performance, profile your application to identify bottlenecks. Common optimizations include caching, lazy loading, and query optimization.",
- },
- # Security templates
- {
- "lex": ["{q} security", "{q} hardening", "{q} vulnerabilities"],
- "vec": ["how to secure {q} configuration", "{q} security best practices"],
- "hyde": "Security considerations for {q} include input validation, authentication, authorization, and keeping dependencies up to date with security patches.",
- },
- # Testing templates
- {
- "lex": ["{q} testing", "{q} test suite", "{q} unit tests"],
- "vec": ["how to test {q} code", "{q} testing strategies and frameworks"],
- "hyde": "Testing {q} involves writing unit tests, integration tests, and end-to-end tests. Use appropriate testing frameworks for your language and platform.",
- },
- # Deployment templates
- {
- "lex": ["{q} deployment", "{q} production", "{q} release"],
- "vec": ["how to deploy {q} to production", "{q} production deployment guide"],
- "hyde": "Deploying {q} to production requires proper configuration, environment variables, monitoring, and rollback procedures for reliability.",
- },
- # Debugging templates
- {
- "lex": ["{q} debugging", "{q} troubleshooting", "{q} error handling"],
- "vec": ["how to debug {q} issues", "{q} debugging techniques and tools"],
- "hyde": "Debugging {q} involves using logging, breakpoints, stack traces, and specialized debugging tools to identify and fix issues efficiently.",
- },
- # Integration templates
- {
- "lex": ["{q} integration", "{q} connect", "{q} interoperability"],
- "vec": ["how to integrate {q} with other systems", "{q} integration patterns"],
- "hyde": "Integrating {q} with other systems requires understanding APIs, data formats, authentication mechanisms, and error handling strategies.",
- },
- # Migration templates
- {
- "lex": ["{q} migration", "{q} upgrade", "{q} versioning"],
- "vec": ["how to migrate to {q}", "{q} upgrade guide and breaking changes"],
- "hyde": "Migrating to {q} involves planning, testing compatibility, addressing breaking changes, and validating functionality before production deployment.",
- },
- ]
- def truncate_hyde(hyde_text: str, max_len: int = 150) -> str:
- """Truncate hyde to max length, ending at sentence boundary."""
- if len(hyde_text) <= max_len:
- return hyde_text
- truncated = hyde_text[:max_len]
- last_period = truncated.rfind(". ")
- if last_period > max_len // 2:
- return truncated[: last_period + 1]
- last_space = truncated.rfind(" ")
- if last_space > max_len // 2:
- return truncated[:last_space] + "."
- return truncated[: max_len - 1] + "."
- def clean_output(output: str) -> str:
- """Clean output: truncate hyde, remove invalid lines."""
- lines = []
- for line in output.strip().split("\n"):
- line = line.strip()
- if not line:
- continue
- if line.startswith("hyde:"):
- hyde_text = line[5:].strip()
- hyde_text = truncate_hyde(hyde_text)
- lines.append(f"hyde: {hyde_text}")
- elif line.startswith(("lex:", "vec:")):
- lines.append(line)
- return "\n".join(lines)
- def generate_short_example(query: str) -> dict:
- """Generate a training example for a short query."""
- template = random.choice(SHORT_TEMPLATES)
- lex_lines = random.sample(template["lex"], 2)
- vec_lines = random.sample(template["vec"], 2)
- hyde_line = template["hyde"]
- output_lines = []
- for lex in lex_lines:
- output_lines.append(f"lex: {lex.format(q=query)}")
- for vec in vec_lines:
- output_lines.append(f"vec: {vec.format(q=query)}")
- output_lines.append(f"hyde: {hyde_line.format(q=query)}")
- return {"input": query, "output": "\n".join(output_lines)}
- def format_for_training(input_text: str, output_text: str) -> dict:
- """Format a single example for SFT training using Qwen chat format."""
- tokenizer = get_tokenizer()
- # Use /no_think to disable thinking mode - we want direct output
- messages = [
- {
- "role": "user",
- "content": f"/no_think Expand this search query: {input_text}",
- },
- {"role": "assistant", "content": output_text},
- ]
- # Use tokenizer to generate proper chat format with special tokens
- text = tokenizer.apply_chat_template(
- messages,
- tokenize=False,
- add_generation_prompt=False,
- )
- # Strip empty tags - we don't want thinking mode
- # The template adds " \n\n\u5df4\u6bd4\n\n" which we remove
- text = text.replace(" \n\n\u5df4\u6bd4\n\n", "")
- return {
- "text": text,
- "messages": messages,
- }
- def main():
- parser = argparse.ArgumentParser(description="Prepare data for training")
- parser.add_argument(
- "--input", type=str, default="data/qmd_expansion.jsonl", help="Input JSONL file"
- )
- parser.add_argument(
- "--output", type=str, default="data/train", help="Output directory"
- )
- parser.add_argument(
- "--split", type=float, default=0.1, help="Validation split ratio"
- )
- parser.add_argument(
- "--add-short", type=int, default=3, help="Variations per short query to add"
- )
- args = parser.parse_args()
- input_path = Path(args.input)
- output_dir = Path(args.output)
- output_dir.mkdir(parents=True, exist_ok=True)
- if not input_path.exists():
- print(f"Error: Input file not found: {input_path}")
- exit(1)
- # Load and clean existing examples
- examples = []
- seen_queries = set()
- long_hyde_count = 0
- with open(input_path) as f:
- for line in f:
- if line.strip():
- ex = json.loads(line)
- # Clean output (truncate hyde, remove invalid lines)
- original_output = ex["output"]
- ex["output"] = clean_output(ex["output"])
- # Track hyde truncation
- if "hyde:" in original_output:
- for orig_line in original_output.split("\n"):
- if orig_line.strip().startswith("hyde:"):
- if len(orig_line) > 160:
- long_hyde_count += 1
- # Validate cleaned output
- has_lex = "lex:" in ex["output"]
- has_vec = "vec:" in ex["output"]
- if has_lex and has_vec:
- examples.append(ex)
- seen_queries.add(ex["input"].lower())
- print(f"Loaded and cleaned {len(examples)} examples")
- print(f"Truncated {long_hyde_count} long hyde sections")
- # Count existing short queries
- short_existing = sum(1 for ex in examples if len(ex["input"].split()) <= 2)
- print(f"Existing short queries (1-2 words): {short_existing}")
- # Generate additional short query examples
- new_short = []
- for query in SHORT_QUERIES:
- if query.lower() not in seen_queries:
- for _ in range(args.add_short):
- new_short.append(generate_short_example(query))
- seen_queries.add(query.lower())
- print(f"Generated {len(new_short)} new short query examples")
- # Combine and shuffle
- all_examples = examples + new_short
- random.shuffle(all_examples)
- # Format for training
- formatted = [format_for_training(ex["input"], ex["output"]) for ex in all_examples]
- # Split into train/val
- split_idx = int(len(formatted) * (1 - args.split))
- train_data = formatted[:split_idx]
- val_data = formatted[split_idx:]
- # Write train set
- train_path = output_dir / "train.jsonl"
- with open(train_path, "w") as f:
- for item in train_data:
- f.write(json.dumps(item) + "\n")
- # Write validation set
- val_path = output_dir / "val.jsonl"
- with open(val_path, "w") as f:
- for item in val_data:
- f.write(json.dumps(item) + "\n")
- # Write chat format (for TRL)
- chat_path = output_dir / "train_chat.jsonl"
- with open(chat_path, "w") as f:
- for item in train_data:
- f.write(json.dumps({"messages": item["messages"]}) + "\n")
- # Stats
- short_final = sum(1 for ex in all_examples if len(ex["input"].split()) <= 2)
- print(f"\n=== Summary ===")
- print(f"Total examples: {len(all_examples)}")
- print(
- f"Short queries: {short_final} ({100 * short_final / len(all_examples):.1f}%)"
- )
- print(f"Train: {len(train_data)}, Val: {len(val_data)}")
- print(f"Output: {output_dir}")
- # Dataset info
- dataset_info = {
- "dataset_name": "qmd-query-expansion",
- "train_samples": len(train_data),
- "val_samples": len(val_data),
- "short_query_pct": round(100 * short_final / len(all_examples), 1),
- "columns": ["prompt", "completion", "text", "messages"],
- }
- with open(output_dir / "dataset_info.json", "w") as f:
- json.dump(dataset_info, f, indent=2)
- if __name__ == "__main__":
- main()
|