train.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "torch",
  5. # "trl>=0.12.0",
  6. # "peft>=0.7.0",
  7. # "transformers>=4.45.0",
  8. # "accelerate>=0.24.0",
  9. # "huggingface_hub>=0.20.0",
  10. # "trackio",
  11. # "nvidia-ml-py",
  12. # "datasets",
  13. # "bitsandbytes",
  14. # "pyyaml",
  15. # "gguf",
  16. # ]
  17. # ///
  18. """
  19. Unified training script for QMD query expansion models.
  20. Supports two stages:
  21. sft - Supervised fine-tuning on labeled examples
  22. grpo - Group Relative Policy Optimization (RL) on top of merged SFT weights
  23. Usage:
  24. uv run train.py sft --config configs/sft.yaml
  25. uv run train.py grpo --config configs/grpo.yaml
  26. uv run train.py grpo --config configs/grpo.yaml --dry-run
  27. """
  28. import argparse
  29. import os
  30. import subprocess
  31. import sys
  32. import time
  33. from pathlib import Path
  34. import yaml
  35. from transformers import TrainerCallback
  36. def export_gguf(model, tokenizer, output_dir: str, model_name: str):
  37. """Export model to GGUF at Q4_K_M, Q6_K, Q8_0 quantizations."""
  38. import shutil
  39. import tempfile
  40. output_path = Path(output_dir)
  41. gguf_dir = output_path / "gguf"
  42. gguf_dir.mkdir(exist_ok=True)
  43. # Save merged model to temp dir
  44. print("Saving merged model for GGUF conversion...")
  45. with tempfile.TemporaryDirectory() as tmp:
  46. merged_path = Path(tmp) / "merged"
  47. model.save_pretrained(merged_path, safe_serialization=True)
  48. tokenizer.save_pretrained(merged_path)
  49. # Setup llama.cpp
  50. llama_cpp = Path("/tmp/llama.cpp")
  51. if not llama_cpp.exists():
  52. print("Cloning llama.cpp...")
  53. subprocess.run(
  54. [
  55. "git",
  56. "clone",
  57. "--depth",
  58. "1",
  59. "https://github.com/ggerganov/llama.cpp.git",
  60. str(llama_cpp),
  61. ],
  62. capture_output=True,
  63. )
  64. subprocess.run(
  65. [
  66. sys.executable,
  67. "-m",
  68. "pip",
  69. "install",
  70. "-q",
  71. "-r",
  72. str(llama_cpp / "requirements.txt"),
  73. ],
  74. capture_output=True,
  75. )
  76. # Build quantize tool if needed
  77. quantize_bin = llama_cpp / "build" / "bin" / "llama-quantize"
  78. if not quantize_bin.exists():
  79. print("Building llama-quantize...")
  80. build_dir = llama_cpp / "build"
  81. build_dir.mkdir(exist_ok=True)
  82. subprocess.run(
  83. [
  84. "cmake",
  85. "-B",
  86. str(build_dir),
  87. "-S",
  88. str(llama_cpp),
  89. "-DGGML_CUDA=OFF",
  90. ],
  91. capture_output=True,
  92. )
  93. subprocess.run(
  94. [
  95. "cmake",
  96. "--build",
  97. str(build_dir),
  98. "--target",
  99. "llama-quantize",
  100. "-j",
  101. "4",
  102. ],
  103. capture_output=True,
  104. )
  105. # Convert to FP16 first
  106. fp16_file = gguf_dir / f"{model_name}-f16.gguf"
  107. print(f"Converting to FP16: {fp16_file}")
  108. log_out = Path("/tmp/qmd-gguf-convert.log")
  109. log_err = Path("/tmp/qmd-gguf-convert.err")
  110. with log_out.open("w") as out_f, log_err.open("w") as err_f:
  111. result = subprocess.run(
  112. [
  113. sys.executable,
  114. str(llama_cpp / "convert_hf_to_gguf.py"),
  115. str(merged_path),
  116. "--outfile",
  117. str(fp16_file),
  118. "--outtype",
  119. "f16",
  120. ],
  121. stdout=out_f,
  122. stderr=err_f,
  123. text=True,
  124. )
  125. if result.returncode != 0:
  126. print("GGUF conversion failed.")
  127. print(f"stdout: {log_out}")
  128. print(f"stderr: {log_err}")
  129. return
  130. # Quantize to 4, 6, 8 bit
  131. for quant_type in ["Q4_K_M", "Q6_K", "Q8_0"]:
  132. out_file = gguf_dir / f"{model_name}-{quant_type.lower()}.gguf"
  133. print(f"Quantizing {quant_type}: {out_file}")
  134. subprocess.run(
  135. [str(quantize_bin), str(fp16_file), str(out_file), quant_type],
  136. capture_output=True,
  137. )
  138. if out_file.exists():
  139. size_mb = out_file.stat().st_size / (1024 * 1024)
  140. print(f" {quant_type}: {size_mb:.1f} MB")
  141. # Remove FP16 to save space
  142. if fp16_file.exists():
  143. fp16_file.unlink()
  144. print(f"GGUF files saved to: {gguf_dir}")
  145. class TimedSaveCallback(TrainerCallback):
  146. """Trigger periodic checkpoint saves based on elapsed wall-clock time."""
  147. def __init__(self, interval_minutes: float):
  148. self.interval_seconds = float(interval_minutes) * 60.0
  149. self.last_save_time = time.time()
  150. def on_step_end(self, args, state, control, **kwargs):
  151. if not getattr(state, "is_world_process_zero", False):
  152. return control
  153. now = time.time()
  154. if now - self.last_save_time >= self.interval_seconds:
  155. control.should_save = True
  156. self.last_save_time = now
  157. return control
  158. def run_eval(model_path: str) -> float | None:
  159. """Run eval.py on the trained model and return average score."""
  160. print("\n" + "=" * 60)
  161. print("Running evaluation...")
  162. print("=" * 60)
  163. eval_script = Path(__file__).parent / "eval.py"
  164. result = subprocess.run(
  165. [sys.executable, str(eval_script), model_path],
  166. cwd=str(Path(__file__).parent),
  167. capture_output=True,
  168. text=True,
  169. )
  170. if result.stdout:
  171. print(result.stdout, end="")
  172. if result.stderr:
  173. print(result.stderr, end="")
  174. avg = None
  175. for line in (result.stdout or "").splitlines():
  176. if line.strip().startswith("Average:"):
  177. try:
  178. avg = float(line.split("Average:", 1)[1].split("%", 1)[0].strip())
  179. except ValueError:
  180. pass
  181. break
  182. return avg
  183. def cmd_sft(args):
  184. """Run supervised fine-tuning."""
  185. import torch
  186. from datasets import load_dataset
  187. import torch.distributed as dist
  188. from peft import LoraConfig
  189. from transformers import AutoTokenizer, AutoModelForCausalLM
  190. from transformers.utils import logging as hf_logging
  191. hf_logging.set_verbosity_error()
  192. from trl import SFTTrainer, SFTConfig
  193. with open(args.config) as f:
  194. cfg = yaml.safe_load(f)
  195. os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
  196. if args.dry_run:
  197. print("SFT Training Configuration:")
  198. print(yaml.dump(cfg, default_flow_style=False))
  199. return
  200. dataset_name = cfg["dataset"]["name"]
  201. print(f"Loading dataset: {dataset_name}...")
  202. # Support local JSONL files and glob patterns
  203. if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
  204. from pathlib import Path
  205. import glob
  206. # Handle glob patterns like "data/*.jsonl"
  207. if "*" in dataset_name:
  208. jsonl_files = sorted(glob.glob(dataset_name))
  209. if not jsonl_files:
  210. raise ValueError(f"No files found matching: {dataset_name}")
  211. print(
  212. f" Found {len(jsonl_files)} JSONL files: {[Path(f).name for f in jsonl_files]}"
  213. )
  214. dataset = load_dataset("json", data_files=jsonl_files, split="train")
  215. else:
  216. data_path = Path(dataset_name)
  217. if data_path.is_dir():
  218. train_file = data_path / "train.jsonl"
  219. dataset = load_dataset(
  220. "json", data_files=str(train_file), split="train"
  221. )
  222. else:
  223. dataset = load_dataset("json", data_files=dataset_name, split="train")
  224. else:
  225. dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
  226. print(f"Dataset loaded: {len(dataset)} examples")
  227. dataset = dataset.shuffle(seed=42)
  228. split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
  229. train_dataset = split["train"]
  230. eval_dataset = split["test"]
  231. print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
  232. # Check if output looks like a HF Hub path (contains /)
  233. output_name = cfg["model"]["output"]
  234. push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
  235. if "push_to_hub" in cfg["model"]:
  236. push_to_hub = bool(cfg["model"]["push_to_hub"])
  237. output_dir = output_name.split("/")[-1] if push_to_hub else output_name
  238. report_to = "none"
  239. if os.environ.get("HF_TOKEN"):
  240. try:
  241. import trackio # noqa: F401
  242. report_to = "trackio"
  243. except Exception:
  244. print("Trackio not installed; disabling tracking.")
  245. tracking = cfg.get("tracking", {})
  246. if report_to == "trackio":
  247. project = tracking.get("project")
  248. if project:
  249. os.environ.setdefault("TRACKIO_PROJECT", project)
  250. run_name = tracking.get("run_name")
  251. if run_name and "{" in run_name:
  252. from datetime import datetime
  253. now = datetime.now()
  254. run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
  255. "{time}", now.strftime("%H:%M")
  256. )
  257. save_interval_minutes = cfg["training"].get("save_interval_minutes")
  258. save_steps = cfg["training"].get("save_steps", 200)
  259. save_total_limit = cfg["training"].get("save_total_limit", 2)
  260. if save_interval_minutes:
  261. # Prefer wall-clock checkpointing (for long jobs / preemption safety)
  262. save_steps = max(save_steps, 10_000_000)
  263. callbacks = []
  264. if save_interval_minutes:
  265. try:
  266. interval_value = float(save_interval_minutes)
  267. except (TypeError, ValueError):
  268. interval_value = None
  269. if interval_value and interval_value > 0:
  270. callbacks.append(TimedSaveCallback(interval_value))
  271. config = SFTConfig(
  272. output_dir=output_dir,
  273. push_to_hub=push_to_hub,
  274. hub_model_id=output_name if push_to_hub else None,
  275. hub_strategy="every_save" if push_to_hub else "end",
  276. num_train_epochs=cfg["training"]["epochs"],
  277. per_device_train_batch_size=cfg["training"]["batch_size"],
  278. gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
  279. learning_rate=cfg["training"]["learning_rate"],
  280. max_length=cfg["training"]["max_length"],
  281. logging_steps=10,
  282. save_strategy="steps",
  283. save_steps=save_steps,
  284. save_total_limit=save_total_limit,
  285. eval_strategy="steps",
  286. eval_steps=cfg["training"].get("eval_steps", 200),
  287. warmup_ratio=cfg["training"]["warmup_ratio"],
  288. lr_scheduler_type=cfg["training"]["lr_scheduler"],
  289. ddp_find_unused_parameters=cfg["training"].get(
  290. "ddp_find_unused_parameters", False
  291. ),
  292. bf16=True,
  293. report_to=report_to,
  294. run_name=run_name if report_to == "trackio" else None,
  295. )
  296. # LoRA config with modules_to_save for embedding layers
  297. # This prevents token ID mismatches during inference
  298. peft_config = LoraConfig(
  299. r=cfg["lora"]["rank"],
  300. lora_alpha=cfg["lora"]["alpha"],
  301. lora_dropout=cfg["lora"]["dropout"],
  302. bias="none",
  303. task_type="CAUSAL_LM",
  304. target_modules=cfg["lora"]["target_modules"],
  305. modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
  306. ensure_weight_tying=True,
  307. )
  308. print("Loading tokenizer...")
  309. base_model = cfg["model"]["base"]
  310. tokenizer = AutoTokenizer.from_pretrained(base_model)
  311. if tokenizer.pad_token is None:
  312. tokenizer.pad_token = tokenizer.eos_token
  313. print("Initializing SFT trainer...")
  314. trainer = SFTTrainer(
  315. model=base_model,
  316. train_dataset=train_dataset,
  317. eval_dataset=eval_dataset,
  318. args=config,
  319. peft_config=peft_config,
  320. processing_class=tokenizer,
  321. callbacks=callbacks,
  322. )
  323. print("Starting SFT training...")
  324. trainer.train()
  325. is_main = os.environ.get("RANK", "0") == "0"
  326. if dist.is_available() and dist.is_initialized():
  327. dist.barrier()
  328. if not is_main:
  329. return
  330. if push_to_hub:
  331. print("Pushing to Hub...")
  332. trainer.push_to_hub()
  333. print(f"Done! Model: https://huggingface.co/{output_name}")
  334. else:
  335. trainer.save_model()
  336. print(f"Done! Model saved to: {output_dir}")
  337. # Export GGUF
  338. print("\nExporting to GGUF...")
  339. # Need to get the merged model for GGUF
  340. print("Loading model for GGUF export...")
  341. from peft import PeftModel
  342. base = AutoModelForCausalLM.from_pretrained(
  343. base_model, torch_dtype=torch.bfloat16, device_map="auto"
  344. )
  345. base.config.tie_word_embeddings = False
  346. model = PeftModel.from_pretrained(base, output_dir, local_files_only=True)
  347. model = model.merge_and_unload()
  348. export_gguf(model, tokenizer, output_dir, Path(output_dir).name)
  349. # Run eval
  350. eval_avg = run_eval(output_dir)
  351. if report_to == "trackio":
  352. try:
  353. import trackio
  354. if eval_avg is not None:
  355. trackio.log({"eval.avg": eval_avg})
  356. except Exception:
  357. pass
  358. def cmd_grpo(args):
  359. """Run GRPO reinforcement learning on top of merged SFT weights."""
  360. import torch
  361. import torch.distributed as dist
  362. import os
  363. from datasets import load_dataset
  364. from peft import LoraConfig, PeftModel, get_peft_model
  365. from transformers import AutoModelForCausalLM, AutoTokenizer
  366. from transformers.utils import logging as hf_logging
  367. hf_logging.set_verbosity_error()
  368. from trl import GRPOTrainer, GRPOConfig
  369. # Import reward from the shared module
  370. sys.path.insert(0, os.path.dirname(__file__))
  371. from reward import QMDRewardFunction
  372. with open(args.config) as f:
  373. cfg = yaml.safe_load(f)
  374. os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
  375. if args.dry_run:
  376. print("GRPO Training Configuration:")
  377. print(yaml.dump(cfg, default_flow_style=False))
  378. return
  379. # Tracking
  380. report_to = "none"
  381. if os.environ.get("HF_TOKEN"):
  382. try:
  383. import trackio # noqa: F401
  384. report_to = "trackio"
  385. except Exception:
  386. print("Trackio not installed; disabling tracking.")
  387. tracking = cfg.get("tracking", {})
  388. if report_to == "trackio":
  389. project = tracking.get("project")
  390. if project:
  391. os.environ.setdefault("TRACKIO_PROJECT", project)
  392. run_name = tracking.get("run_name")
  393. if run_name and "{" in run_name:
  394. from datetime import datetime
  395. now = datetime.now()
  396. run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
  397. "{time}", now.strftime("%H:%M")
  398. )
  399. # Load tokenizer
  400. base_model_name = cfg["model"]["base"]
  401. print(f"Loading tokenizer from {base_model_name}...")
  402. tokenizer = AutoTokenizer.from_pretrained(base_model_name)
  403. if tokenizer.pad_token is None:
  404. tokenizer.pad_token = tokenizer.eos_token
  405. # Load and format dataset
  406. print("Loading dataset...")
  407. dataset = load_dataset(cfg["dataset"]["name"], split="train")
  408. def extract_prompt(example):
  409. content = example[cfg["dataset"]["prompt_field"]][0]["content"]
  410. messages = [{"role": "user", "content": content}]
  411. formatted = tokenizer.apply_chat_template(
  412. messages, tokenize=False, add_generation_prompt=True
  413. )
  414. return {"prompt": formatted}
  415. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  416. max_samples = cfg["dataset"].get("max_samples", len(dataset))
  417. dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
  418. print(f"Using {len(dataset)} prompts for GRPO")
  419. # Load base model, merge SFT adapter
  420. sft_model_name = cfg["model"]["sft"]
  421. print(f"Loading SFT model from {sft_model_name}...")
  422. local_rank = int(os.environ.get("LOCAL_RANK", "0"))
  423. if torch.cuda.is_available():
  424. available = torch.cuda.device_count()
  425. if available == 0:
  426. raise RuntimeError("CUDA is available but no devices were detected.")
  427. if local_rank >= available:
  428. print(
  429. f"Warning: LOCAL_RANK={local_rank} but only {available} CUDA device(s) visible. "
  430. "Falling back to the last available device."
  431. )
  432. local_rank = available - 1
  433. torch.cuda.set_device(local_rank)
  434. dtype_name = cfg["model"].get("torch_dtype", "bfloat16")
  435. dtype_map = {
  436. "bfloat16": torch.bfloat16,
  437. "float16": torch.float16,
  438. "float32": torch.float32,
  439. }
  440. torch_dtype = dtype_map.get(dtype_name, torch.bfloat16)
  441. model_kwargs = {
  442. "torch_dtype": torch_dtype,
  443. "device_map": {"": local_rank} if torch.cuda.is_available() else "auto",
  444. }
  445. base_model = AutoModelForCausalLM.from_pretrained(
  446. base_model_name,
  447. **model_kwargs,
  448. )
  449. model = PeftModel.from_pretrained(base_model, sft_model_name)
  450. model = model.merge_and_unload()
  451. print("SFT adapter merged.")
  452. # Add fresh LoRA for GRPO with modules_to_save
  453. grpo_lora_config = LoraConfig(
  454. r=cfg["lora"]["rank"],
  455. lora_alpha=cfg["lora"]["alpha"],
  456. lora_dropout=cfg["lora"]["dropout"],
  457. bias="none",
  458. task_type="CAUSAL_LM",
  459. target_modules=cfg["lora"]["target_modules"],
  460. modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
  461. ensure_weight_tying=True,
  462. )
  463. model = get_peft_model(model, grpo_lora_config)
  464. model.print_trainable_parameters()
  465. # Build GRPO config
  466. output_name = cfg["model"]["output"]
  467. push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
  468. if "push_to_hub" in cfg["model"]:
  469. push_to_hub = bool(cfg["model"]["push_to_hub"])
  470. output_dir = output_name.split("/")[-1] if push_to_hub else output_name
  471. grpo_cfg = cfg.get("grpo", {})
  472. learning_rate = cfg["training"]["learning_rate"]
  473. if isinstance(learning_rate, str):
  474. learning_rate = float(learning_rate)
  475. save_interval_minutes = cfg["training"].get("save_interval_minutes")
  476. save_steps = cfg["training"].get("save_steps", 200)
  477. save_total_limit = cfg["training"].get("save_total_limit", 2)
  478. save_strategy = cfg["training"].get("save_strategy", "epoch")
  479. if save_interval_minutes:
  480. # Prefer wall-clock checkpointing (for long jobs / preemption safety)
  481. save_steps = max(save_steps, 10_000_000)
  482. save_strategy = "steps"
  483. callbacks = []
  484. if save_interval_minutes:
  485. try:
  486. interval_value = float(save_interval_minutes)
  487. except (TypeError, ValueError):
  488. interval_value = None
  489. if interval_value and interval_value > 0:
  490. callbacks.append(TimedSaveCallback(interval_value))
  491. config = GRPOConfig(
  492. output_dir=output_dir,
  493. push_to_hub=push_to_hub,
  494. hub_model_id=output_name if push_to_hub else None,
  495. num_generations=grpo_cfg.get("num_generations", 4),
  496. max_completion_length=grpo_cfg.get("max_completion_length", 200),
  497. beta=grpo_cfg.get("beta", 0.04),
  498. num_train_epochs=cfg["training"]["epochs"],
  499. per_device_train_batch_size=cfg["training"]["batch_size"],
  500. gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
  501. learning_rate=learning_rate,
  502. max_grad_norm=cfg["training"]["max_grad_norm"],
  503. max_steps=cfg["training"].get("max_steps", -1),
  504. logging_steps=10,
  505. save_strategy=save_strategy,
  506. save_steps=save_steps,
  507. save_total_limit=save_total_limit,
  508. bf16=True,
  509. skip_memory_metrics=True,
  510. report_to=report_to,
  511. run_name=run_name if report_to == "trackio" else None,
  512. )
  513. # Train
  514. print("Initializing GRPO trainer...")
  515. trainer = GRPOTrainer(
  516. model=model,
  517. processing_class=tokenizer,
  518. args=config,
  519. train_dataset=dataset,
  520. reward_funcs=[QMDRewardFunction()],
  521. callbacks=callbacks,
  522. )
  523. print("Starting GRPO training...")
  524. trainer.train()
  525. is_main = os.environ.get("RANK", "0") == "0"
  526. if dist.is_available() and dist.is_initialized():
  527. dist.barrier()
  528. if not is_main:
  529. return
  530. if push_to_hub:
  531. print("Pushing to Hub...")
  532. trainer.push_to_hub()
  533. trainer.save_model()
  534. if report_to == "trackio":
  535. try:
  536. import trackio
  537. trackio.finish()
  538. except Exception:
  539. pass
  540. print(f"Done! Model saved to: {output_dir}")
  541. # Export GGUF
  542. print("\nExporting to GGUF...")
  543. merged = model.merge_and_unload()
  544. export_gguf(merged, tokenizer, output_dir, Path(output_dir).name)
  545. # Run eval
  546. eval_avg = run_eval(output_dir)
  547. if report_to == "trackio" and eval_avg is not None:
  548. try:
  549. import trackio
  550. trackio.log({"eval.avg": eval_avg})
  551. except Exception:
  552. pass
  553. def main():
  554. parser = argparse.ArgumentParser(
  555. description="QMD Query Expansion Training",
  556. formatter_class=argparse.RawDescriptionHelpFormatter,
  557. epilog="""
  558. Examples:
  559. uv run train.py sft --config configs/sft.yaml
  560. uv run train.py grpo --config configs/grpo.yaml
  561. uv run train.py grpo --config configs/grpo.yaml --dry-run
  562. """,
  563. )
  564. sub = parser.add_subparsers(dest="stage", required=True)
  565. sft_parser = sub.add_parser("sft", help="Supervised fine-tuning")
  566. sft_parser.add_argument("--config", required=True, help="Path to SFT config YAML")
  567. sft_parser.add_argument(
  568. "--dry-run", action="store_true", help="Print config and exit"
  569. )
  570. grpo_parser = sub.add_parser("grpo", help="GRPO reinforcement learning")
  571. grpo_parser.add_argument("--config", required=True, help="Path to GRPO config YAML")
  572. grpo_parser.add_argument(
  573. "--dry-run", action="store_true", help="Print config, test reward, and exit"
  574. )
  575. args = parser.parse_args()
  576. if args.stage == "sft":
  577. cmd_sft(args)
  578. elif args.stage == "grpo":
  579. cmd_grpo(args)
  580. if __name__ == "__main__":
  581. main()