prepare_data.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/env python3
  2. """Prepare QMD query expansion data for training."""
  3. import argparse
  4. import json
  5. from pathlib import Path
  6. # Prompt template matching QMD's llm.ts format (simplified for training)
  7. PROMPT_TEMPLATE = """You are a search query optimization expert. Transform the query into retrieval-optimized outputs.
  8. Query: {query}
  9. Output format:
  10. lex: {{keyword variation}}
  11. vec: {{semantic reformulation}}
  12. hyde: {{hypothetical document passage}}
  13. Output:"""
  14. def format_for_training(input_text: str, output_text: str) -> dict:
  15. """Format a single example for SFT training."""
  16. prompt = PROMPT_TEMPLATE.format(query=input_text)
  17. return {
  18. "prompt": prompt,
  19. "completion": output_text,
  20. # Alternative format for some trainers
  21. "text": f"{prompt}\n{output_text}",
  22. # Chat format
  23. "messages": [
  24. {"role": "user", "content": f"Expand this search query:\n\n{input_text}"},
  25. {"role": "assistant", "content": output_text}
  26. ]
  27. }
  28. def main():
  29. parser = argparse.ArgumentParser(description="Prepare data for training")
  30. parser.add_argument("--input", type=str, default="data/qmd_expansion.jsonl", help="Input JSONL file")
  31. parser.add_argument("--output", type=str, default="data/train", help="Output directory")
  32. parser.add_argument("--split", type=float, default=0.1, help="Validation split ratio")
  33. args = parser.parse_args()
  34. input_path = Path(args.input)
  35. output_dir = Path(args.output)
  36. output_dir.mkdir(parents=True, exist_ok=True)
  37. if not input_path.exists():
  38. print(f"Error: Input file not found: {input_path}")
  39. exit(1)
  40. # Load examples
  41. examples = []
  42. with open(input_path) as f:
  43. for line in f:
  44. if line.strip():
  45. examples.append(json.loads(line))
  46. print(f"Loaded {len(examples)} examples from {input_path}")
  47. # Format for training
  48. formatted = [format_for_training(ex["input"], ex["output"]) for ex in examples]
  49. # Split into train/val
  50. split_idx = int(len(formatted) * (1 - args.split))
  51. train_data = formatted[:split_idx]
  52. val_data = formatted[split_idx:]
  53. # Write train set
  54. train_path = output_dir / "train.jsonl"
  55. with open(train_path, "w") as f:
  56. for item in train_data:
  57. f.write(json.dumps(item) + "\n")
  58. # Write validation set
  59. val_path = output_dir / "val.jsonl"
  60. with open(val_path, "w") as f:
  61. for item in val_data:
  62. f.write(json.dumps(item) + "\n")
  63. # Write chat format (for TRL/Unsloth)
  64. chat_path = output_dir / "train_chat.jsonl"
  65. with open(chat_path, "w") as f:
  66. for item in train_data:
  67. f.write(json.dumps({"messages": item["messages"]}) + "\n")
  68. print(f"Written {len(train_data)} train examples to {train_path}")
  69. print(f"Written {len(val_data)} validation examples to {val_path}")
  70. print(f"Written chat format to {chat_path}")
  71. # Also save as HuggingFace datasets format info
  72. dataset_info = {
  73. "dataset_name": "qmd-query-expansion",
  74. "train_samples": len(train_data),
  75. "val_samples": len(val_data),
  76. "columns": ["prompt", "completion", "text", "messages"],
  77. }
  78. with open(output_dir / "dataset_info.json", "w") as f:
  79. json.dump(dataset_info, f, indent=2)
  80. if __name__ == "__main__":
  81. main()