convert_to_chatml.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #!/usr/bin/env python3
  2. """
  3. Convert QMD expansion v3 JSONL to ChatML format for LFM2.5 training.
  4. """
  5. import json
  6. import random
  7. from pathlib import Path
  8. def reorder_hyde_first(output_items):
  9. """Reorder output items to put hyde first, then lex, then vec."""
  10. hyde_items = [item for item in output_items if item[0] == "hyde"]
  11. lex_items = [item for item in output_items if item[0] == "lex"]
  12. vec_items = [item for item in output_items if item[0] == "vec"]
  13. return hyde_items + lex_items + vec_items
  14. def convert_entry(entry):
  15. """Convert a single QMD entry to ChatML format."""
  16. query = entry["query"]
  17. output_items = entry["output"]
  18. # Reorder: hyde first, then lex, then vec
  19. output_items = reorder_hyde_first(output_items)
  20. # Build the assistant response
  21. assistant_lines = []
  22. for item_type, content in output_items:
  23. assistant_lines.append(f"{item_type}: {content}")
  24. assistant_response = "\n".join(assistant_lines)
  25. # Create ChatML formatted text
  26. chatml_text = (
  27. "<|startoftext|>"
  28. "<|im_start|>user\n"
  29. f"Expand this search query: {query}"
  30. "<|im_end|>\n"
  31. "<|im_start|>assistant\n"
  32. f"{assistant_response}"
  33. "<|im_end|>\n"
  34. )
  35. return {"text": chatml_text}
  36. def main():
  37. # Use paths relative to this script's location
  38. script_dir = Path(__file__).parent
  39. input_file = script_dir / "qmd_expansion_v3.jsonl"
  40. output_dir = script_dir / "train-lfm2"
  41. # Load all data
  42. print(f"Loading data from {input_file}...")
  43. all_entries = []
  44. with open(input_file, 'r', encoding='utf-8') as f:
  45. for line_num, line in enumerate(f, 1):
  46. try:
  47. entry = json.loads(line.strip())
  48. converted = convert_entry(entry)
  49. all_entries.append(converted)
  50. except json.JSONDecodeError as e:
  51. print(f"Warning: Skipping invalid JSON on line {line_num}: {e}")
  52. except Exception as e:
  53. print(f"Warning: Error processing line {line_num}: {e}")
  54. print(f"Successfully converted {len(all_entries)} entries")
  55. # Shuffle for better training
  56. random.seed(42) # For reproducibility
  57. random.shuffle(all_entries)
  58. # Split into train (90%) and validation (10%)
  59. split_idx = int(len(all_entries) * 0.9)
  60. train_entries = all_entries[:split_idx]
  61. val_entries = all_entries[split_idx:]
  62. print(f"Train set: {len(train_entries)} entries")
  63. print(f"Validation set: {len(val_entries)} entries")
  64. # Write train set
  65. train_file = output_dir / "train.jsonl"
  66. print(f"Writing train set to {train_file}...")
  67. with open(train_file, 'w', encoding='utf-8') as f:
  68. for entry in train_entries:
  69. f.write(json.dumps(entry, ensure_ascii=False) + '\n')
  70. # Write validation set
  71. val_file = output_dir / "val.jsonl"
  72. print(f"Writing validation set to {val_file}...")
  73. with open(val_file, 'w', encoding='utf-8') as f:
  74. for entry in val_entries:
  75. f.write(json.dumps(entry, ensure_ascii=False) + '\n')
  76. print("Conversion complete!")
  77. # Show some sample entries
  78. print("\nSample train entries:")
  79. for i, entry in enumerate(train_entries[:2]):
  80. print(f"\n--- Sample {i+1} ---")
  81. print(entry["text"][:300] + "..." if len(entry["text"]) > 300 else entry["text"])
  82. if __name__ == "__main__":
  83. main()