train.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "datasets>=2.14.0",
  9. # "trackio",
  10. # "pyyaml",
  11. # ]
  12. # ///
  13. """
  14. SFT Training for QMD Query Expansion.
  15. Usage:
  16. uv run train.py --config configs/sft_v4.yaml
  17. uv run train.py --config configs/sft_v4.yaml --dry-run
  18. """
  19. import argparse
  20. import yaml
  21. import trackio
  22. from datasets import load_dataset
  23. from peft import LoraConfig
  24. from trl import SFTTrainer, SFTConfig
  25. def main():
  26. parser = argparse.ArgumentParser(description="Train QMD query expansion model")
  27. parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
  28. parser.add_argument("--dry-run", action="store_true", help="Print config and exit")
  29. args = parser.parse_args()
  30. # Load config
  31. with open(args.config) as f:
  32. cfg = yaml.safe_load(f)
  33. if args.dry_run:
  34. print("SFT Training Configuration:")
  35. print(yaml.dump(cfg, default_flow_style=False))
  36. return
  37. print(f"Loading dataset: {cfg['dataset']['name']}...")
  38. dataset = load_dataset(cfg["dataset"]["name"], split=cfg["dataset"]["split"])
  39. print(f"Dataset loaded: {len(dataset)} examples")
  40. # Create train/eval split
  41. print("Creating train/eval split...")
  42. split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
  43. train_dataset = split["train"]
  44. eval_dataset = split["test"]
  45. print(f" Train: {len(train_dataset)} examples")
  46. print(f" Eval: {len(eval_dataset)} examples")
  47. # Training configuration
  48. config = SFTConfig(
  49. output_dir=cfg["model"]["output"].split("/")[-1],
  50. push_to_hub=True,
  51. hub_model_id=cfg["model"]["output"],
  52. hub_strategy="every_save",
  53. num_train_epochs=cfg["training"]["epochs"],
  54. per_device_train_batch_size=cfg["training"]["batch_size"],
  55. gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
  56. learning_rate=cfg["training"]["learning_rate"],
  57. max_length=cfg["training"]["max_length"],
  58. logging_steps=10,
  59. save_strategy="steps",
  60. save_steps=200,
  61. save_total_limit=2,
  62. eval_strategy="steps",
  63. eval_steps=200,
  64. warmup_ratio=cfg["training"]["warmup_ratio"],
  65. lr_scheduler_type=cfg["training"]["lr_scheduler"],
  66. report_to="trackio",
  67. project=cfg["tracking"]["project"],
  68. run_name=cfg["tracking"]["run_name"],
  69. )
  70. # LoRA configuration
  71. peft_config = LoraConfig(
  72. r=cfg["lora"]["rank"],
  73. lora_alpha=cfg["lora"]["alpha"],
  74. lora_dropout=cfg["lora"]["dropout"],
  75. bias="none",
  76. task_type="CAUSAL_LM",
  77. target_modules=cfg["lora"]["target_modules"],
  78. )
  79. # Initialize and train
  80. print("Initializing trainer...")
  81. trainer = SFTTrainer(
  82. model=cfg["model"]["base"],
  83. train_dataset=train_dataset,
  84. eval_dataset=eval_dataset,
  85. args=config,
  86. peft_config=peft_config,
  87. )
  88. print("Starting training...")
  89. trainer.train()
  90. print("Pushing to Hub...")
  91. trainer.push_to_hub()
  92. trackio.finish()
  93. print(f"Complete! Model at: https://huggingface.co/{cfg['model']['output']}")
  94. if __name__ == "__main__":
  95. main()