grpo.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # "torch",
  12. # ]
  13. # ///
  14. """
  15. GRPO training for QMD query expansion (Qwen3-1.7B).
  16. Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs:
  17. hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py
  18. """
  19. import os
  20. import sys
  21. import torch
  22. from datasets import load_dataset
  23. from huggingface_hub import login
  24. from peft import LoraConfig, PeftModel, get_peft_model
  25. from transformers import AutoModelForCausalLM, AutoTokenizer
  26. from trl import GRPOTrainer, GRPOConfig
  27. # Download eval_common.py if running as a standalone script (e.g. HF Jobs)
  28. _eval_common_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "eval_common.py")
  29. if not os.path.exists(_eval_common_path):
  30. import urllib.request
  31. _url = "https://huggingface.co/datasets/tobil/hf-cli-jobs-uv-run-scripts/resolve/main/eval_common.py"
  32. _opener = urllib.request.build_opener()
  33. _token = os.environ.get("HF_TOKEN", "")
  34. if _token:
  35. _opener.addheaders = [("Authorization", f"Bearer {_token}")]
  36. with open(_eval_common_path, "wb") as _f:
  37. _f.write(_opener.open(_url).read())
  38. sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
  39. from eval_common import QMDRewardFunction, run_eval
  40. # --- Config (inlined from configs/grpo.yaml) ---
  41. BASE_MODEL = "Qwen/Qwen3-1.7B"
  42. SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
  43. OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
  44. DATASET = "tobil/qmd-query-expansion-train"
  45. def main():
  46. hf_token = os.environ.get("HF_TOKEN")
  47. if hf_token:
  48. login(token=hf_token)
  49. print(f"Loading tokenizer from {BASE_MODEL}...")
  50. tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
  51. if tokenizer.pad_token is None:
  52. tokenizer.pad_token = tokenizer.eos_token
  53. # Load and format dataset
  54. print(f"Loading dataset: {DATASET}...")
  55. dataset = load_dataset(DATASET, split="train")
  56. def extract_prompt(example):
  57. content = example["messages"][0]["content"]
  58. messages = [{"role": "user", "content": content}]
  59. formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  60. return {"prompt": formatted}
  61. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  62. dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
  63. print(f"Using {len(dataset)} prompts for GRPO")
  64. # Load base model, merge SFT adapter
  65. print(f"Loading base model {BASE_MODEL}...")
  66. base_model = AutoModelForCausalLM.from_pretrained(
  67. BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto",
  68. )
  69. print(f"Merging SFT adapter {SFT_MODEL}...")
  70. model = PeftModel.from_pretrained(base_model, SFT_MODEL)
  71. model = model.merge_and_unload()
  72. print("SFT adapter merged.")
  73. # Fresh LoRA for GRPO (small: rank 4, q/v only)
  74. grpo_lora = LoraConfig(
  75. r=4, lora_alpha=8, lora_dropout=0.05,
  76. bias="none", task_type="CAUSAL_LM",
  77. target_modules=["q_proj", "v_proj"],
  78. )
  79. model = get_peft_model(model, grpo_lora)
  80. model.print_trainable_parameters()
  81. config = GRPOConfig(
  82. output_dir="qmd-query-expansion-1.7B-grpo",
  83. push_to_hub=True,
  84. hub_model_id=OUTPUT_MODEL,
  85. num_generations=4,
  86. max_completion_length=200,
  87. beta=0.04, # KL regularization — prevents drift from SFT checkpoint
  88. num_train_epochs=1,
  89. per_device_train_batch_size=2,
  90. gradient_accumulation_steps=8,
  91. learning_rate=5e-7,
  92. max_grad_norm=0.5,
  93. max_steps=200,
  94. logging_steps=10,
  95. save_strategy="epoch",
  96. bf16=True,
  97. report_to="none",
  98. )
  99. print("Initializing GRPO trainer...")
  100. trainer = GRPOTrainer(
  101. model=model,
  102. processing_class=tokenizer,
  103. args=config,
  104. train_dataset=dataset,
  105. reward_funcs=[QMDRewardFunction()],
  106. )
  107. print("Starting GRPO training...")
  108. trainer.train()
  109. print("Pushing to Hub...")
  110. trainer.push_to_hub()
  111. print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
  112. # --- Automatic evaluation ---
  113. print("\nStarting automatic evaluation...")
  114. trainer.model.eval()
  115. run_eval(trainer.model, tokenizer, "grpo")
  116. if __name__ == "__main__":
  117. main()