train.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "torch",
  5. # "trl>=0.12.0",
  6. # "peft>=0.7.0",
  7. # "transformers>=4.45.0",
  8. # "accelerate>=0.24.0",
  9. # "huggingface_hub>=0.20.0",
  10. # "trackio",
  11. # "nvidia-ml-py",
  12. # "datasets",
  13. # "bitsandbytes",
  14. # "pyyaml",
  15. # "gguf",
  16. # ]
  17. # ///
  18. """
  19. Unified training script for QMD query expansion models.
  20. Supports two stages:
  21. sft - Supervised fine-tuning on labeled examples
  22. grpo - Group Relative Policy Optimization (RL) on top of merged SFT weights
  23. Usage:
  24. uv run train.py sft --config configs/sft.yaml
  25. uv run train.py grpo --config configs/grpo.yaml
  26. uv run train.py grpo --config configs/grpo.yaml --dry-run
  27. """
  28. import argparse
  29. import os
  30. import subprocess
  31. import sys
  32. from pathlib import Path
  33. import yaml
  34. def export_gguf(model, tokenizer, output_dir: str, model_name: str):
  35. """Export model to GGUF at Q4_K_M, Q6_K, Q8_0 quantizations."""
  36. import shutil
  37. import tempfile
  38. output_path = Path(output_dir)
  39. gguf_dir = output_path / "gguf"
  40. gguf_dir.mkdir(exist_ok=True)
  41. # Save merged model to temp dir
  42. print("Saving merged model for GGUF conversion...")
  43. with tempfile.TemporaryDirectory() as tmp:
  44. merged_path = Path(tmp) / "merged"
  45. model.save_pretrained(merged_path, safe_serialization=True)
  46. tokenizer.save_pretrained(merged_path)
  47. # Setup llama.cpp
  48. llama_cpp = Path("/tmp/llama.cpp")
  49. if not llama_cpp.exists():
  50. print("Cloning llama.cpp...")
  51. subprocess.run(
  52. [
  53. "git",
  54. "clone",
  55. "--depth",
  56. "1",
  57. "https://github.com/ggerganov/llama.cpp.git",
  58. str(llama_cpp),
  59. ],
  60. capture_output=True,
  61. )
  62. subprocess.run(
  63. [
  64. sys.executable,
  65. "-m",
  66. "pip",
  67. "install",
  68. "-q",
  69. "-r",
  70. str(llama_cpp / "requirements.txt"),
  71. ],
  72. capture_output=True,
  73. )
  74. # Build quantize tool if needed
  75. quantize_bin = llama_cpp / "build" / "bin" / "llama-quantize"
  76. if not quantize_bin.exists():
  77. print("Building llama-quantize...")
  78. build_dir = llama_cpp / "build"
  79. build_dir.mkdir(exist_ok=True)
  80. subprocess.run(
  81. [
  82. "cmake",
  83. "-B",
  84. str(build_dir),
  85. "-S",
  86. str(llama_cpp),
  87. "-DGGML_CUDA=OFF",
  88. ],
  89. capture_output=True,
  90. )
  91. subprocess.run(
  92. [
  93. "cmake",
  94. "--build",
  95. str(build_dir),
  96. "--target",
  97. "llama-quantize",
  98. "-j",
  99. "4",
  100. ],
  101. capture_output=True,
  102. )
  103. # Convert to FP16 first
  104. fp16_file = gguf_dir / f"{model_name}-f16.gguf"
  105. print(f"Converting to FP16: {fp16_file}")
  106. log_out = Path("/tmp/qmd-gguf-convert.log")
  107. log_err = Path("/tmp/qmd-gguf-convert.err")
  108. with log_out.open("w") as out_f, log_err.open("w") as err_f:
  109. result = subprocess.run(
  110. [
  111. sys.executable,
  112. str(llama_cpp / "convert_hf_to_gguf.py"),
  113. str(merged_path),
  114. "--outfile",
  115. str(fp16_file),
  116. "--outtype",
  117. "f16",
  118. ],
  119. stdout=out_f,
  120. stderr=err_f,
  121. text=True,
  122. )
  123. if result.returncode != 0:
  124. print("GGUF conversion failed.")
  125. print(f"stdout: {log_out}")
  126. print(f"stderr: {log_err}")
  127. return
  128. # Quantize to 4, 6, 8 bit
  129. for quant_type in ["Q4_K_M", "Q6_K", "Q8_0"]:
  130. out_file = gguf_dir / f"{model_name}-{quant_type.lower()}.gguf"
  131. print(f"Quantizing {quant_type}: {out_file}")
  132. subprocess.run(
  133. [str(quantize_bin), str(fp16_file), str(out_file), quant_type],
  134. capture_output=True,
  135. )
  136. if out_file.exists():
  137. size_mb = out_file.stat().st_size / (1024 * 1024)
  138. print(f" {quant_type}: {size_mb:.1f} MB")
  139. # Remove FP16 to save space
  140. if fp16_file.exists():
  141. fp16_file.unlink()
  142. print(f"GGUF files saved to: {gguf_dir}")
  143. def run_eval(model_path: str) -> float | None:
  144. """Run eval.py on the trained model and return average score."""
  145. print("\n" + "=" * 60)
  146. print("Running evaluation...")
  147. print("=" * 60)
  148. eval_script = Path(__file__).parent / "eval.py"
  149. result = subprocess.run(
  150. [sys.executable, str(eval_script), model_path],
  151. cwd=str(Path(__file__).parent),
  152. capture_output=True,
  153. text=True,
  154. )
  155. if result.stdout:
  156. print(result.stdout, end="")
  157. if result.stderr:
  158. print(result.stderr, end="")
  159. avg = None
  160. for line in (result.stdout or "").splitlines():
  161. if line.strip().startswith("Average:"):
  162. try:
  163. avg = float(line.split("Average:", 1)[1].split("%", 1)[0].strip())
  164. except ValueError:
  165. pass
  166. break
  167. return avg
  168. def cmd_sft(args):
  169. """Run supervised fine-tuning."""
  170. import torch
  171. import os
  172. from datasets import load_dataset
  173. import torch
  174. import torch.distributed as dist
  175. from peft import LoraConfig
  176. from transformers import AutoTokenizer, AutoModelForCausalLM
  177. from transformers.utils import logging as hf_logging
  178. hf_logging.set_verbosity_error()
  179. from trl import SFTTrainer, SFTConfig
  180. with open(args.config) as f:
  181. cfg = yaml.safe_load(f)
  182. os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
  183. if args.dry_run:
  184. print("SFT Training Configuration:")
  185. print(yaml.dump(cfg, default_flow_style=False))
  186. return
  187. dataset_name = cfg["dataset"]["name"]
  188. print(f"Loading dataset: {dataset_name}...")
  189. # Support local JSONL files and glob patterns
  190. if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
  191. from pathlib import Path
  192. import glob
  193. # Handle glob patterns like "data/*.jsonl"
  194. if "*" in dataset_name:
  195. jsonl_files = sorted(glob.glob(dataset_name))
  196. if not jsonl_files:
  197. raise ValueError(f"No files found matching: {dataset_name}")
  198. print(
  199. f" Found {len(jsonl_files)} JSONL files: {[Path(f).name for f in jsonl_files]}"
  200. )
  201. dataset = load_dataset("json", data_files=jsonl_files, split="train")
  202. else:
  203. data_path = Path(dataset_name)
  204. if data_path.is_dir():
  205. train_file = data_path / "train.jsonl"
  206. dataset = load_dataset(
  207. "json", data_files=str(train_file), split="train"
  208. )
  209. else:
  210. dataset = load_dataset("json", data_files=dataset_name, split="train")
  211. else:
  212. dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
  213. print(f"Dataset loaded: {len(dataset)} examples")
  214. dataset = dataset.shuffle(seed=42)
  215. split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
  216. train_dataset = split["train"]
  217. eval_dataset = split["test"]
  218. print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
  219. # Check if output looks like a HF Hub path (contains /)
  220. output_name = cfg["model"]["output"]
  221. push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
  222. if "push_to_hub" in cfg["model"]:
  223. push_to_hub = bool(cfg["model"]["push_to_hub"])
  224. output_dir = output_name.split("/")[-1] if push_to_hub else output_name
  225. report_to = "none"
  226. if os.environ.get("HF_TOKEN"):
  227. try:
  228. import trackio # noqa: F401
  229. report_to = "trackio"
  230. except Exception:
  231. print("Trackio not installed; disabling tracking.")
  232. tracking = cfg.get("tracking", {})
  233. if report_to == "trackio":
  234. project = tracking.get("project")
  235. if project:
  236. os.environ.setdefault("TRACKIO_PROJECT", project)
  237. run_name = tracking.get("run_name")
  238. if run_name and "{" in run_name:
  239. from datetime import datetime
  240. now = datetime.now()
  241. run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
  242. "{time}", now.strftime("%H:%M")
  243. )
  244. config = SFTConfig(
  245. output_dir=output_dir,
  246. push_to_hub=push_to_hub,
  247. hub_model_id=output_name if push_to_hub else None,
  248. hub_strategy="every_save" if push_to_hub else "end",
  249. num_train_epochs=cfg["training"]["epochs"],
  250. per_device_train_batch_size=cfg["training"]["batch_size"],
  251. gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
  252. learning_rate=cfg["training"]["learning_rate"],
  253. max_length=cfg["training"]["max_length"],
  254. logging_steps=10,
  255. save_strategy="steps",
  256. save_steps=200,
  257. save_total_limit=2,
  258. eval_strategy="steps",
  259. eval_steps=200,
  260. warmup_ratio=cfg["training"]["warmup_ratio"],
  261. lr_scheduler_type=cfg["training"]["lr_scheduler"],
  262. ddp_find_unused_parameters=cfg["training"].get(
  263. "ddp_find_unused_parameters", False
  264. ),
  265. bf16=True,
  266. report_to=report_to,
  267. run_name=run_name if report_to == "trackio" else None,
  268. )
  269. # LoRA config with modules_to_save for embedding layers
  270. # This prevents token ID mismatches during inference
  271. peft_config = LoraConfig(
  272. r=cfg["lora"]["rank"],
  273. lora_alpha=cfg["lora"]["alpha"],
  274. lora_dropout=cfg["lora"]["dropout"],
  275. bias="none",
  276. task_type="CAUSAL_LM",
  277. target_modules=cfg["lora"]["target_modules"],
  278. modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
  279. ensure_weight_tying=True,
  280. )
  281. print("Loading tokenizer...")
  282. base_model = cfg["model"]["base"]
  283. tokenizer = AutoTokenizer.from_pretrained(base_model)
  284. if tokenizer.pad_token is None:
  285. tokenizer.pad_token = tokenizer.eos_token
  286. print("Initializing SFT trainer...")
  287. trainer = SFTTrainer(
  288. model=base_model,
  289. train_dataset=train_dataset,
  290. eval_dataset=eval_dataset,
  291. args=config,
  292. peft_config=peft_config,
  293. processing_class=tokenizer,
  294. )
  295. print("Starting SFT training...")
  296. trainer.train()
  297. is_main = os.environ.get("RANK", "0") == "0"
  298. if dist.is_available() and dist.is_initialized():
  299. dist.barrier()
  300. if not is_main:
  301. return
  302. if push_to_hub:
  303. print("Pushing to Hub...")
  304. trainer.push_to_hub()
  305. print(f"Done! Model: https://huggingface.co/{output_name}")
  306. else:
  307. trainer.save_model()
  308. print(f"Done! Model saved to: {output_dir}")
  309. # Export GGUF
  310. print("\nExporting to GGUF...")
  311. # Need to get the merged model for GGUF
  312. print("Loading model for GGUF export...")
  313. from peft import PeftModel
  314. base = AutoModelForCausalLM.from_pretrained(
  315. base_model, torch_dtype=torch.bfloat16, device_map="auto"
  316. )
  317. base.config.tie_word_embeddings = False
  318. model = PeftModel.from_pretrained(base, output_dir, local_files_only=True)
  319. model = model.merge_and_unload()
  320. export_gguf(model, tokenizer, output_dir, Path(output_dir).name)
  321. # Run eval
  322. eval_avg = run_eval(output_dir)
  323. if report_to == "trackio":
  324. try:
  325. import trackio
  326. if eval_avg is not None:
  327. trackio.log({"eval.avg": eval_avg})
  328. except Exception:
  329. pass
  330. def cmd_grpo(args):
  331. """Run GRPO reinforcement learning on top of merged SFT weights."""
  332. import torch
  333. import os
  334. from datasets import load_dataset
  335. from peft import LoraConfig, PeftModel, get_peft_model
  336. from transformers import AutoModelForCausalLM, AutoTokenizer
  337. from transformers.utils import logging as hf_logging
  338. hf_logging.set_verbosity_error()
  339. from trl import GRPOTrainer, GRPOConfig
  340. # Import reward from the shared module
  341. sys.path.insert(0, os.path.dirname(__file__))
  342. from reward import QMDRewardFunction
  343. with open(args.config) as f:
  344. cfg = yaml.safe_load(f)
  345. os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
  346. if args.dry_run:
  347. print("GRPO Training Configuration:")
  348. print(yaml.dump(cfg, default_flow_style=False))
  349. return
  350. # Tracking
  351. report_to = "none"
  352. if os.environ.get("HF_TOKEN"):
  353. try:
  354. import trackio # noqa: F401
  355. report_to = "trackio"
  356. except Exception:
  357. print("Trackio not installed; disabling tracking.")
  358. tracking = cfg.get("tracking", {})
  359. if report_to == "trackio":
  360. project = tracking.get("project")
  361. if project:
  362. os.environ.setdefault("TRACKIO_PROJECT", project)
  363. run_name = tracking.get("run_name")
  364. if run_name and "{" in run_name:
  365. from datetime import datetime
  366. now = datetime.now()
  367. run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
  368. "{time}", now.strftime("%H:%M")
  369. )
  370. # Load tokenizer
  371. base_model_name = cfg["model"]["base"]
  372. print(f"Loading tokenizer from {base_model_name}...")
  373. tokenizer = AutoTokenizer.from_pretrained(base_model_name)
  374. if tokenizer.pad_token is None:
  375. tokenizer.pad_token = tokenizer.eos_token
  376. # Load and format dataset
  377. print("Loading dataset...")
  378. dataset = load_dataset(cfg["dataset"]["name"], split="train")
  379. def extract_prompt(example):
  380. content = example[cfg["dataset"]["prompt_field"]][0]["content"]
  381. messages = [{"role": "user", "content": content}]
  382. formatted = tokenizer.apply_chat_template(
  383. messages, tokenize=False, add_generation_prompt=True
  384. )
  385. return {"prompt": formatted}
  386. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  387. max_samples = cfg["dataset"].get("max_samples", len(dataset))
  388. dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
  389. print(f"Using {len(dataset)} prompts for GRPO")
  390. # Load base model, merge SFT adapter
  391. sft_model_name = cfg["model"]["sft"]
  392. print(f"Loading SFT model from {sft_model_name}...")
  393. local_rank = int(os.environ.get("LOCAL_RANK", "0"))
  394. if torch.cuda.is_available():
  395. available = torch.cuda.device_count()
  396. if available == 0:
  397. raise RuntimeError("CUDA is available but no devices were detected.")
  398. if local_rank >= available:
  399. print(
  400. f"Warning: LOCAL_RANK={local_rank} but only {available} CUDA device(s) visible. "
  401. "Falling back to the last available device."
  402. )
  403. local_rank = available - 1
  404. torch.cuda.set_device(local_rank)
  405. dtype_name = cfg["model"].get("torch_dtype", "bfloat16")
  406. dtype_map = {
  407. "bfloat16": torch.bfloat16,
  408. "float16": torch.float16,
  409. "float32": torch.float32,
  410. }
  411. torch_dtype = dtype_map.get(dtype_name, torch.bfloat16)
  412. model_kwargs = {
  413. "torch_dtype": torch_dtype,
  414. "device_map": {"": local_rank} if torch.cuda.is_available() else "auto",
  415. }
  416. base_model = AutoModelForCausalLM.from_pretrained(
  417. base_model_name,
  418. **model_kwargs,
  419. )
  420. model = PeftModel.from_pretrained(base_model, sft_model_name)
  421. model = model.merge_and_unload()
  422. print("SFT adapter merged.")
  423. # Add fresh LoRA for GRPO with modules_to_save
  424. grpo_lora_config = LoraConfig(
  425. r=cfg["lora"]["rank"],
  426. lora_alpha=cfg["lora"]["alpha"],
  427. lora_dropout=cfg["lora"]["dropout"],
  428. bias="none",
  429. task_type="CAUSAL_LM",
  430. target_modules=cfg["lora"]["target_modules"],
  431. modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
  432. )
  433. model = get_peft_model(model, grpo_lora_config)
  434. model.print_trainable_parameters()
  435. # Build GRPO config
  436. output_name = cfg["model"]["output"]
  437. push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
  438. if "push_to_hub" in cfg["model"]:
  439. push_to_hub = bool(cfg["model"]["push_to_hub"])
  440. output_dir = output_name.split("/")[-1] if push_to_hub else output_name
  441. grpo_cfg = cfg.get("grpo", {})
  442. learning_rate = cfg["training"]["learning_rate"]
  443. if isinstance(learning_rate, str):
  444. learning_rate = float(learning_rate)
  445. config = GRPOConfig(
  446. output_dir=output_dir,
  447. push_to_hub=push_to_hub,
  448. hub_model_id=output_name if push_to_hub else None,
  449. num_generations=grpo_cfg.get("num_generations", 4),
  450. max_completion_length=grpo_cfg.get("max_completion_length", 200),
  451. beta=grpo_cfg.get("beta", 0.04),
  452. num_train_epochs=cfg["training"]["epochs"],
  453. per_device_train_batch_size=cfg["training"]["batch_size"],
  454. gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
  455. learning_rate=learning_rate,
  456. max_grad_norm=cfg["training"]["max_grad_norm"],
  457. max_steps=cfg["training"].get("max_steps", -1),
  458. logging_steps=10,
  459. save_strategy="epoch",
  460. bf16=True,
  461. skip_memory_metrics=True,
  462. report_to=report_to,
  463. run_name=run_name if report_to == "trackio" else None,
  464. )
  465. # Train
  466. print("Initializing GRPO trainer...")
  467. trainer = GRPOTrainer(
  468. model=model,
  469. processing_class=tokenizer,
  470. args=config,
  471. train_dataset=dataset,
  472. reward_funcs=[QMDRewardFunction()],
  473. )
  474. print("Starting GRPO training...")
  475. trainer.train()
  476. if push_to_hub:
  477. print("Pushing to Hub...")
  478. trainer.push_to_hub()
  479. trainer.save_model()
  480. if report_to == "trackio":
  481. try:
  482. import trackio
  483. trackio.finish()
  484. except Exception:
  485. pass
  486. print(f"Done! Model saved to: {output_dir}")
  487. # Export GGUF
  488. print("\nExporting to GGUF...")
  489. merged = model.merge_and_unload()
  490. export_gguf(merged, tokenizer, output_dir, Path(output_dir).name)
  491. # Run eval
  492. eval_avg = run_eval(output_dir)
  493. if report_to == "trackio" and eval_avg is not None:
  494. try:
  495. import trackio
  496. trackio.log({"eval.avg": eval_avg})
  497. except Exception:
  498. pass
  499. def main():
  500. parser = argparse.ArgumentParser(
  501. description="QMD Query Expansion Training",
  502. formatter_class=argparse.RawDescriptionHelpFormatter,
  503. epilog="""
  504. Examples:
  505. uv run train.py sft --config configs/sft.yaml
  506. uv run train.py grpo --config configs/grpo.yaml
  507. uv run train.py grpo --config configs/grpo.yaml --dry-run
  508. """,
  509. )
  510. sub = parser.add_subparsers(dest="stage", required=True)
  511. sft_parser = sub.add_parser("sft", help="Supervised fine-tuning")
  512. sft_parser.add_argument("--config", required=True, help="Path to SFT config YAML")
  513. sft_parser.add_argument(
  514. "--dry-run", action="store_true", help="Print config and exit"
  515. )
  516. grpo_parser = sub.add_parser("grpo", help="GRPO reinforcement learning")
  517. grpo_parser.add_argument("--config", required=True, help="Path to GRPO config YAML")
  518. grpo_parser.add_argument(
  519. "--dry-run", action="store_true", help="Print config, test reward, and exit"
  520. )
  521. args = parser.parse_args()
  522. if args.stage == "sft":
  523. cmd_sft(args)
  524. elif args.stage == "grpo":
  525. cmd_grpo(args)
  526. if __name__ == "__main__":
  527. main()