grpo.yaml 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # GRPO Training Config for QMD Query Expansion
  2. # Target: Qwen3-1.7B, trained on top of merged SFT weights
  3. #
  4. # Usage: uv run train.py grpo --config configs/grpo.yaml
  5. #
  6. # The reward function (reward.py) scores expansions on format compliance,
  7. # diversity, hyde quality, content quality, and named entity preservation.
  8. # beta > 0 is critical to prevent drift from the SFT checkpoint.
  9. model:
  10. base: "Qwen/Qwen3-1.7B"
  11. sft: "outputs/sft" # Use local SFT output (or HF path if uploaded)
  12. output: "outputs/grpo" # Local training output (push to HF manually after eval)
  13. dataset:
  14. # Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
  15. # HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
  16. name: "data/train/"
  17. prompt_field: "messages"
  18. max_samples: 1000
  19. training:
  20. epochs: 1
  21. batch_size: 2
  22. gradient_accumulation_steps: 8
  23. learning_rate: 5e-7
  24. max_grad_norm: 0.5
  25. max_steps: 200
  26. grpo:
  27. num_generations: 4
  28. max_completion_length: 200
  29. beta: 0.04 # KL regularization - prevents drift from SFT checkpoint
  30. lora:
  31. rank: 4
  32. alpha: 8
  33. dropout: 0.05
  34. target_modules:
  35. - "q_proj"
  36. - "v_proj"
  37. tracking:
  38. project: "qmd-query-expansion"
  39. run_name: "grpo-1.7B"