grpo.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # "torch",
  12. # ]
  13. # ///
  14. """
  15. GRPO training for QMD query expansion (Qwen3-1.7B).
  16. Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs:
  17. hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py
  18. """
  19. import os
  20. import sys
  21. import torch
  22. from datasets import load_dataset
  23. from huggingface_hub import login
  24. from peft import LoraConfig, PeftModel, get_peft_model
  25. from transformers import AutoModelForCausalLM, AutoTokenizer
  26. from trl import GRPOTrainer, GRPOConfig
  27. sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
  28. from eval_common import QMDRewardFunction, run_eval
  29. # --- Config (inlined from configs/grpo.yaml) ---
  30. BASE_MODEL = "Qwen/Qwen3-1.7B"
  31. SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
  32. OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
  33. DATASET = "tobil/qmd-query-expansion-train-v2"
  34. def main():
  35. hf_token = os.environ.get("HF_TOKEN")
  36. if hf_token:
  37. login(token=hf_token)
  38. print(f"Loading tokenizer from {BASE_MODEL}...")
  39. tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
  40. if tokenizer.pad_token is None:
  41. tokenizer.pad_token = tokenizer.eos_token
  42. # Load and format dataset
  43. print(f"Loading dataset: {DATASET}...")
  44. dataset = load_dataset(DATASET, split="train")
  45. def extract_prompt(example):
  46. content = example["messages"][0]["content"]
  47. messages = [{"role": "user", "content": content}]
  48. formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  49. return {"prompt": formatted}
  50. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  51. dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
  52. print(f"Using {len(dataset)} prompts for GRPO")
  53. # Load base model, merge SFT adapter
  54. print(f"Loading base model {BASE_MODEL}...")
  55. base_model = AutoModelForCausalLM.from_pretrained(
  56. BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto",
  57. )
  58. print(f"Merging SFT adapter {SFT_MODEL}...")
  59. model = PeftModel.from_pretrained(base_model, SFT_MODEL)
  60. model = model.merge_and_unload()
  61. print("SFT adapter merged.")
  62. # Fresh LoRA for GRPO (small: rank 4, q/v only)
  63. grpo_lora = LoraConfig(
  64. r=4, lora_alpha=8, lora_dropout=0.05,
  65. bias="none", task_type="CAUSAL_LM",
  66. target_modules=["q_proj", "v_proj"],
  67. )
  68. model = get_peft_model(model, grpo_lora)
  69. model.print_trainable_parameters()
  70. config = GRPOConfig(
  71. output_dir="qmd-query-expansion-1.7B-grpo",
  72. push_to_hub=True,
  73. hub_model_id=OUTPUT_MODEL,
  74. num_generations=4,
  75. max_completion_length=200,
  76. beta=0.04, # KL regularization — prevents drift from SFT checkpoint
  77. num_train_epochs=1,
  78. per_device_train_batch_size=2,
  79. gradient_accumulation_steps=8,
  80. learning_rate=5e-7,
  81. max_grad_norm=0.5,
  82. max_steps=200,
  83. logging_steps=10,
  84. save_strategy="epoch",
  85. bf16=True,
  86. report_to="none",
  87. )
  88. print("Initializing GRPO trainer...")
  89. trainer = GRPOTrainer(
  90. model=model,
  91. processing_class=tokenizer,
  92. args=config,
  93. train_dataset=dataset,
  94. reward_funcs=[QMDRewardFunction()],
  95. )
  96. print("Starting GRPO training...")
  97. trainer.train()
  98. print("Pushing to Hub...")
  99. trainer.push_to_hub()
  100. print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
  101. # --- Automatic evaluation ---
  102. print("\nStarting automatic evaluation...")
  103. trainer.model.eval()
  104. run_eval(trainer.model, tokenizer, "grpo")
  105. if __name__ == "__main__":
  106. main()