grpo.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # "torch",
  12. # ]
  13. # ///
  14. """
  15. GRPO training for QMD query expansion (Qwen3-1.7B).
  16. Experimental recipe run on top of merged SFT weights. Self-contained runner:
  17. uv run experiments/grpo/grpo.py
  18. (If using HF Jobs, run this script as the job entrypoint.)
  19. """
  20. import os
  21. import sys
  22. import torch
  23. from datasets import load_dataset
  24. from huggingface_hub import login
  25. from peft import LoraConfig, PeftModel, get_peft_model
  26. from transformers import AutoModelForCausalLM, AutoTokenizer
  27. from trl import GRPOTrainer, GRPOConfig
  28. # Download eval_common.py if running as a standalone script (e.g. HF Jobs)
  29. _eval_common_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "eval_common.py")
  30. if not os.path.exists(_eval_common_path):
  31. import urllib.request
  32. _url = "https://huggingface.co/datasets/tobil/hf-cli-jobs-uv-run-scripts/resolve/main/eval_common.py"
  33. _opener = urllib.request.build_opener()
  34. _token = os.environ.get("HF_TOKEN", "")
  35. if _token:
  36. _opener.addheaders = [("Authorization", f"Bearer {_token}")]
  37. with open(_eval_common_path, "wb") as _f:
  38. _f.write(_opener.open(_url).read())
  39. sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
  40. from eval_common import QMDRewardFunction, run_eval
  41. # --- Config (inlined from experiments/grpo/grpo.yaml) ---
  42. BASE_MODEL = "Qwen/Qwen3-1.7B"
  43. SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
  44. OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
  45. DATASET = "tobil/qmd-query-expansion-train"
  46. def main():
  47. hf_token = os.environ.get("HF_TOKEN")
  48. if hf_token:
  49. login(token=hf_token)
  50. print(f"Loading tokenizer from {BASE_MODEL}...")
  51. tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
  52. if tokenizer.pad_token is None:
  53. tokenizer.pad_token = tokenizer.eos_token
  54. # Load and format dataset
  55. print(f"Loading dataset: {DATASET}...")
  56. dataset = load_dataset(DATASET, split="train")
  57. def extract_prompt(example):
  58. content = example["messages"][0]["content"]
  59. messages = [{"role": "user", "content": content}]
  60. formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  61. return {"prompt": formatted}
  62. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  63. dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
  64. print(f"Using {len(dataset)} prompts for GRPO")
  65. # Load base model, merge SFT adapter
  66. print(f"Loading base model {BASE_MODEL}...")
  67. base_model = AutoModelForCausalLM.from_pretrained(
  68. BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto",
  69. )
  70. print(f"Merging SFT adapter {SFT_MODEL}...")
  71. model = PeftModel.from_pretrained(base_model, SFT_MODEL)
  72. model = model.merge_and_unload()
  73. print("SFT adapter merged.")
  74. # Fresh LoRA for GRPO (small: rank 4, q/v only)
  75. grpo_lora = LoraConfig(
  76. r=4, lora_alpha=8, lora_dropout=0.05,
  77. bias="none", task_type="CAUSAL_LM",
  78. target_modules=["q_proj", "v_proj"],
  79. )
  80. model = get_peft_model(model, grpo_lora)
  81. model.print_trainable_parameters()
  82. config = GRPOConfig(
  83. output_dir="qmd-query-expansion-1.7B-grpo",
  84. push_to_hub=True,
  85. hub_model_id=OUTPUT_MODEL,
  86. num_generations=4,
  87. max_completion_length=200,
  88. beta=0.04, # KL regularization — prevents drift from SFT checkpoint
  89. num_train_epochs=1,
  90. per_device_train_batch_size=2,
  91. gradient_accumulation_steps=8,
  92. learning_rate=5e-7,
  93. max_grad_norm=0.5,
  94. max_steps=200,
  95. logging_steps=10,
  96. save_strategy="epoch",
  97. bf16=True,
  98. report_to="none",
  99. )
  100. print("Initializing GRPO trainer...")
  101. trainer = GRPOTrainer(
  102. model=model,
  103. processing_class=tokenizer,
  104. args=config,
  105. train_dataset=dataset,
  106. reward_funcs=[QMDRewardFunction()],
  107. )
  108. print("Starting GRPO training...")
  109. trainer.train()
  110. print("Pushing to Hub...")
  111. trainer.push_to_hub()
  112. print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
  113. # --- Automatic evaluation ---
  114. print("\nStarting automatic evaluation...")
  115. trainer.model.eval()
  116. run_eval(trainer.model, tokenizer, "grpo")
  117. if __name__ == "__main__":
  118. main()