|
@@ -417,8 +417,6 @@ def cmd_grpo(args):
|
|
|
)
|
|
)
|
|
|
print("To run experimental GRPO, use:")
|
|
print("To run experimental GRPO, use:")
|
|
|
print(" cd finetune && uv run python experiments/grpo/grpo.py")
|
|
print(" cd finetune && uv run python experiments/grpo/grpo.py")
|
|
|
- print("Or, if you have local config wiring ready:")
|
|
|
|
|
- print(" uv run train.py grpo --config experiments/grpo/grpo.yaml")
|
|
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
import torch
|
|
import torch
|
|
@@ -664,22 +662,9 @@ Examples:
|
|
|
"--dry-run", action="store_true", help="Print config and exit"
|
|
"--dry-run", action="store_true", help="Print config and exit"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- grpo_parser = sub.add_parser(
|
|
|
|
|
- "grpo",
|
|
|
|
|
- help="Experimental: GRPO reinforcement learning (moved to experiments/grpo/)",
|
|
|
|
|
- )
|
|
|
|
|
- grpo_parser.add_argument("--config", required=True, help="Path to GRPO config YAML")
|
|
|
|
|
- grpo_parser.add_argument(
|
|
|
|
|
- "--dry-run", action="store_true", help="Print config, test reward, and exit"
|
|
|
|
|
- )
|
|
|
|
|
-
|
|
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
- if args.stage == "sft":
|
|
|
|
|
- cmd_sft(args)
|
|
|
|
|
- elif args.stage == "grpo":
|
|
|
|
|
- cmd_grpo(args)
|
|
|
|
|
-
|
|
|
|
|
|
|
+ cmd_sft(args)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
main()
|