prepare_data.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "transformers>=4.45.0",
  6. # "jinja2",
  7. # ]
  8. # ///
  9. """Prepare QMD query expansion data for training.
  10. See PROMPT_FORMAT.md for format specification.
  11. """
  12. import argparse
  13. import json
  14. import random
  15. import sys
  16. import os
  17. from pathlib import Path
  18. sys.path.insert(0, str(Path(__file__).parent.parent))
  19. from dataset.schema import (
  20. normalize_output_items,
  21. output_items_to_text,
  22. parse_output_text,
  23. has_type,
  24. )
  25. from transformers import AutoTokenizer
  26. _tokenizer = None
  27. _tokenizer_model = None
  28. def get_tokenizer():
  29. global _tokenizer, _tokenizer_model
  30. model_name = os.environ.get("QMD_BASE_MODEL", "Qwen/Qwen3-1.7B")
  31. if _tokenizer is None or _tokenizer_model != model_name:
  32. _tokenizer = AutoTokenizer.from_pretrained(model_name)
  33. _tokenizer_model = model_name
  34. return _tokenizer
  35. def format_for_training(query_text: str, output_items: list[list[str]]) -> dict:
  36. """Format a single example for SFT training using Qwen chat format."""
  37. tokenizer = get_tokenizer()
  38. output_text = output_items_to_text(output_items)
  39. # Use /no_think to disable thinking mode - we want direct output
  40. messages = [
  41. {
  42. "role": "user",
  43. "content": f"/no_think Expand this search query: {query_text}",
  44. },
  45. {"role": "assistant", "content": output_text},
  46. ]
  47. # Use tokenizer to generate proper chat format with special tokens
  48. text = tokenizer.apply_chat_template(
  49. messages,
  50. tokenize=False,
  51. add_generation_prompt=False,
  52. )
  53. # Strip empty <think> tags - we don't want thinking mode
  54. # The template adds "<think>\n\n</think>\n\n" which we remove
  55. text = text.replace("<think>\n\n</think>\n\n", "")
  56. return {
  57. "text": text,
  58. "messages": messages,
  59. }
  60. def main():
  61. parser = argparse.ArgumentParser(description="Prepare data for training")
  62. parser.add_argument(
  63. "--input",
  64. type=str,
  65. default="data/*.jsonl",
  66. help="Input JSONL file(s) - supports glob patterns",
  67. )
  68. parser.add_argument(
  69. "--output", type=str, default="data/train", help="Output directory"
  70. )
  71. parser.add_argument(
  72. "--split", type=float, default=0.1, help="Validation split ratio"
  73. )
  74. parser.add_argument(
  75. "--seed",
  76. type=int,
  77. default=42,
  78. help="Shuffle seed (default: 42)",
  79. )
  80. args = parser.parse_args()
  81. output_dir = Path(args.output)
  82. output_dir.mkdir(parents=True, exist_ok=True)
  83. # Support glob patterns for input
  84. import glob
  85. if "*" in args.input:
  86. input_files = sorted(glob.glob(args.input))
  87. if not input_files:
  88. print(f"Error: No files found matching: {args.input}")
  89. exit(1)
  90. print(
  91. f"Found {len(input_files)} input files: {[Path(f).name for f in input_files]}"
  92. )
  93. else:
  94. input_path = Path(args.input)
  95. if not input_path.exists():
  96. print(f"Error: Input file not found: {input_path}")
  97. exit(1)
  98. input_files = [str(input_path)]
  99. # Load all examples from all input files
  100. examples = []
  101. for input_file in input_files:
  102. file_count = 0
  103. with open(input_file) as f:
  104. for line in f:
  105. if line.strip():
  106. ex = json.loads(line)
  107. # Normalize legacy format
  108. if "query" not in ex and "input" in ex:
  109. ex["query"] = ex.pop("input")
  110. if isinstance(ex.get("output"), str):
  111. ex["output"] = parse_output_text(ex["output"])
  112. ex["output"] = normalize_output_items(ex.get("output", []))
  113. examples.append(ex)
  114. file_count += 1
  115. print(f" {Path(input_file).name}: {file_count} examples")
  116. print(f"Loaded {len(examples)} examples total")
  117. # Combine and shuffle
  118. all_examples = examples
  119. random.seed(args.seed)
  120. random.shuffle(all_examples)
  121. # Format for training
  122. formatted = [format_for_training(ex["query"], ex["output"]) for ex in all_examples]
  123. # Split into train/val
  124. split_idx = int(len(formatted) * (1 - args.split))
  125. train_data = formatted[:split_idx]
  126. val_data = formatted[split_idx:]
  127. # Write train set
  128. train_path = output_dir / "train.jsonl"
  129. with open(train_path, "w") as f:
  130. for item in train_data:
  131. f.write(json.dumps(item) + "\n")
  132. # Write validation set
  133. val_path = output_dir / "val.jsonl"
  134. with open(val_path, "w") as f:
  135. for item in val_data:
  136. f.write(json.dumps(item) + "\n")
  137. # Write chat format (for TRL)
  138. chat_path = output_dir / "train_chat.jsonl"
  139. with open(chat_path, "w") as f:
  140. for item in train_data:
  141. f.write(json.dumps({"messages": item["messages"]}) + "\n")
  142. # Stats
  143. short_final = sum(1 for ex in all_examples if len(ex["query"].split()) <= 2)
  144. print(f"\n=== Summary ===")
  145. print(f"Total examples: {len(all_examples)}")
  146. print(
  147. f"Short queries: {short_final} ({100 * short_final / len(all_examples):.1f}%)"
  148. )
  149. print(f"Train: {len(train_data)}, Val: {len(val_data)}")
  150. print(f"Output: {output_dir}")
  151. # Dataset info
  152. dataset_info = {
  153. "dataset_name": "qmd-query-expansion",
  154. "train_samples": len(train_data),
  155. "val_samples": len(val_data),
  156. "short_query_pct": round(100 * short_final / len(all_examples), 1),
  157. "columns": ["prompt", "completion", "text", "messages"],
  158. }
  159. with open(output_dir / "dataset_info.json", "w") as f:
  160. json.dump(dataset_info, f, indent=2)
  161. if __name__ == "__main__":
  162. main()