grpo.yaml 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. # GRPO Training Config for QMD Query Expansion
  2. # Target: Qwen3-1.7B, trained on top of merged SFT weights
  3. #
  4. # Usage: uv run train.py grpo --config configs/grpo.yaml
  5. #
  6. # The reward function (reward.py) scores expansions on format compliance,
  7. # diversity, hyde quality, content quality, and named entity preservation.
  8. # beta > 0 is critical to prevent drift from the SFT checkpoint.
  9. model:
  10. base: "Qwen/Qwen3-1.7B"
  11. sft: "outputs/sft" # Use local SFT output (or HF path if uploaded)
  12. output: "outputs/grpo" # Local training output (push to HF manually after eval)
  13. push_to_hub: false
  14. torch_dtype: "bfloat16"
  15. load_in_4bit: false
  16. load_in_8bit: false
  17. dataset:
  18. # Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
  19. # HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
  20. name: "data/train/"
  21. prompt_field: "messages"
  22. max_samples: 1000
  23. training:
  24. epochs: 1
  25. batch_size: 2
  26. gradient_accumulation_steps: 8
  27. learning_rate: 0.0000005
  28. max_grad_norm: 0.5
  29. max_steps: 200
  30. grpo:
  31. num_generations: 4
  32. max_completion_length: 200
  33. beta: 0.04 # KL regularization - prevents drift from SFT checkpoint
  34. lora:
  35. rank: 4
  36. alpha: 8
  37. dropout: 0.05
  38. target_modules:
  39. - "q_proj"
  40. - "v_proj"
  41. tracking:
  42. project: "qmd-query-expansion"
  43. run_name: "grpo-1.7B"