train_0.6B.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "trackio",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # ]
  12. # ///
  13. import trackio
  14. from datasets import load_dataset
  15. from peft import LoraConfig
  16. from trl import SFTTrainer, SFTConfig
  17. # Load dataset from Hub
  18. print("Loading dataset...")
  19. dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
  20. print(f"Loaded {len(dataset)} examples")
  21. # Create train/eval split
  22. dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
  23. train_dataset = dataset_split["train"]
  24. eval_dataset = dataset_split["test"]
  25. print(f"Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
  26. # Training configuration
  27. config = SFTConfig(
  28. output_dir="qmd-query-expansion-0.6B",
  29. push_to_hub=True,
  30. hub_model_id="tobil/qmd-query-expansion-0.6B",
  31. hub_strategy="every_save",
  32. # Training parameters
  33. num_train_epochs=3,
  34. per_device_train_batch_size=4,
  35. gradient_accumulation_steps=4,
  36. learning_rate=2e-4,
  37. max_length=512,
  38. # Logging & checkpointing
  39. logging_steps=25,
  40. save_strategy="steps",
  41. save_steps=200,
  42. save_total_limit=2,
  43. # Evaluation
  44. eval_strategy="steps",
  45. eval_steps=200,
  46. # Optimization
  47. warmup_ratio=0.1,
  48. lr_scheduler_type="cosine",
  49. bf16=True,
  50. # Monitoring
  51. report_to="trackio",
  52. project="qmd-query-expansion",
  53. run_name="qwen3-0.6B-lora",
  54. )
  55. # LoRA configuration
  56. peft_config = LoraConfig(
  57. r=16,
  58. lora_alpha=32,
  59. lora_dropout=0.05,
  60. bias="none",
  61. task_type="CAUSAL_LM",
  62. target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
  63. )
  64. # Initialize trainer
  65. print("Initializing trainer with Qwen/Qwen3-0.6B...")
  66. trainer = SFTTrainer(
  67. model="Qwen/Qwen3-0.6B",
  68. train_dataset=train_dataset,
  69. eval_dataset=eval_dataset,
  70. args=config,
  71. peft_config=peft_config,
  72. )
  73. print("Starting training...")
  74. trainer.train()
  75. print("Pushing to Hub...")
  76. trainer.push_to_hub()
  77. trackio.finish()
  78. print("Done! Model at: https://huggingface.co/tobil/qmd-query-expansion-0.6B")