| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- # /// script
- # requires-python = ">=3.10"
- # dependencies = [
- # "trl>=0.12.0",
- # "peft>=0.7.0",
- # "transformers>=4.45.0",
- # "accelerate>=0.24.0",
- # "huggingface_hub>=0.20.0",
- # "datasets",
- # "bitsandbytes",
- # "torch",
- # ]
- # ///
- """
- GRPO training for QMD query expansion (Qwen3-1.7B).
- Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs:
- hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py
- """
- import os
- import sys
- import torch
- from datasets import load_dataset
- from huggingface_hub import login
- from peft import LoraConfig, PeftModel, get_peft_model
- from transformers import AutoModelForCausalLM, AutoTokenizer
- from trl import GRPOTrainer, GRPOConfig
- # Download eval_common.py if running as a standalone script (e.g. HF Jobs)
- _eval_common_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "eval_common.py")
- if not os.path.exists(_eval_common_path):
- import urllib.request
- _url = "https://huggingface.co/datasets/tobil/hf-cli-jobs-uv-run-scripts/resolve/main/eval_common.py"
- _opener = urllib.request.build_opener()
- _token = os.environ.get("HF_TOKEN", "")
- if _token:
- _opener.addheaders = [("Authorization", f"Bearer {_token}")]
- with open(_eval_common_path, "wb") as _f:
- _f.write(_opener.open(_url).read())
- sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
- from eval_common import QMDRewardFunction, run_eval
- # --- Config (inlined from configs/grpo.yaml) ---
- BASE_MODEL = "Qwen/Qwen3-1.7B"
- SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
- OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
- DATASET = "tobil/qmd-query-expansion-train"
- def main():
- hf_token = os.environ.get("HF_TOKEN")
- if hf_token:
- login(token=hf_token)
- print(f"Loading tokenizer from {BASE_MODEL}...")
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- # Load and format dataset
- print(f"Loading dataset: {DATASET}...")
- dataset = load_dataset(DATASET, split="train")
- def extract_prompt(example):
- content = example["messages"][0]["content"]
- messages = [{"role": "user", "content": content}]
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
- return {"prompt": formatted}
- dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
- dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
- print(f"Using {len(dataset)} prompts for GRPO")
- # Load base model, merge SFT adapter
- print(f"Loading base model {BASE_MODEL}...")
- base_model = AutoModelForCausalLM.from_pretrained(
- BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto",
- )
- print(f"Merging SFT adapter {SFT_MODEL}...")
- model = PeftModel.from_pretrained(base_model, SFT_MODEL)
- model = model.merge_and_unload()
- print("SFT adapter merged.")
- # Fresh LoRA for GRPO (small: rank 4, q/v only)
- grpo_lora = LoraConfig(
- r=4, lora_alpha=8, lora_dropout=0.05,
- bias="none", task_type="CAUSAL_LM",
- target_modules=["q_proj", "v_proj"],
- )
- model = get_peft_model(model, grpo_lora)
- model.print_trainable_parameters()
- config = GRPOConfig(
- output_dir="qmd-query-expansion-1.7B-grpo",
- push_to_hub=True,
- hub_model_id=OUTPUT_MODEL,
- num_generations=4,
- max_completion_length=200,
- beta=0.04, # KL regularization — prevents drift from SFT checkpoint
- num_train_epochs=1,
- per_device_train_batch_size=2,
- gradient_accumulation_steps=8,
- learning_rate=5e-7,
- max_grad_norm=0.5,
- max_steps=200,
- logging_steps=10,
- save_strategy="epoch",
- bf16=True,
- report_to="none",
- )
- print("Initializing GRPO trainer...")
- trainer = GRPOTrainer(
- model=model,
- processing_class=tokenizer,
- args=config,
- train_dataset=dataset,
- reward_funcs=[QMDRewardFunction()],
- )
- print("Starting GRPO training...")
- trainer.train()
- print("Pushing to Hub...")
- trainer.push_to_hub()
- print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
- # --- Automatic evaluation ---
- print("\nStarting automatic evaluation...")
- trainer.model.eval()
- run_eval(trainer.model, tokenizer, "grpo")
- if __name__ == "__main__":
- main()
|