grpo.yaml 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # GRPO Training Config for QMD Query Expansion
  2. # Target: Qwen3-1.7B, trained on top of merged SFT weights
  3. #
  4. # Usage: uv run train.py grpo --config experiments/grpo/grpo.yaml
  5. #
  6. # The reward function (reward.py) scores expansions on format compliance,
  7. # diversity, hyde quality, content quality, and named entity preservation.
  8. # beta > 0 is critical to prevent drift from the SFT checkpoint.
  9. model:
  10. base: "Qwen/Qwen3-1.7B"
  11. sft: "outputs/sft" # Use local SFT output (or HF path if uploaded)
  12. output: "outputs/grpo" # Local training output (push to HF manually after eval)
  13. push_to_hub: false
  14. torch_dtype: "bfloat16"
  15. load_in_4bit: false
  16. load_in_8bit: false
  17. dataset:
  18. # Local: run `uv run dataset/prepare_data.py` first, then use "data/train/"
  19. # HuggingFace: use "tobil/qmd-query-expansion-train" (already prepared)
  20. name: "data/train/"
  21. prompt_field: "messages"
  22. max_samples: 1000
  23. training:
  24. epochs: 1
  25. batch_size: 2
  26. gradient_accumulation_steps: 8
  27. learning_rate: 0.0000005
  28. max_grad_norm: 0.5
  29. max_steps: 200
  30. # Save checkpoints every 30 minutes
  31. save_interval_minutes: 30
  32. # Fallback time-step save cadence if needed (not used for wall-clock mode)
  33. save_steps: 50
  34. grpo:
  35. num_generations: 4
  36. max_completion_length: 200
  37. beta: 0.04 # KL regularization - prevents drift from SFT checkpoint
  38. lora:
  39. rank: 4
  40. alpha: 8
  41. dropout: 0.05
  42. target_modules:
  43. - "q_proj"
  44. - "v_proj"
  45. tracking:
  46. project: "qmd-query-expansion"
  47. run_name: "grpo-1.7B"