convert_to_chatml.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #!/usr/bin/env python3
  2. """
  3. Convert QMD expansion v3 JSONL to ChatML format for LFM2.5 training.
  4. """
  5. import json
  6. import random
  7. from pathlib import Path
  8. def convert_entry(entry):
  9. """Convert a single QMD entry to ChatML format."""
  10. query = entry["query"]
  11. output_items = entry["output"]
  12. # Build the assistant response
  13. assistant_lines = []
  14. for item_type, content in output_items:
  15. assistant_lines.append(f"{item_type}: {content}")
  16. assistant_response = "\n".join(assistant_lines)
  17. # Create ChatML formatted text
  18. chatml_text = (
  19. "<|startoftext|>"
  20. "<|im_start|>user\n"
  21. f"Expand this search query: {query}"
  22. "<|im_end|>\n"
  23. "<|im_start|>assistant\n"
  24. f"{assistant_response}"
  25. "<|im_end|>\n"
  26. )
  27. return {"text": chatml_text}
  28. def main():
  29. input_file = Path("~/src/github.com/tobi/qmd/finetune/data/qmd_expansion_v3.jsonl").expanduser()
  30. output_dir = Path("~/src/github.com/tobi/qmd/finetune/data/train-lfm2").expanduser()
  31. # Load all data
  32. print(f"Loading data from {input_file}...")
  33. all_entries = []
  34. with open(input_file, 'r', encoding='utf-8') as f:
  35. for line_num, line in enumerate(f, 1):
  36. try:
  37. entry = json.loads(line.strip())
  38. converted = convert_entry(entry)
  39. all_entries.append(converted)
  40. except json.JSONDecodeError as e:
  41. print(f"Warning: Skipping invalid JSON on line {line_num}: {e}")
  42. except Exception as e:
  43. print(f"Warning: Error processing line {line_num}: {e}")
  44. print(f"Successfully converted {len(all_entries)} entries")
  45. # Shuffle for better training
  46. random.seed(42) # For reproducibility
  47. random.shuffle(all_entries)
  48. # Split into train (90%) and validation (10%)
  49. split_idx = int(len(all_entries) * 0.9)
  50. train_entries = all_entries[:split_idx]
  51. val_entries = all_entries[split_idx:]
  52. print(f"Train set: {len(train_entries)} entries")
  53. print(f"Validation set: {len(val_entries)} entries")
  54. # Write train set
  55. train_file = output_dir / "train.jsonl"
  56. print(f"Writing train set to {train_file}...")
  57. with open(train_file, 'w', encoding='utf-8') as f:
  58. for entry in train_entries:
  59. f.write(json.dumps(entry, ensure_ascii=False) + '\n')
  60. # Write validation set
  61. val_file = output_dir / "val.jsonl"
  62. print(f"Writing validation set to {val_file}...")
  63. with open(val_file, 'w', encoding='utf-8') as f:
  64. for entry in val_entries:
  65. f.write(json.dumps(entry, ensure_ascii=False) + '\n')
  66. print("Conversion complete!")
  67. # Show some sample entries
  68. print("\nSample train entries:")
  69. for i, entry in enumerate(train_entries[:2]):
  70. print(f"\n--- Sample {i+1} ---")
  71. print(entry["text"][:300] + "..." if len(entry["text"]) > 300 else entry["text"])
  72. if __name__ == "__main__":
  73. main()