#!/usr/bin/env python3 """ QMD Query Expansion fine-tuning with Unsloth (Qwen3.5 support). Usage: python train_unsloth.py --model 0.8B python train_unsloth.py --model 2B python train_unsloth.py --model 4B --epochs 3 Requires: pip install unsloth unsloth_zoo """ import argparse import json import sys from pathlib import Path MODEL_MAP = { "0.8B": "unsloth/Qwen3.5-0.8B", "2B": "unsloth/Qwen3.5-2B", "4B": "unsloth/Qwen3.5-4B", "9B": "unsloth/Qwen3.5-9B", "27B": "unsloth/Qwen3.5-27B", } def main(): parser = argparse.ArgumentParser(description="QMD fine-tuning with Unsloth") parser.add_argument("--model", required=True, choices=list(MODEL_MAP.keys()), help="Model size to train") parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--grad-accum", type=int, default=4) parser.add_argument("--lr", type=float, default=2e-4) parser.add_argument("--max-seq-len", type=int, default=512) parser.add_argument("--lora-rank", type=int, default=16) parser.add_argument("--data", type=str, default="data/train/train.jsonl") parser.add_argument("--output", type=str, default=None, help="Output directory (default: outputs/qwen3.5-{size})") parser.add_argument("--push-hub", type=str, default=None, help="Push to HF hub (e.g. tobil/qmd-query-expansion-qwen3.5-0.8B)") parser.add_argument("--no-gguf", action="store_true") parser.add_argument("--no-eval", action="store_true") parser.add_argument("--dry-run", action="store_true") args = parser.parse_args() model_name = MODEL_MAP[args.model] output_dir = args.output or f"outputs/qwen3.5-{args.model}" print(f"{'='*60}") print(f"QMD Query Expansion — Unsloth SFT") print(f" Base model: {model_name}") print(f" Output: {output_dir}") print(f" Data: {args.data}") print(f" Epochs: {args.epochs}") print(f" Batch: {args.batch_size} x {args.grad_accum} accum") print(f" LR: {args.lr}") print(f" LoRA rank: {args.lora_rank}") print(f" Max seq len: {args.max_seq_len}") print(f"{'='*60}") if args.dry_run: print("Dry run — exiting.") return # --- Imports (heavy) --- import os import torch from unsloth import FastLanguageModel from datasets import load_dataset from trl import SFTTrainer, SFTConfig # --- Load model --- print(f"\nLoading {model_name}...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name, max_seq_length=args.max_seq_len, load_in_4bit=False, load_in_16bit=True, full_finetuning=False, ) # --- LoRA --- model = FastLanguageModel.get_peft_model( model, r=args.lora_rank, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=args.lora_rank, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=3407, max_seq_length=args.max_seq_len, ) # --- Dataset --- print(f"Loading dataset from {args.data}...") dataset = load_dataset("json", data_files=args.data, split="train") dataset = dataset.shuffle(seed=42) split = dataset.train_test_split(test_size=0.1, seed=42) train_ds = split["train"] eval_ds = split["test"] print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}") # --- Tracking --- report_to = "none" if os.environ.get("HF_TOKEN"): try: import trackio report_to = "trackio" os.environ.setdefault("TRACKIO_PROJECT", "qmd-query-expansion") except ImportError: pass # --- Trainer --- trainer = SFTTrainer( model=model, tokenizer=tokenizer, train_dataset=train_ds, eval_dataset=eval_ds, args=SFTConfig( output_dir=output_dir, max_seq_length=args.max_seq_len, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.lr, warmup_ratio=0.03, lr_scheduler_type="cosine", logging_steps=10, save_strategy="steps", save_steps=200, save_total_limit=3, eval_strategy="steps", eval_steps=200, bf16=True, optim="adamw_8bit", seed=3407, dataset_num_proc=4, report_to=report_to, run_name=f"sft-qwen3.5-{args.model}", ), ) print("\nStarting training...") stats = trainer.train() print(f"\nTraining complete!") print(f" Total steps: {stats.global_step}") print(f" Final loss: {stats.training_loss:.4f}") # --- Save --- trainer.save_model(output_dir) tokenizer.save_pretrained(output_dir) print(f"Adapter saved to {output_dir}") # --- GGUF export --- if not args.no_gguf: print("\nExporting GGUF quantizations...") gguf_dir = f"{output_dir}/gguf" for quant in ["q4_k_m", "q8_0"]: print(f" {quant}...") try: model.save_pretrained_gguf( gguf_dir, tokenizer, quantization_method=quant ) print(f" ✓ {quant} saved") except Exception as e: print(f" ✗ {quant} failed: {e}") # --- Push to Hub --- if args.push_hub: print(f"\nPushing to {args.push_hub}...") model.push_to_hub_merged(args.push_hub, tokenizer, save_method="lora") if not args.no_gguf: for quant in ["q4_k_m", "q8_0"]: try: model.push_to_hub_gguf(args.push_hub, tokenizer, quantization_method=quant) except Exception as e: print(f" GGUF push {quant} failed: {e}") # --- Eval --- if not args.no_eval: print("\nRunning evaluation...") import subprocess subprocess.run( [sys.executable, "eval.py", output_dir], cwd=str(Path(__file__).parent), ) print(f"\n{'='*60}") print(f"Done! Model at: {output_dir}") print(f"{'='*60}") if __name__ == "__main__": main()