# /// script # requires-python = ">=3.10" # dependencies = [ # "trl>=0.12.0", # "peft>=0.7.0", # "transformers>=4.45.0", # "accelerate>=0.24.0", # "huggingface_hub>=0.20.0", # "datasets", # "bitsandbytes", # "torch", # ] # /// """ GRPO training for QMD query expansion (Qwen3-1.7B). Experimental recipe run on top of merged SFT weights. Self-contained runner: uv run experiments/grpo/grpo.py (If using HF Jobs, run this script as the job entrypoint.) """ import os import sys import torch from datasets import load_dataset from huggingface_hub import login from peft import LoraConfig, PeftModel, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer from trl import GRPOTrainer, GRPOConfig # Download eval_common.py if running as a standalone script (e.g. HF Jobs) _eval_common_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "eval_common.py") if not os.path.exists(_eval_common_path): import urllib.request _url = "https://huggingface.co/datasets/tobil/hf-cli-jobs-uv-run-scripts/resolve/main/eval_common.py" _opener = urllib.request.build_opener() _token = os.environ.get("HF_TOKEN", "") if _token: _opener.addheaders = [("Authorization", f"Bearer {_token}")] with open(_eval_common_path, "wb") as _f: _f.write(_opener.open(_url).read()) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from eval_common import QMDRewardFunction, run_eval # --- Config (inlined from experiments/grpo/grpo.yaml) --- BASE_MODEL = "Qwen/Qwen3-1.7B" SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft" OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo" DATASET = "tobil/qmd-query-expansion-train" def main(): hf_token = os.environ.get("HF_TOKEN") if hf_token: login(token=hf_token) print(f"Loading tokenizer from {BASE_MODEL}...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load and format dataset print(f"Loading dataset: {DATASET}...") dataset = load_dataset(DATASET, split="train") def extract_prompt(example): content = example["messages"][0]["content"] messages = [{"role": "user", "content": content}] formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return {"prompt": formatted} dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names) dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset)))) print(f"Using {len(dataset)} prompts for GRPO") # Load base model, merge SFT adapter print(f"Loading base model {BASE_MODEL}...") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto", ) print(f"Merging SFT adapter {SFT_MODEL}...") model = PeftModel.from_pretrained(base_model, SFT_MODEL) model = model.merge_and_unload() print("SFT adapter merged.") # Fresh LoRA for GRPO (small: rank 4, q/v only) grpo_lora = LoraConfig( r=4, lora_alpha=8, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], ) model = get_peft_model(model, grpo_lora) model.print_trainable_parameters() config = GRPOConfig( output_dir="qmd-query-expansion-1.7B-grpo", push_to_hub=True, hub_model_id=OUTPUT_MODEL, num_generations=4, max_completion_length=200, beta=0.04, # KL regularization — prevents drift from SFT checkpoint num_train_epochs=1, per_device_train_batch_size=2, gradient_accumulation_steps=8, learning_rate=5e-7, max_grad_norm=0.5, max_steps=200, logging_steps=10, save_strategy="epoch", bf16=True, report_to="none", ) print("Initializing GRPO trainer...") trainer = GRPOTrainer( model=model, processing_class=tokenizer, args=config, train_dataset=dataset, reward_funcs=[QMDRewardFunction()], ) print("Starting GRPO training...") trainer.train() print("Pushing to Hub...") trainer.push_to_hub() print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}") # --- Automatic evaluation --- print("\nStarting automatic evaluation...") trainer.model.eval() run_eval(trainer.model, tokenizer, "grpo") if __name__ == "__main__": main()