| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- #!/usr/bin/env python3
- """Prepare QMD query expansion data for training."""
- import argparse
- import json
- from pathlib import Path
- # Prompt template matching QMD's llm.ts format (simplified for training)
- PROMPT_TEMPLATE = """You are a search query optimization expert. Transform the query into retrieval-optimized outputs.
- Query: {query}
- Output format:
- lex: {{keyword variation}}
- vec: {{semantic reformulation}}
- hyde: {{hypothetical document passage}}
- Output:"""
- def format_for_training(input_text: str, output_text: str) -> dict:
- """Format a single example for SFT training."""
- prompt = PROMPT_TEMPLATE.format(query=input_text)
- return {
- "prompt": prompt,
- "completion": output_text,
- # Alternative format for some trainers
- "text": f"{prompt}\n{output_text}",
- # Chat format
- "messages": [
- {"role": "user", "content": f"Expand this search query:\n\n{input_text}"},
- {"role": "assistant", "content": output_text}
- ]
- }
- def main():
- parser = argparse.ArgumentParser(description="Prepare data for training")
- parser.add_argument("--input", type=str, default="data/qmd_expansion.jsonl", help="Input JSONL file")
- parser.add_argument("--output", type=str, default="data/train", help="Output directory")
- parser.add_argument("--split", type=float, default=0.1, help="Validation split ratio")
- args = parser.parse_args()
- input_path = Path(args.input)
- output_dir = Path(args.output)
- output_dir.mkdir(parents=True, exist_ok=True)
- if not input_path.exists():
- print(f"Error: Input file not found: {input_path}")
- exit(1)
- # Load examples
- examples = []
- with open(input_path) as f:
- for line in f:
- if line.strip():
- examples.append(json.loads(line))
- print(f"Loaded {len(examples)} examples from {input_path}")
- # Format for training
- formatted = [format_for_training(ex["input"], ex["output"]) for ex in examples]
- # Split into train/val
- split_idx = int(len(formatted) * (1 - args.split))
- train_data = formatted[:split_idx]
- val_data = formatted[split_idx:]
- # Write train set
- train_path = output_dir / "train.jsonl"
- with open(train_path, "w") as f:
- for item in train_data:
- f.write(json.dumps(item) + "\n")
- # Write validation set
- val_path = output_dir / "val.jsonl"
- with open(val_path, "w") as f:
- for item in val_data:
- f.write(json.dumps(item) + "\n")
- # Write chat format (for TRL/Unsloth)
- chat_path = output_dir / "train_chat.jsonl"
- with open(chat_path, "w") as f:
- for item in train_data:
- f.write(json.dumps({"messages": item["messages"]}) + "\n")
- print(f"Written {len(train_data)} train examples to {train_path}")
- print(f"Written {len(val_data)} validation examples to {val_path}")
- print(f"Written chat format to {chat_path}")
- # Also save as HuggingFace datasets format info
- dataset_info = {
- "dataset_name": "qmd-query-expansion",
- "train_samples": len(train_data),
- "val_samples": len(val_data),
- "columns": ["prompt", "completion", "text", "messages"],
- }
- with open(output_dir / "dataset_info.json", "w") as f:
- json.dump(dataset_info, f, indent=2)
- if __name__ == "__main__":
- main()
|