train_1.7B.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "trackio",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # ]
  12. # ///
  13. import trackio
  14. from datasets import load_dataset
  15. from peft import LoraConfig
  16. from trl import SFTTrainer, SFTConfig
  17. # Load dataset from Hub
  18. print("Loading dataset...")
  19. dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
  20. print(f"Loaded {len(dataset)} examples")
  21. # Create train/eval split
  22. dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
  23. train_dataset = dataset_split["train"]
  24. eval_dataset = dataset_split["test"]
  25. print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
  26. # Training configuration
  27. config = SFTConfig(
  28. output_dir="qmd-query-expansion-1.7B",
  29. push_to_hub=True,
  30. hub_model_id="tobil/qmd-query-expansion-1.7B",
  31. hub_strategy="every_save",
  32. # Training parameters - slightly smaller batch for larger model
  33. num_train_epochs=3,
  34. per_device_train_batch_size=2,
  35. gradient_accumulation_steps=8,
  36. learning_rate=2e-4,
  37. max_length=512,
  38. # Logging & checkpointing
  39. logging_steps=25,
  40. save_strategy="steps",
  41. save_steps=200,
  42. save_total_limit=2,
  43. # Evaluation
  44. eval_strategy="steps",
  45. eval_steps=200,
  46. # Optimization
  47. warmup_ratio=0.1,
  48. lr_scheduler_type="cosine",
  49. bf16=True,
  50. gradient_checkpointing=True, # Save memory for larger model
  51. # Monitoring
  52. report_to="trackio",
  53. project="qmd-query-expansion",
  54. run_name="qwen3-1.7B-lora",
  55. )
  56. # LoRA configuration
  57. peft_config = LoraConfig(
  58. r=16,
  59. lora_alpha=32,
  60. lora_dropout=0.05,
  61. bias="none",
  62. task_type="CAUSAL_LM",
  63. target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
  64. )
  65. # Initialize trainer
  66. print("Initializing trainer with Qwen/Qwen3-1.7B...")
  67. trainer = SFTTrainer(
  68. model="Qwen/Qwen3-1.7B",
  69. train_dataset=train_dataset,
  70. eval_dataset=eval_dataset,
  71. args=config,
  72. peft_config=peft_config,
  73. )
  74. print("Starting training...")
  75. trainer.train()
  76. print("Pushing to Hub...")
  77. trainer.push_to_hub()
  78. trackio.finish()
  79. print("Done! Model at: https://huggingface.co/tobil/qmd-query-expansion-1.7B")