| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611 |
- # /// script
- # requires-python = ">=3.10"
- # dependencies = [
- # "torch",
- # "trl>=0.12.0",
- # "peft>=0.7.0",
- # "transformers>=4.45.0",
- # "accelerate>=0.24.0",
- # "huggingface_hub>=0.20.0",
- # "trackio",
- # "nvidia-ml-py",
- # "datasets",
- # "bitsandbytes",
- # "pyyaml",
- # "gguf",
- # ]
- # ///
- """
- Unified training script for QMD query expansion models.
- Supports two stages:
- sft - Supervised fine-tuning on labeled examples
- grpo - Group Relative Policy Optimization (RL) on top of merged SFT weights
- Usage:
- uv run train.py sft --config configs/sft.yaml
- uv run train.py grpo --config configs/grpo.yaml
- uv run train.py grpo --config configs/grpo.yaml --dry-run
- """
- import argparse
- import os
- import subprocess
- import sys
- from pathlib import Path
- import yaml
- def export_gguf(model, tokenizer, output_dir: str, model_name: str):
- """Export model to GGUF at Q4_K_M, Q6_K, Q8_0 quantizations."""
- import shutil
- import tempfile
- output_path = Path(output_dir)
- gguf_dir = output_path / "gguf"
- gguf_dir.mkdir(exist_ok=True)
- # Save merged model to temp dir
- print("Saving merged model for GGUF conversion...")
- with tempfile.TemporaryDirectory() as tmp:
- merged_path = Path(tmp) / "merged"
- model.save_pretrained(merged_path, safe_serialization=True)
- tokenizer.save_pretrained(merged_path)
- # Setup llama.cpp
- llama_cpp = Path("/tmp/llama.cpp")
- if not llama_cpp.exists():
- print("Cloning llama.cpp...")
- subprocess.run(
- [
- "git",
- "clone",
- "--depth",
- "1",
- "https://github.com/ggerganov/llama.cpp.git",
- str(llama_cpp),
- ],
- capture_output=True,
- )
- subprocess.run(
- [
- sys.executable,
- "-m",
- "pip",
- "install",
- "-q",
- "-r",
- str(llama_cpp / "requirements.txt"),
- ],
- capture_output=True,
- )
- # Build quantize tool if needed
- quantize_bin = llama_cpp / "build" / "bin" / "llama-quantize"
- if not quantize_bin.exists():
- print("Building llama-quantize...")
- build_dir = llama_cpp / "build"
- build_dir.mkdir(exist_ok=True)
- subprocess.run(
- [
- "cmake",
- "-B",
- str(build_dir),
- "-S",
- str(llama_cpp),
- "-DGGML_CUDA=OFF",
- ],
- capture_output=True,
- )
- subprocess.run(
- [
- "cmake",
- "--build",
- str(build_dir),
- "--target",
- "llama-quantize",
- "-j",
- "4",
- ],
- capture_output=True,
- )
- # Convert to FP16 first
- fp16_file = gguf_dir / f"{model_name}-f16.gguf"
- print(f"Converting to FP16: {fp16_file}")
- log_out = Path("/tmp/qmd-gguf-convert.log")
- log_err = Path("/tmp/qmd-gguf-convert.err")
- with log_out.open("w") as out_f, log_err.open("w") as err_f:
- result = subprocess.run(
- [
- sys.executable,
- str(llama_cpp / "convert_hf_to_gguf.py"),
- str(merged_path),
- "--outfile",
- str(fp16_file),
- "--outtype",
- "f16",
- ],
- stdout=out_f,
- stderr=err_f,
- text=True,
- )
- if result.returncode != 0:
- print("GGUF conversion failed.")
- print(f"stdout: {log_out}")
- print(f"stderr: {log_err}")
- return
- # Quantize to 4, 6, 8 bit
- for quant_type in ["Q4_K_M", "Q6_K", "Q8_0"]:
- out_file = gguf_dir / f"{model_name}-{quant_type.lower()}.gguf"
- print(f"Quantizing {quant_type}: {out_file}")
- subprocess.run(
- [str(quantize_bin), str(fp16_file), str(out_file), quant_type],
- capture_output=True,
- )
- if out_file.exists():
- size_mb = out_file.stat().st_size / (1024 * 1024)
- print(f" {quant_type}: {size_mb:.1f} MB")
- # Remove FP16 to save space
- if fp16_file.exists():
- fp16_file.unlink()
- print(f"GGUF files saved to: {gguf_dir}")
- def run_eval(model_path: str) -> float | None:
- """Run eval.py on the trained model and return average score."""
- print("\n" + "=" * 60)
- print("Running evaluation...")
- print("=" * 60)
- eval_script = Path(__file__).parent / "eval.py"
- result = subprocess.run(
- [sys.executable, str(eval_script), model_path],
- cwd=str(Path(__file__).parent),
- capture_output=True,
- text=True,
- )
- if result.stdout:
- print(result.stdout, end="")
- if result.stderr:
- print(result.stderr, end="")
- avg = None
- for line in (result.stdout or "").splitlines():
- if line.strip().startswith("Average:"):
- try:
- avg = float(line.split("Average:", 1)[1].split("%", 1)[0].strip())
- except ValueError:
- pass
- break
- return avg
- def cmd_sft(args):
- """Run supervised fine-tuning."""
- import torch
- import os
- from datasets import load_dataset
- import torch
- import torch.distributed as dist
- from peft import LoraConfig
- from transformers import AutoTokenizer, AutoModelForCausalLM
- from transformers.utils import logging as hf_logging
- hf_logging.set_verbosity_error()
- from trl import SFTTrainer, SFTConfig
- with open(args.config) as f:
- cfg = yaml.safe_load(f)
- os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
- if args.dry_run:
- print("SFT Training Configuration:")
- print(yaml.dump(cfg, default_flow_style=False))
- return
- dataset_name = cfg["dataset"]["name"]
- print(f"Loading dataset: {dataset_name}...")
- # Support local JSONL files and glob patterns
- if dataset_name.startswith("data/") or dataset_name.endswith(".jsonl"):
- from pathlib import Path
- import glob
- # Handle glob patterns like "data/*.jsonl"
- if "*" in dataset_name:
- jsonl_files = sorted(glob.glob(dataset_name))
- if not jsonl_files:
- raise ValueError(f"No files found matching: {dataset_name}")
- print(
- f" Found {len(jsonl_files)} JSONL files: {[Path(f).name for f in jsonl_files]}"
- )
- dataset = load_dataset("json", data_files=jsonl_files, split="train")
- else:
- data_path = Path(dataset_name)
- if data_path.is_dir():
- train_file = data_path / "train.jsonl"
- dataset = load_dataset(
- "json", data_files=str(train_file), split="train"
- )
- else:
- dataset = load_dataset("json", data_files=dataset_name, split="train")
- else:
- dataset = load_dataset(dataset_name, split=cfg["dataset"]["split"])
- print(f"Dataset loaded: {len(dataset)} examples")
- dataset = dataset.shuffle(seed=42)
- split = dataset.train_test_split(test_size=cfg["dataset"]["eval_split"], seed=42)
- train_dataset = split["train"]
- eval_dataset = split["test"]
- print(f" Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
- # Check if output looks like a HF Hub path (contains /)
- output_name = cfg["model"]["output"]
- push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
- if "push_to_hub" in cfg["model"]:
- push_to_hub = bool(cfg["model"]["push_to_hub"])
- output_dir = output_name.split("/")[-1] if push_to_hub else output_name
- report_to = "none"
- if os.environ.get("HF_TOKEN"):
- try:
- import trackio # noqa: F401
- report_to = "trackio"
- except Exception:
- print("Trackio not installed; disabling tracking.")
- tracking = cfg.get("tracking", {})
- if report_to == "trackio":
- project = tracking.get("project")
- if project:
- os.environ.setdefault("TRACKIO_PROJECT", project)
- run_name = tracking.get("run_name")
- if run_name and "{" in run_name:
- from datetime import datetime
- now = datetime.now()
- run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
- "{time}", now.strftime("%H:%M")
- )
- config = SFTConfig(
- output_dir=output_dir,
- push_to_hub=push_to_hub,
- hub_model_id=output_name if push_to_hub else None,
- hub_strategy="every_save" if push_to_hub else "end",
- num_train_epochs=cfg["training"]["epochs"],
- per_device_train_batch_size=cfg["training"]["batch_size"],
- gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
- learning_rate=cfg["training"]["learning_rate"],
- max_length=cfg["training"]["max_length"],
- logging_steps=10,
- save_strategy="steps",
- save_steps=200,
- save_total_limit=2,
- eval_strategy="steps",
- eval_steps=200,
- warmup_ratio=cfg["training"]["warmup_ratio"],
- lr_scheduler_type=cfg["training"]["lr_scheduler"],
- ddp_find_unused_parameters=cfg["training"].get(
- "ddp_find_unused_parameters", False
- ),
- bf16=True,
- report_to=report_to,
- run_name=run_name if report_to == "trackio" else None,
- )
- # LoRA config with modules_to_save for embedding layers
- # This prevents token ID mismatches during inference
- peft_config = LoraConfig(
- r=cfg["lora"]["rank"],
- lora_alpha=cfg["lora"]["alpha"],
- lora_dropout=cfg["lora"]["dropout"],
- bias="none",
- task_type="CAUSAL_LM",
- target_modules=cfg["lora"]["target_modules"],
- modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
- ensure_weight_tying=True,
- )
- print("Loading tokenizer...")
- base_model = cfg["model"]["base"]
- tokenizer = AutoTokenizer.from_pretrained(base_model)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- print("Initializing SFT trainer...")
- trainer = SFTTrainer(
- model=base_model,
- train_dataset=train_dataset,
- eval_dataset=eval_dataset,
- args=config,
- peft_config=peft_config,
- processing_class=tokenizer,
- )
- print("Starting SFT training...")
- trainer.train()
- is_main = os.environ.get("RANK", "0") == "0"
- if dist.is_available() and dist.is_initialized():
- dist.barrier()
- if not is_main:
- return
- if push_to_hub:
- print("Pushing to Hub...")
- trainer.push_to_hub()
- print(f"Done! Model: https://huggingface.co/{output_name}")
- else:
- trainer.save_model()
- print(f"Done! Model saved to: {output_dir}")
- # Export GGUF
- print("\nExporting to GGUF...")
- # Need to get the merged model for GGUF
- print("Loading model for GGUF export...")
- from peft import PeftModel
- base = AutoModelForCausalLM.from_pretrained(
- base_model, torch_dtype=torch.bfloat16, device_map="auto"
- )
- base.config.tie_word_embeddings = False
- model = PeftModel.from_pretrained(base, output_dir, local_files_only=True)
- model = model.merge_and_unload()
- export_gguf(model, tokenizer, output_dir, Path(output_dir).name)
- # Run eval
- eval_avg = run_eval(output_dir)
- if report_to == "trackio":
- try:
- import trackio
- if eval_avg is not None:
- trackio.log({"eval.avg": eval_avg})
- except Exception:
- pass
- def cmd_grpo(args):
- """Run GRPO reinforcement learning on top of merged SFT weights."""
- import torch
- import os
- from datasets import load_dataset
- from peft import LoraConfig, PeftModel, get_peft_model
- from transformers import AutoModelForCausalLM, AutoTokenizer
- from transformers.utils import logging as hf_logging
- hf_logging.set_verbosity_error()
- from trl import GRPOTrainer, GRPOConfig
- # Import reward from the shared module
- sys.path.insert(0, os.path.dirname(__file__))
- from reward import QMDRewardFunction
- with open(args.config) as f:
- cfg = yaml.safe_load(f)
- os.environ.setdefault("HF_LOG_CUDA_MEMORY", "0")
- if args.dry_run:
- print("GRPO Training Configuration:")
- print(yaml.dump(cfg, default_flow_style=False))
- return
- # Tracking
- report_to = "none"
- if os.environ.get("HF_TOKEN"):
- try:
- import trackio # noqa: F401
- report_to = "trackio"
- except Exception:
- print("Trackio not installed; disabling tracking.")
- tracking = cfg.get("tracking", {})
- if report_to == "trackio":
- project = tracking.get("project")
- if project:
- os.environ.setdefault("TRACKIO_PROJECT", project)
- run_name = tracking.get("run_name")
- if run_name and "{" in run_name:
- from datetime import datetime
- now = datetime.now()
- run_name = run_name.replace("{day}", now.strftime("%b %d")).replace(
- "{time}", now.strftime("%H:%M")
- )
- # Load tokenizer
- base_model_name = cfg["model"]["base"]
- print(f"Loading tokenizer from {base_model_name}...")
- tokenizer = AutoTokenizer.from_pretrained(base_model_name)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- # Load and format dataset
- print("Loading dataset...")
- dataset = load_dataset(cfg["dataset"]["name"], split="train")
- def extract_prompt(example):
- content = example[cfg["dataset"]["prompt_field"]][0]["content"]
- messages = [{"role": "user", "content": content}]
- formatted = tokenizer.apply_chat_template(
- messages, tokenize=False, add_generation_prompt=True
- )
- return {"prompt": formatted}
- dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
- max_samples = cfg["dataset"].get("max_samples", len(dataset))
- dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
- print(f"Using {len(dataset)} prompts for GRPO")
- # Load base model, merge SFT adapter
- sft_model_name = cfg["model"]["sft"]
- print(f"Loading SFT model from {sft_model_name}...")
- local_rank = int(os.environ.get("LOCAL_RANK", "0"))
- if torch.cuda.is_available():
- available = torch.cuda.device_count()
- if available == 0:
- raise RuntimeError("CUDA is available but no devices were detected.")
- if local_rank >= available:
- print(
- f"Warning: LOCAL_RANK={local_rank} but only {available} CUDA device(s) visible. "
- "Falling back to the last available device."
- )
- local_rank = available - 1
- torch.cuda.set_device(local_rank)
- dtype_name = cfg["model"].get("torch_dtype", "bfloat16")
- dtype_map = {
- "bfloat16": torch.bfloat16,
- "float16": torch.float16,
- "float32": torch.float32,
- }
- torch_dtype = dtype_map.get(dtype_name, torch.bfloat16)
- model_kwargs = {
- "torch_dtype": torch_dtype,
- "device_map": {"": local_rank} if torch.cuda.is_available() else "auto",
- }
- base_model = AutoModelForCausalLM.from_pretrained(
- base_model_name,
- **model_kwargs,
- )
- model = PeftModel.from_pretrained(base_model, sft_model_name)
- model = model.merge_and_unload()
- print("SFT adapter merged.")
- # Add fresh LoRA for GRPO with modules_to_save
- grpo_lora_config = LoraConfig(
- r=cfg["lora"]["rank"],
- lora_alpha=cfg["lora"]["alpha"],
- lora_dropout=cfg["lora"]["dropout"],
- bias="none",
- task_type="CAUSAL_LM",
- target_modules=cfg["lora"]["target_modules"],
- modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
- )
- model = get_peft_model(model, grpo_lora_config)
- model.print_trainable_parameters()
- # Build GRPO config
- output_name = cfg["model"]["output"]
- push_to_hub = "/" in output_name and not output_name.startswith("outputs/")
- if "push_to_hub" in cfg["model"]:
- push_to_hub = bool(cfg["model"]["push_to_hub"])
- output_dir = output_name.split("/")[-1] if push_to_hub else output_name
- grpo_cfg = cfg.get("grpo", {})
- learning_rate = cfg["training"]["learning_rate"]
- if isinstance(learning_rate, str):
- learning_rate = float(learning_rate)
- config = GRPOConfig(
- output_dir=output_dir,
- push_to_hub=push_to_hub,
- hub_model_id=output_name if push_to_hub else None,
- num_generations=grpo_cfg.get("num_generations", 4),
- max_completion_length=grpo_cfg.get("max_completion_length", 200),
- beta=grpo_cfg.get("beta", 0.04),
- num_train_epochs=cfg["training"]["epochs"],
- per_device_train_batch_size=cfg["training"]["batch_size"],
- gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
- learning_rate=learning_rate,
- max_grad_norm=cfg["training"]["max_grad_norm"],
- max_steps=cfg["training"].get("max_steps", -1),
- logging_steps=10,
- save_strategy="epoch",
- bf16=True,
- skip_memory_metrics=True,
- report_to=report_to,
- run_name=run_name if report_to == "trackio" else None,
- )
- # Train
- print("Initializing GRPO trainer...")
- trainer = GRPOTrainer(
- model=model,
- processing_class=tokenizer,
- args=config,
- train_dataset=dataset,
- reward_funcs=[QMDRewardFunction()],
- )
- print("Starting GRPO training...")
- trainer.train()
- if push_to_hub:
- print("Pushing to Hub...")
- trainer.push_to_hub()
- trainer.save_model()
- if report_to == "trackio":
- try:
- import trackio
- trackio.finish()
- except Exception:
- pass
- print(f"Done! Model saved to: {output_dir}")
- # Export GGUF
- print("\nExporting to GGUF...")
- merged = model.merge_and_unload()
- export_gguf(merged, tokenizer, output_dir, Path(output_dir).name)
- # Run eval
- eval_avg = run_eval(output_dir)
- if report_to == "trackio" and eval_avg is not None:
- try:
- import trackio
- trackio.log({"eval.avg": eval_avg})
- except Exception:
- pass
- def main():
- parser = argparse.ArgumentParser(
- description="QMD Query Expansion Training",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
- Examples:
- uv run train.py sft --config configs/sft.yaml
- uv run train.py grpo --config configs/grpo.yaml
- uv run train.py grpo --config configs/grpo.yaml --dry-run
- """,
- )
- sub = parser.add_subparsers(dest="stage", required=True)
- sft_parser = sub.add_parser("sft", help="Supervised fine-tuning")
- sft_parser.add_argument("--config", required=True, help="Path to SFT config YAML")
- sft_parser.add_argument(
- "--dry-run", action="store_true", help="Print config and exit"
- )
- grpo_parser = sub.add_parser("grpo", help="GRPO reinforcement learning")
- grpo_parser.add_argument("--config", required=True, help="Path to GRPO config YAML")
- grpo_parser.add_argument(
- "--dry-run", action="store_true", help="Print config, test reward, and exit"
- )
- args = parser.parse_args()
- if args.stage == "sft":
- cmd_sft(args)
- elif args.stage == "grpo":
- cmd_grpo(args)
- if __name__ == "__main__":
- main()
|