prepare_v4_dataset.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = []
  4. # ///
  5. """Prepare v4 dataset: high-quality expansions + /only: variants."""
  6. import json
  7. import random
  8. from pathlib import Path
  9. def to_chat_format(query: str, output: str) -> dict:
  10. """Convert input/output to chat format with /no_think."""
  11. # For /only: queries, keep the suffix in the prompt
  12. prompt = f"/no_think Expand this search query: {query}"
  13. text = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n{output}<|im_end|>\n"
  14. messages = [
  15. {"role": "user", "content": prompt},
  16. {"role": "assistant", "content": output}
  17. ]
  18. return {"text": text, "messages": messages}
  19. def load_jsonl(path: Path) -> list[dict]:
  20. """Load JSONL file."""
  21. data = []
  22. with open(path) as f:
  23. for line in f:
  24. line = line.strip()
  25. if line:
  26. data.append(json.loads(line))
  27. return data
  28. def main():
  29. data_dir = Path("data")
  30. # High-quality sources
  31. sources = [
  32. ("qmd_expansion_v2.jsonl", "v2"),
  33. ("qmd_expansion_handcrafted.jsonl", "handcrafted"),
  34. ("qmd_only_variants.jsonl", "only"),
  35. ]
  36. all_examples = []
  37. stats = {}
  38. for filename, label in sources:
  39. path = data_dir / filename
  40. if not path.exists():
  41. print(f" Skipping {filename} (not found)")
  42. continue
  43. raw = load_jsonl(path)
  44. converted = []
  45. for item in raw:
  46. query = item.get("input", "")
  47. output = item.get("output", "")
  48. if query and output:
  49. converted.append(to_chat_format(query, output))
  50. all_examples.extend(converted)
  51. stats[label] = len(converted)
  52. print(f" {label}: {len(converted)} examples")
  53. # Shuffle
  54. random.seed(42)
  55. random.shuffle(all_examples)
  56. # Split 90/10
  57. split_idx = int(len(all_examples) * 0.9)
  58. train = all_examples[:split_idx]
  59. val = all_examples[split_idx:]
  60. # Write output
  61. out_dir = data_dir / "train_v4"
  62. out_dir.mkdir(exist_ok=True)
  63. with open(out_dir / "train.jsonl", "w") as f:
  64. for ex in train:
  65. f.write(json.dumps(ex) + "\n")
  66. with open(out_dir / "val.jsonl", "w") as f:
  67. for ex in val:
  68. f.write(json.dumps(ex) + "\n")
  69. # Dataset info
  70. info = {
  71. "dataset_name": "qmd-query-expansion-v4",
  72. "train_samples": len(train),
  73. "val_samples": len(val),
  74. "sources": stats,
  75. }
  76. with open(out_dir / "dataset_info.json", "w") as f:
  77. json.dump(info, f, indent=2)
  78. print(f"\n✓ Dataset prepared in {out_dir}/")
  79. print(f" Train: {len(train)}")
  80. print(f" Val: {len(val)}")
  81. print(f" Total: {len(all_examples)}")
  82. if __name__ == "__main__":
  83. main()