sft.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # "torch",
  12. # ]
  13. # ///
  14. """
  15. SFT training for QMD query expansion (Qwen3-1.7B).
  16. Self-contained script for HuggingFace Jobs:
  17. hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 2h jobs/sft.py
  18. """
  19. import os
  20. import sys
  21. from huggingface_hub import login
  22. # --- Config (inlined from configs/sft.yaml) ---
  23. BASE_MODEL = "Qwen/Qwen3-1.7B"
  24. OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
  25. DATASET = "tobil/qmd-query-expansion-train-v2"
  26. hf_token = os.environ.get("HF_TOKEN")
  27. if hf_token:
  28. login(token=hf_token)
  29. from datasets import load_dataset
  30. from peft import LoraConfig
  31. from transformers import AutoTokenizer
  32. from trl import SFTTrainer, SFTConfig
  33. # Load and split dataset
  34. print(f"Loading dataset: {DATASET}...")
  35. dataset = load_dataset(DATASET, split="train")
  36. print(f"Dataset loaded: {len(dataset)} examples")
  37. split = dataset.train_test_split(test_size=0.1, seed=42)
  38. train_dataset = split["train"]
  39. eval_dataset = split["test"]
  40. print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
  41. # SFT config
  42. config = SFTConfig(
  43. output_dir="qmd-query-expansion-1.7B-sft",
  44. push_to_hub=True,
  45. hub_model_id=OUTPUT_MODEL,
  46. hub_strategy="every_save",
  47. num_train_epochs=5,
  48. per_device_train_batch_size=4,
  49. gradient_accumulation_steps=4,
  50. learning_rate=2e-4,
  51. max_length=512,
  52. logging_steps=10,
  53. save_strategy="steps",
  54. save_steps=200,
  55. save_total_limit=2,
  56. eval_strategy="steps",
  57. eval_steps=200,
  58. warmup_ratio=0.03,
  59. lr_scheduler_type="cosine",
  60. bf16=True,
  61. report_to="none",
  62. )
  63. # LoRA: rank 16, all projection layers
  64. peft_config = LoraConfig(
  65. r=16,
  66. lora_alpha=32,
  67. lora_dropout=0.0,
  68. bias="none",
  69. task_type="CAUSAL_LM",
  70. target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
  71. )
  72. print("Initializing SFT trainer...")
  73. trainer = SFTTrainer(
  74. model=BASE_MODEL,
  75. train_dataset=train_dataset,
  76. eval_dataset=eval_dataset,
  77. args=config,
  78. peft_config=peft_config,
  79. )
  80. print("Starting SFT training...")
  81. trainer.train()
  82. print("Pushing to Hub...")
  83. trainer.push_to_hub()
  84. print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
  85. # --- Automatic evaluation ---
  86. sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
  87. from eval_common import run_eval
  88. print("\nStarting automatic evaluation...")
  89. eval_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
  90. if eval_tokenizer.pad_token is None:
  91. eval_tokenizer.pad_token = eval_tokenizer.eos_token
  92. trainer.model.eval()
  93. run_eval(trainer.model, eval_tokenizer, "sft")