train.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "torch",
  5. # "trl>=0.12.0",
  6. # "peft>=0.7.0",
  7. # "transformers>=4.45.0",
  8. # "accelerate>=0.24.0",
  9. # "huggingface_hub>=0.20.0",
  10. # "trackio",
  11. # "nvidia-ml-py",
  12. # "datasets",
  13. # "bitsandbytes",
  14. # "pyyaml",
  15. # "gguf",
  16. # ]
  17. # ///
  18. """
  19. Unified training script for QMD query expansion models.
  20. Primary pipeline is SFT-only:
  21. sft - Supervised fine-tuning on labeled examples
  22. GRPO was moved to `experiments/grpo/` and is not part of the main training
  23. pipeline by default.
  24. Usage:
  25. uv run train.py sft --config configs/sft.yaml
  26. """
  27. import argparse
  28. import os
  29. import subprocess
  30. import sys
  31. import time
  32. from pathlib import Path
  33. import yaml
  34. from transformers import TrainerCallback
  35. def export_gguf(model, tokenizer, output_dir: str, model_name: str):
  36. """Export model to GGUF at Q4_K_M, Q6_K, Q8_0 quantizations."""
  37. import shutil
  38. import tempfile
  39. output_path = Path(output_dir)
  40. gguf_dir = output_path / "gguf"
  41. gguf_dir.mkdir(exist_ok=True)
  42. # Save merged model to temp dir
  43. print("Saving merged model for GGUF conversion...")
  44. with tempfile.TemporaryDirectory() as tmp:
  45. merged_path = Path(tmp) / "merged"
  46. model.save_pretrained(merged_path, safe_serialization=True)
  47. tokenizer.save_pretrained(merged_path)
  48. # Setup llama.cpp
  49. llama_cpp = Path("/tmp/llama.cpp")
  50. if not llama_cpp.exists():
  51. print("Cloning llama.cpp...")
  52. subprocess.run(
  53. [
  54. "git",
  55. "clone",
  56. "--depth",
  57. "1",
  58. "https://github.com/ggerganov/llama.cpp.git",
  59. str(llama_cpp),
  60. ],
  61. capture_output=True,
  62. )
  63. subprocess.run(
  64. [
  65. sys.executable,
  66. "-m",
  67. "pip",
  68. "install",
  69. "-q",
  70. "-r",
  71. str(llama_cpp / "requirements.txt"),
  72. ],
  73. capture_output=True,
  74. )
  75. # Build quantize tool if needed
  76. quantize_bin = llama_cpp / "build" / "bin" / "llama-quantize"
  77. if not quantize_bin.exists():
  78. print("Building llama-quantize...")
  79. build_dir = llama_cpp / "build"
  80. build_dir.mkdir(exist_ok=True)
  81. subprocess.run(
  82. [
  83. "cmake",
  84. "-B",
  85. str(build_dir),
  86. "-S",
  87. str(llama_cpp),
  88. "-DGGML_CUDA=OFF",
  89. ],
  90. capture_output=True,
  91. )
  92. subprocess.run(
  93. [
  94. "cmake",
  95. "--build",
  96. str(build_dir),
  97. "--target",
  98. "llama-quantize",
  99. "-j",
  100. "4",
  101. ],
  102. capture_output=True,
  103. )
  104. # Convert to FP16 first
  105. fp16_file = gguf_dir / f"{model_name}-f16.gguf"
  106. print(f"Converting to FP16: {fp16_file}")
  107. log_out = Path("/tmp/qmd-gguf-convert.log")
  108. log_err = Path("/tmp/qmd-gguf-convert.err")
  109. with log_out.open("w") as out_f, log_err.open("w") as err_f:
  110. result = subprocess.run(
  111. [
  112. sys.executable,
  113. str(llama_cpp / "convert_hf_to_gguf.py"),
  114. str(merged_path),
  115. "--outfile",
  116. str(fp16_file),
  117. "--outtype",
  118. "f16",
  119. ],
  120. stdout=out_f,
  121. stderr=err_f,
  122. text=True,
  123. )
  124. if result.returncode != 0:
  125. print("GGUF conversion failed.")
  126. print(f"stdout: {log_out}")
  127. print(f"stderr: {log_err}")
  128. return
  129. # Quantize to 4, 6, 8 bit
  130. for quant_type in ["Q4_K_M", "Q6_K", "Q8_0"]:
  131. out_file = gguf_dir / f"{model_name}-{quant_type.lower()}.gguf"
  132. print(f"Quantizing {quant_type}: {out_file}")
  133. subprocess.run(
  134. [str(quantize_bin), str(fp16_file), str(out_file), quant_type],
  135. capture_output=True,
  136. )
  137. if out_file.exists():
  138. size_mb = out_file.stat().st_size / (1024 * 1024)
  139. print(f" {quant_type}: {size_mb:.1f} MB")
  140. # Remove FP16 to save space
  141. if fp16_file.exists():
  142. fp16_file.unlink()
  143. print(f"GGUF files saved to: {gguf_dir}")
  144. class TimedSaveCallback(TrainerCallback):
  145. """Trigger periodic checkpoint saves based on elapsed wall-clock time."""
  146. def __init__(self, interval_minutes: float):
  147. self.interval_seconds = float(interval_minutes) * 60.0
  148. self.last_save_time = time.time()
  149. def on_step_end(self, args, state, control, **kwargs):
  150. if not getattr(state, "is_world_process_zero", False):
  151. return control
  152. now = time.time()
  153. if now - self.last_save_time >= self.interval_seconds:
  154. control.should_save = True
  155. self.last_save_time = now
  156. return control
  157. def run_eval(model_path: str) -> float | None:
  158. """Run eval.py on the trained model and return average score."""
  159. print("\n" + "=" * 60)
  160. print("Running evaluation...")
  161. print("=" * 60)
  162. eval_script = Path(__file__).parent / "eval.py"
  163. result = subprocess.run(
  164. [sys.executable, str(eval_script), model_path],
  165. cwd=str(Path(__file__).parent),
  166. capture_output=True,
  167. text=True,
  168. )
  169. if result.stdout:
  170. print(result.stdout, end="")
  171. if result.stderr:
  172. print(result.stderr, end="")
  173. avg = None
  174. for line in (result.stdout or "").splitlines():
  175. if line.strip().startswith("Average:"):
  176. try:
  177. avg = float(line.split("Average:", 1)[1].split("%", 1)[0].strip())
  178. except ValueError:
  179. pass
  180. break
  181. return avg
  182. def cmd_sft(args):
  183. """Run supervised fine-tuning."""
  184. import torch
  185. from datasets import load_dataset
  186. import torch.distributed as dist
  187. from peft import LoraConfig
  188. from transformers import AutoTokenizer, AutoModelForCausalLM
  189. from transformers.utils import logging as hf_logging
  190. hf_logging.set_verbosity_error()
  191. from trl import SFTTrainer, SFTConfig
  192. with open(args.config) as f:
  193. cfg = yaml.safe_load(f)
  194. os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
  195. if args.dry_run:
  196. print("SFT Training Configuration:")
  197. print(yaml.dump(cfg, default_flow_style=False))
  198. return
  199. dataset_name = cfg["dataset"]["name"]
  200. print(f"Loading dataset: {dataset_name}...")
  201. # Support local JSONL files and glob patterns
  202. if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
  203. from pathlib import Path
  204. import glob
  205. # Handle glob patterns like "data/*.jsonl"
  206. if "*" in dataset_name:
  207. jsonl_files = sorted(glob.glob(dataset_name))
  208. if not jsonl_files:
  209. raise ValueError(f"No files found matching: {dataset_name}")
  210. print(
  211. f" Found {len(jsonl_files)} JSONL files: {[Path(f).name for f in jsonl_files]}"
  212. )
  213. dataset = load_dataset("json", data_files=jsonl_files, split="train")
  214. else:
  215. data_path = Path(dataset_name)
  216. if data_path.is_dir():
  217. train_file = data_path / "train.jsonl"
  218. dataset = load_dataset(
  219. "json", data_files=str(train_file), split="train"
  220. )
  221. else:
  222. dataset = load_dataset("json", data_files=dataset_name, split="train")
  223. else:
  224. dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
  225. print(f"Dataset loaded: {len(dataset)} examples")
  226. dataset = dataset.shuffle(seed=42)
  227. split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
  228. train_dataset = split["train"]
  229. eval_dataset = split["test"]
  230. print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
  231. # Check if output looks like a HF Hub path (contains /)
  232. output_name = cfg["model"]["output"]
  233. push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
  234. if "push_to_hub" in cfg["model"]:
  235. push_to_hub = bool(cfg["model"]["push_to_hub"])
  236. output_dir = output_name.split("/")[-1] if push_to_hub else output_name
  237. report_to = "none"
  238. if os.environ.get("HF_TOKEN"):
  239. try:
  240. import trackio # noqa: F401
  241. report_to = "trackio"
  242. except Exception:
  243. print("Trackio not installed; disabling tracking.")
  244. tracking = cfg.get("tracking", {})
  245. if report_to == "trackio":
  246. project = tracking.get("project")
  247. if project:
  248. os.environ.setdefault("TRACKIO_PROJECT", project)
  249. run_name = tracking.get("run_name")
  250. if run_name and "{" in run_name:
  251. from datetime import datetime
  252. now = datetime.now()
  253. run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
  254. "{time}", now.strftime("%H:%M")
  255. )
  256. save_interval_minutes = cfg["training"].get("save_interval_minutes")
  257. save_steps = cfg["training"].get("save_steps", 200)
  258. save_total_limit = cfg["training"].get("save_total_limit", 2)
  259. if save_interval_minutes:
  260. # Prefer wall-clock checkpointing (for long jobs / preemption safety)
  261. save_steps = max(save_steps, 10_000_000)
  262. callbacks = []
  263. if save_interval_minutes:
  264. try:
  265. interval_value = float(save_interval_minutes)
  266. except (TypeError, ValueError):
  267. interval_value = None
  268. if interval_value and interval_value > 0:
  269. callbacks.append(TimedSaveCallback(interval_value))
  270. config = SFTConfig(
  271. output_dir=output_dir,
  272. push_to_hub=push_to_hub,
  273. hub_model_id=output_name if push_to_hub else None,
  274. hub_strategy="every_save" if push_to_hub else "end",
  275. num_train_epochs=cfg["training"]["epochs"],
  276. per_device_train_batch_size=cfg["training"]["batch_size"],
  277. gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
  278. learning_rate=cfg["training"]["learning_rate"],
  279. max_length=cfg["training"]["max_length"],
  280. logging_steps=10,
  281. save_strategy="steps",
  282. save_steps=save_steps,
  283. save_total_limit=save_total_limit,
  284. eval_strategy="steps",
  285. eval_steps=cfg["training"].get("eval_steps", 200),
  286. warmup_ratio=cfg["training"]["warmup_ratio"],
  287. lr_scheduler_type=cfg["training"]["lr_scheduler"],
  288. ddp_find_unused_parameters=cfg["training"].get(
  289. "ddp_find_unused_parameters", False
  290. ),
  291. bf16=True,
  292. report_to=report_to,
  293. run_name=run_name if report_to == "trackio" else None,
  294. )
  295. # LoRA config with modules_to_save for embedding layers
  296. # This prevents token ID mismatches during inference
  297. peft_config = LoraConfig(
  298. r=cfg["lora"]["rank"],
  299. lora_alpha=cfg["lora"]["alpha"],
  300. lora_dropout=cfg["lora"]["dropout"],
  301. bias="none",
  302. task_type="CAUSAL_LM",
  303. target_modules=cfg["lora"]["target_modules"],
  304. modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
  305. ensure_weight_tying=True,
  306. )
  307. print("Loading tokenizer...")
  308. base_model = cfg["model"]["base"]
  309. tokenizer = AutoTokenizer.from_pretrained(base_model)
  310. if tokenizer.pad_token is None:
  311. tokenizer.pad_token = tokenizer.eos_token
  312. print("Initializing SFT trainer...")
  313. trainer = SFTTrainer(
  314. model=base_model,
  315. train_dataset=train_dataset,
  316. eval_dataset=eval_dataset,
  317. args=config,
  318. peft_config=peft_config,
  319. processing_class=tokenizer,
  320. callbacks=callbacks,
  321. )
  322. print("Starting SFT training...")
  323. trainer.train()
  324. is_main = os.environ.get("RANK", "0") == "0"
  325. if dist.is_available() and dist.is_initialized():
  326. dist.barrier()
  327. if not is_main:
  328. return
  329. if push_to_hub:
  330. print("Pushing to Hub...")
  331. trainer.push_to_hub()
  332. print(f"Done! Model: https://huggingface.co/{output_name}")
  333. else:
  334. trainer.save_model()
  335. print(f"Done! Model saved to: {output_dir}")
  336. # Export GGUF
  337. print("\nExporting to GGUF...")
  338. # Need to get the merged model for GGUF
  339. print("Loading model for GGUF export...")
  340. from peft import PeftModel
  341. base = AutoModelForCausalLM.from_pretrained(
  342. base_model, torch_dtype=torch.bfloat16, device_map="auto"
  343. )
  344. base.config.tie_word_embeddings = False
  345. model = PeftModel.from_pretrained(base, output_dir, local_files_only=True)
  346. model = model.merge_and_unload()
  347. export_gguf(model, tokenizer, output_dir, Path(output_dir).name)
  348. # Run eval
  349. eval_avg = run_eval(output_dir)
  350. if report_to == "trackio":
  351. try:
  352. import trackio
  353. if eval_avg is not None:
  354. trackio.log({"eval.avg": eval_avg})
  355. except Exception:
  356. pass
  357. def cmd_grpo(args):
  358. """Run GRPO reinforcement learning on top of merged SFT weights."""
  359. print(
  360. "GRPO is not part of the main training pipeline and has been moved to `experiments/grpo/`."
  361. )
  362. print("To run experimental GRPO, use:")
  363. print(" cd finetune && uv run python experiments/grpo/grpo.py")
  364. print("Or, if you have local config wiring ready:")
  365. print(" uv run train.py grpo --config experiments/grpo/grpo.yaml")
  366. return
  367. import torch
  368. import torch.distributed as dist
  369. import os
  370. from datasets import load_dataset
  371. from peft import LoraConfig, PeftModel, get_peft_model
  372. from transformers import AutoModelForCausalLM, AutoTokenizer
  373. from transformers.utils import logging as hf_logging
  374. hf_logging.set_verbosity_error()
  375. from trl import GRPOTrainer, GRPOConfig
  376. # Import reward from the shared module
  377. sys.path.insert(0, os.path.dirname(__file__))
  378. from reward import QMDRewardFunction
  379. with open(args.config) as f:
  380. cfg = yaml.safe_load(f)
  381. os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
  382. if args.dry_run:
  383. print("GRPO Training Configuration:")
  384. print(yaml.dump(cfg, default_flow_style=False))
  385. return
  386. # Tracking
  387. report_to = "none"
  388. if os.environ.get("HF_TOKEN"):
  389. try:
  390. import trackio # noqa: F401
  391. report_to = "trackio"
  392. except Exception:
  393. print("Trackio not installed; disabling tracking.")
  394. tracking = cfg.get("tracking", {})
  395. if report_to == "trackio":
  396. project = tracking.get("project")
  397. if project:
  398. os.environ.setdefault("TRACKIO_PROJECT", project)
  399. run_name = tracking.get("run_name")
  400. if run_name and "{" in run_name:
  401. from datetime import datetime
  402. now = datetime.now()
  403. run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
  404. "{time}", now.strftime("%H:%M")
  405. )
  406. # Load tokenizer
  407. base_model_name = cfg["model"]["base"]
  408. print(f"Loading tokenizer from {base_model_name}...")
  409. tokenizer = AutoTokenizer.from_pretrained(base_model_name)
  410. if tokenizer.pad_token is None:
  411. tokenizer.pad_token = tokenizer.eos_token
  412. # Load and format dataset
  413. print("Loading dataset...")
  414. dataset = load_dataset(cfg["dataset"]["name"], split="train")
  415. def extract_prompt(example):
  416. content = example[cfg["dataset"]["prompt_field"]][0]["content"]
  417. messages = [{"role": "user", "content": content}]
  418. formatted = tokenizer.apply_chat_template(
  419. messages, tokenize=False, add_generation_prompt=True
  420. )
  421. return {"prompt": formatted}
  422. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  423. max_samples = cfg["dataset"].get("max_samples", len(dataset))
  424. dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
  425. print(f"Using {len(dataset)} prompts for GRPO")
  426. # Load base model, merge SFT adapter
  427. sft_model_name = cfg["model"]["sft"]
  428. print(f"Loading SFT model from {sft_model_name}...")
  429. local_rank = int(os.environ.get("LOCAL_RANK", "0"))
  430. if torch.cuda.is_available():
  431. available = torch.cuda.device_count()
  432. if available == 0:
  433. raise RuntimeError("CUDA is available but no devices were detected.")
  434. if local_rank >= available:
  435. print(
  436. f"Warning: LOCAL_RANK={local_rank} but only {available} CUDA device(s) visible. "
  437. "Falling back to the last available device."
  438. )
  439. local_rank = available - 1
  440. torch.cuda.set_device(local_rank)
  441. dtype_name = cfg["model"].get("torch_dtype", "bfloat16")
  442. dtype_map = {
  443. "bfloat16": torch.bfloat16,
  444. "float16": torch.float16,
  445. "float32": torch.float32,
  446. }
  447. torch_dtype = dtype_map.get(dtype_name, torch.bfloat16)
  448. model_kwargs = {
  449. "torch_dtype": torch_dtype,
  450. "device_map": {"": local_rank} if torch.cuda.is_available() else "auto",
  451. }
  452. base_model = AutoModelForCausalLM.from_pretrained(
  453. base_model_name,
  454. **model_kwargs,
  455. )
  456. model = PeftModel.from_pretrained(base_model, sft_model_name)
  457. model = model.merge_and_unload()
  458. print("SFT adapter merged.")
  459. # Add fresh LoRA for GRPO with modules_to_save
  460. grpo_lora_config = LoraConfig(
  461. r=cfg["lora"]["rank"],
  462. lora_alpha=cfg["lora"]["alpha"],
  463. lora_dropout=cfg["lora"]["dropout"],
  464. bias="none",
  465. task_type="CAUSAL_LM",
  466. target_modules=cfg["lora"]["target_modules"],
  467. modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
  468. ensure_weight_tying=True,
  469. )
  470. model = get_peft_model(model, grpo_lora_config)
  471. model.print_trainable_parameters()
  472. # Build GRPO config
  473. output_name = cfg["model"]["output"]
  474. push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
  475. if "push_to_hub" in cfg["model"]:
  476. push_to_hub = bool(cfg["model"]["push_to_hub"])
  477. output_dir = output_name.split("/")[-1] if push_to_hub else output_name
  478. grpo_cfg = cfg.get("grpo", {})
  479. learning_rate = cfg["training"]["learning_rate"]
  480. if isinstance(learning_rate, str):
  481. learning_rate = float(learning_rate)
  482. save_interval_minutes = cfg["training"].get("save_interval_minutes")
  483. save_steps = cfg["training"].get("save_steps", 200)
  484. save_total_limit = cfg["training"].get("save_total_limit", 2)
  485. save_strategy = cfg["training"].get("save_strategy", "epoch")
  486. if save_interval_minutes:
  487. # Prefer wall-clock checkpointing (for long jobs / preemption safety)
  488. save_steps = max(save_steps, 10_000_000)
  489. save_strategy = "steps"
  490. callbacks = []
  491. if save_interval_minutes:
  492. try:
  493. interval_value = float(save_interval_minutes)
  494. except (TypeError, ValueError):
  495. interval_value = None
  496. if interval_value and interval_value > 0:
  497. callbacks.append(TimedSaveCallback(interval_value))
  498. config = GRPOConfig(
  499. output_dir=output_dir,
  500. push_to_hub=push_to_hub,
  501. hub_model_id=output_name if push_to_hub else None,
  502. num_generations=grpo_cfg.get("num_generations", 4),
  503. max_completion_length=grpo_cfg.get("max_completion_length", 200),
  504. beta=grpo_cfg.get("beta", 0.04),
  505. num_train_epochs=cfg["training"]["epochs"],
  506. per_device_train_batch_size=cfg["training"]["batch_size"],
  507. gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
  508. learning_rate=learning_rate,
  509. max_grad_norm=cfg["training"]["max_grad_norm"],
  510. max_steps=cfg["training"].get("max_steps", -1),
  511. logging_steps=10,
  512. save_strategy=save_strategy,
  513. save_steps=save_steps,
  514. save_total_limit=save_total_limit,
  515. bf16=True,
  516. skip_memory_metrics=True,
  517. report_to=report_to,
  518. run_name=run_name if report_to == "trackio" else None,
  519. )
  520. # Train
  521. print("Initializing GRPO trainer...")
  522. trainer = GRPOTrainer(
  523. model=model,
  524. processing_class=tokenizer,
  525. args=config,
  526. train_dataset=dataset,
  527. reward_funcs=[QMDRewardFunction()],
  528. callbacks=callbacks,
  529. )
  530. print("Starting GRPO training...")
  531. trainer.train()
  532. is_main = os.environ.get("RANK", "0") == "0"
  533. if dist.is_available() and dist.is_initialized():
  534. dist.barrier()
  535. if not is_main:
  536. return
  537. if push_to_hub:
  538. print("Pushing to Hub...")
  539. trainer.push_to_hub()
  540. trainer.save_model()
  541. if report_to == "trackio":
  542. try:
  543. import trackio
  544. trackio.finish()
  545. except Exception:
  546. pass
  547. print(f"Done! Model saved to: {output_dir}")
  548. # Export GGUF
  549. print("\nExporting to GGUF...")
  550. merged = model.merge_and_unload()
  551. export_gguf(merged, tokenizer, output_dir, Path(output_dir).name)
  552. # Run eval
  553. eval_avg = run_eval(output_dir)
  554. if report_to == "trackio" and eval_avg is not None:
  555. try:
  556. import trackio
  557. trackio.log({"eval.avg": eval_avg})
  558. except Exception:
  559. pass
  560. def main():
  561. parser = argparse.ArgumentParser(
  562. description="QMD Query Expansion Training",
  563. formatter_class=argparse.RawDescriptionHelpFormatter,
  564. epilog="""
  565. Examples:
  566. uv run train.py sft --config configs/sft.yaml
  567. """,
  568. )
  569. sub = parser.add_subparsers(dest="stage", required=True)
  570. sft_parser = sub.add_parser("sft", help="Supervised fine-tuning")
  571. sft_parser.add_argument("--config", required=True, help="Path to SFT config YAML")
  572. sft_parser.add_argument(
  573. "--dry-run", action="store_true", help="Print config and exit"
  574. )
  575. grpo_parser = sub.add_parser(
  576. "grpo",
  577. help="Experimental: GRPO reinforcement learning (moved to experiments/grpo/)",
  578. )
  579. grpo_parser.add_argument("--config", required=True, help="Path to GRPO config YAML")
  580. grpo_parser.add_argument(
  581. "--dry-run", action="store_true", help="Print config, test reward, and exit"
  582. )
  583. args = parser.parse_args()
  584. if args.stage == "sft":
  585. cmd_sft(args)
  586. elif args.stage == "grpo":
  587. cmd_grpo(args)
  588. if __name__ == "__main__":
  589. main()