grpo.yaml 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # GRPO Training Config for QMD Query Expansion
  2. # Target: Qwen3-1.7B, trained on top of merged SFT weights
  3. #
  4. # Usage: uv run train.py grpo --config configs/grpo.yaml
  5. #
  6. # The reward function (reward.py) scores expansions on format compliance,
  7. # diversity, hyde quality, content quality, and named entity preservation.
  8. # beta > 0 is critical to prevent drift from the SFT checkpoint.
  9. model:
  10. base: "Qwen/Qwen3-1.7B"
  11. sft: "outputs/sft" # Use local SFT output (or HF path if uploaded)
  12. output: "outputs/grpo" # Local training output (push to HF manually after eval)
  13. dataset:
  14. name: "tobil/qmd-query-expansion-train-v2"
  15. prompt_field: "messages"
  16. max_samples: 1000
  17. training:
  18. epochs: 1
  19. batch_size: 2
  20. gradient_accumulation_steps: 8
  21. learning_rate: 5e-7
  22. max_grad_norm: 0.5
  23. max_steps: 200
  24. grpo:
  25. num_generations: 4
  26. max_completion_length: 200
  27. beta: 0.04 # KL regularization - prevents drift from SFT checkpoint
  28. lora:
  29. rank: 4
  30. alpha: 8
  31. dropout: 0.05
  32. target_modules:
  33. - "q_proj"
  34. - "v_proj"
  35. tracking:
  36. project: "qmd-query-expansion"
  37. run_name: "grpo-1.7B"