prepare_data.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "transformers>=4.45.0",
  6. # "pydantic>=2.0",
  7. # "jinja2",
  8. # ]
  9. # ///
  10. """Prepare QMD query expansion data for training.
  11. Loads all data/*.jsonl via the strict Pydantic schema, applies the Qwen3
  12. chat template, deduplicates by query, and writes train/val splits.
  13. The prepared train files are ephemeral build artifacts — the canonical
  14. data lives in data/*.jsonl and is always loaded through the schema.
  15. """
  16. import argparse
  17. import json
  18. import random
  19. import os
  20. from pathlib import Path
  21. from dataset.schema import (
  22. TrainingExample,
  23. load_examples,
  24. output_items_to_text,
  25. )
  26. from transformers import AutoTokenizer
  27. _tokenizer = None
  28. _tokenizer_model = None
  29. def get_tokenizer():
  30. global _tokenizer, _tokenizer_model
  31. model_name = os.environ.get("QMD_BASE_MODEL", "Qwen/Qwen3-1.7B")
  32. if _tokenizer is None or _tokenizer_model != model_name:
  33. _tokenizer = AutoTokenizer.from_pretrained(model_name)
  34. _tokenizer_model = model_name
  35. return _tokenizer
  36. def format_for_training(ex: TrainingExample) -> dict:
  37. """Format a validated TrainingExample for SFT training."""
  38. tokenizer = get_tokenizer()
  39. output_text = output_items_to_text(ex.output)
  40. user_prompt = f"/no_think Expand this search query: {ex.query}"
  41. if ex.intent:
  42. user_prompt = (
  43. f"/no_think Expand this search query: {ex.query}\n"
  44. f"Query intent: {ex.intent.strip()}"
  45. )
  46. messages = [
  47. {
  48. "role": "user",
  49. "content": user_prompt,
  50. },
  51. {"role": "assistant", "content": output_text},
  52. ]
  53. text = tokenizer.apply_chat_template(
  54. messages,
  55. tokenize=False,
  56. add_generation_prompt=False,
  57. )
  58. # Strip empty <think> tags — /no_think should suppress them
  59. text = text.replace("<think>\n\n</think>\n\n", "")
  60. return {
  61. "text": text,
  62. "messages": messages,
  63. }
  64. def main():
  65. parser = argparse.ArgumentParser(description="Prepare data for training")
  66. parser.add_argument(
  67. "--input",
  68. type=str,
  69. default="data/*.jsonl",
  70. help="Input JSONL file(s) - supports glob patterns",
  71. )
  72. parser.add_argument(
  73. "--output", type=str, default="data/train", help="Output directory"
  74. )
  75. parser.add_argument(
  76. "--split", type=float, default=0.1, help="Validation split ratio"
  77. )
  78. parser.add_argument(
  79. "--seed", type=int, default=42, help="Shuffle seed",
  80. )
  81. args = parser.parse_args()
  82. output_dir = Path(args.output)
  83. output_dir.mkdir(parents=True, exist_ok=True)
  84. # Resolve input files
  85. import glob as globmod
  86. if "*" in args.input:
  87. input_files = sorted(globmod.glob(args.input))
  88. if not input_files:
  89. print(f"Error: No files found matching: {args.input}")
  90. exit(1)
  91. print(f"Found {len(input_files)} input files")
  92. else:
  93. input_path = Path(args.input)
  94. if not input_path.exists():
  95. print(f"Error: Input file not found: {input_path}")
  96. exit(1)
  97. input_files = [str(input_path)]
  98. # Load all examples through strict Pydantic schema
  99. all_examples: list[TrainingExample] = []
  100. for input_file in input_files:
  101. examples = load_examples(input_file)
  102. print(f" {Path(input_file).name}: {len(examples)} examples")
  103. all_examples.extend(examples)
  104. print(f"Loaded {len(all_examples)} examples total")
  105. # Deduplicate by query (case-insensitive)
  106. seen: set[str] = set()
  107. deduped: list[TrainingExample] = []
  108. for ex in all_examples:
  109. key = ex.query.lower().strip()
  110. if key not in seen:
  111. seen.add(key)
  112. deduped.append(ex)
  113. if len(deduped) < len(all_examples):
  114. print(f"Deduplicated: {len(all_examples)} -> {len(deduped)}")
  115. all_examples = deduped
  116. # Shuffle
  117. random.seed(args.seed)
  118. random.shuffle(all_examples)
  119. # Format each example using the Pydantic model
  120. formatted = [format_for_training(ex) for ex in all_examples]
  121. # Split
  122. split_idx = int(len(formatted) * (1 - args.split))
  123. train_data = formatted[:split_idx]
  124. val_data = formatted[split_idx:]
  125. # Write (these are ephemeral build artifacts)
  126. for name, data in [("train.jsonl", train_data), ("val.jsonl", val_data)]:
  127. with open(output_dir / name, "w") as f:
  128. for item in data:
  129. f.write(json.dumps(item) + "\n")
  130. with open(output_dir / "train_chat.jsonl", "w") as f:
  131. for item in train_data:
  132. f.write(json.dumps({"messages": item["messages"]}) + "\n")
  133. # Stats
  134. short_final = sum(1 for ex in all_examples if len(ex.query.split()) <= 2)
  135. print(f"\n=== Summary ===")
  136. print(f"Total examples: {len(all_examples)}")
  137. print(f"Short queries: {short_final} ({100 * short_final / len(all_examples):.1f}%)")
  138. print(f"Train: {len(train_data)}, Val: {len(val_data)}")
  139. print(f"Output: {output_dir}")
  140. dataset_info = {
  141. "dataset_name": "qmd-query-expansion",
  142. "train_samples": len(train_data),
  143. "val_samples": len(val_data),
  144. "short_query_pct": round(100 * short_final / len(all_examples), 1),
  145. "columns": ["text", "messages"],
  146. }
  147. with open(output_dir / "dataset_info.json", "w") as f:
  148. json.dump(dataset_info, f, indent=2)
  149. if __name__ == "__main__":
  150. main()