train_1.7B_v2.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "trackio",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # ]
  12. # ///
  13. """
  14. Improved Qwen3-1.7B training with best practices for larger models:
  15. - Lower learning rate (1e-4 instead of 2e-4)
  16. - Higher LoRA rank (32 instead of 16)
  17. - More epochs (5 instead of 3)
  18. - Weight decay for regularization
  19. """
  20. import trackio
  21. from datasets import load_dataset
  22. from peft import LoraConfig
  23. from trl import SFTTrainer, SFTConfig
  24. # Load dataset from Hub
  25. print("Loading dataset...")
  26. dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
  27. print(f"Loaded {len(dataset)} examples")
  28. # Create train/eval split
  29. dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
  30. train_dataset = dataset_split["train"]
  31. eval_dataset = dataset_split["test"]
  32. print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
  33. # Training configuration - optimized for larger model
  34. config = SFTConfig(
  35. output_dir="qmd-query-expansion-1.7B-v2",
  36. push_to_hub=True,
  37. hub_model_id="tobil/qmd-query-expansion-1.7B-v2",
  38. hub_strategy="every_save",
  39. # Training parameters - lower LR, more epochs for larger model
  40. num_train_epochs=5,
  41. per_device_train_batch_size=2,
  42. gradient_accumulation_steps=8,
  43. learning_rate=1e-4, # Lowered from 2e-4
  44. weight_decay=0.01, # Added regularization
  45. max_length=512,
  46. # Logging & checkpointing
  47. logging_steps=25,
  48. save_strategy="steps",
  49. save_steps=200,
  50. save_total_limit=3,
  51. # Evaluation
  52. eval_strategy="steps",
  53. eval_steps=200,
  54. # Optimization
  55. warmup_ratio=0.1,
  56. lr_scheduler_type="cosine",
  57. bf16=True,
  58. gradient_checkpointing=True,
  59. gradient_checkpointing_kwargs={"use_reentrant": False},
  60. # Monitoring
  61. report_to="trackio",
  62. project="qmd-query-expansion",
  63. run_name="qwen3-1.7B-lora-v2",
  64. )
  65. # LoRA configuration - higher rank for better learning
  66. peft_config = LoraConfig(
  67. r=32, # Increased from 16
  68. lora_alpha=64, # Increased from 32 (2x rank)
  69. lora_dropout=0.05,
  70. bias="none",
  71. task_type="CAUSAL_LM",
  72. target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
  73. )
  74. # Initialize trainer
  75. print("Initializing trainer with Qwen/Qwen3-1.7B...")
  76. trainer = SFTTrainer(
  77. model="Qwen/Qwen3-1.7B",
  78. train_dataset=train_dataset,
  79. eval_dataset=eval_dataset,
  80. args=config,
  81. peft_config=peft_config,
  82. )
  83. print("Starting training...")
  84. trainer.train()
  85. print("Pushing to Hub...")
  86. trainer.push_to_hub()
  87. trackio.finish()
  88. print("Done! Model at: https://huggingface.co/tobil/qmd-query-expansion-1.7B-v2")