| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- # GRPO Training Config for QMD Query Expansion
- # Target: Qwen3-1.7B, trained on top of merged SFT weights
- #
- # Usage: uv run train.py grpo --config experiments/grpo/grpo.yaml
- #
- # The reward function (reward.py) scores expansions on format compliance,
- # diversity, hyde quality, content quality, and named entity preservation.
- # beta > 0 is critical to prevent drift from the SFT checkpoint.
- model:
- base: "Qwen/Qwen3-1.7B"
- sft: "outputs/sft" # Use local SFT output (or HF path if uploaded)
- output: "outputs/grpo" # Local training output (push to HF manually after eval)
- push_to_hub: false
- torch_dtype: "bfloat16"
- load_in_4bit: false
- load_in_8bit: false
- dataset:
- # Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
- # HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
- name: "data/train/"
- prompt_field: "messages"
- max_samples: 1000
- training:
- epochs: 1
- batch_size: 2
- gradient_accumulation_steps: 8
- learning_rate: 0.0000005
- max_grad_norm: 0.5
- max_steps: 200
- # Save checkpoints every 30 minutes
- save_interval_minutes: 30
- # Fallback time-step save cadence if needed (not used for wall-clock mode)
- save_steps: 50
- grpo:
- num_generations: 4
- max_completion_length: 200
- beta: 0.04 # KL regularization - prevents drift from SFT checkpoint
- lora:
- rank: 4
- alpha: 8
- dropout: 0.05
- target_modules:
- - "q_proj"
- - "v_proj"
- tracking:
- project: "qmd-query-expansion"
- run_name: "grpo-1.7B"
|