# GRPO Training Config for QMD Query Expansion # Target: Qwen3-1.7B, trained on top of merged SFT weights # # Usage: uv run train.py grpo --config configs/grpo.yaml # # The reward function (reward.py) scores expansions on format compliance, # diversity, hyde quality, content quality, and named entity preservation. # beta > 0 is critical to prevent drift from the SFT checkpoint. model: base: "Qwen/Qwen3-1.7B" sft: "outputs/sft" # Use local SFT output (or HF path if uploaded) output: "outputs/grpo" # Local training output (push to HF manually after eval) push_to_hub: false torch_dtype: "bfloat16" load_in_4bit: false load_in_8bit: false dataset: # Local: run `uv run dataset/prepare_data.py` first, then use "data/train/" # HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared) name: "data/train/" prompt_field: "messages" max_samples: 1000 training: epochs: 1 batch_size: 2 gradient_accumulation_steps: 8 learning_rate: 0.0000005 max_grad_norm: 0.5 max_steps: 200 grpo: num_generations: 4 max_completion_length: 200 beta: 0.04 # KL regularization - prevents drift from SFT checkpoint lora: rank: 4 alpha: 8 dropout: 0.05 target_modules: - "q_proj" - "v_proj" tracking: project: "qmd-query-expansion" run_name: "grpo-1.7B"