| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- #!/usr/bin/env python3
- """
- QMD Query Expansion fine-tuning with Unsloth (Qwen3.5 support).
- Usage:
- python train_unsloth.py --model 0.8B
- python train_unsloth.py --model 2B
- python train_unsloth.py --model 4B --epochs 3
- Requires: pip install unsloth unsloth_zoo
- """
- import argparse
- import json
- import sys
- from pathlib import Path
- MODEL_MAP = {
- "0.8B": "unsloth/Qwen3.5-0.8B",
- "2B": "unsloth/Qwen3.5-2B",
- "4B": "unsloth/Qwen3.5-4B",
- "9B": "unsloth/Qwen3.5-9B",
- "27B": "unsloth/Qwen3.5-27B",
- }
- def main():
- parser = argparse.ArgumentParser(description="QMD fine-tuning with Unsloth")
- parser.add_argument("--model", required=True, choices=list(MODEL_MAP.keys()),
- help="Model size to train")
- parser.add_argument("--epochs", type=int, default=5)
- parser.add_argument("--batch-size", type=int, default=4)
- parser.add_argument("--grad-accum", type=int, default=4)
- parser.add_argument("--lr", type=float, default=2e-4)
- parser.add_argument("--max-seq-len", type=int, default=512)
- parser.add_argument("--lora-rank", type=int, default=16)
- parser.add_argument("--data", type=str, default="data/train/train.jsonl")
- parser.add_argument("--output", type=str, default=None,
- help="Output directory (default: outputs/qwen3.5-{size})")
- parser.add_argument("--push-hub", type=str, default=None,
- help="Push to HF hub (e.g. tobil/qmd-query-expansion-qwen3.5-0.8B)")
- parser.add_argument("--no-gguf", action="store_true")
- parser.add_argument("--no-eval", action="store_true")
- parser.add_argument("--dry-run", action="store_true")
- args = parser.parse_args()
- model_name = MODEL_MAP[args.model]
- output_dir = args.output or f"outputs/qwen3.5-{args.model}"
- print(f"{'='*60}")
- print(f"QMD Query Expansion — Unsloth SFT")
- print(f" Base model: {model_name}")
- print(f" Output: {output_dir}")
- print(f" Data: {args.data}")
- print(f" Epochs: {args.epochs}")
- print(f" Batch: {args.batch_size} x {args.grad_accum} accum")
- print(f" LR: {args.lr}")
- print(f" LoRA rank: {args.lora_rank}")
- print(f" Max seq len: {args.max_seq_len}")
- print(f"{'='*60}")
- if args.dry_run:
- print("Dry run — exiting.")
- return
- # --- Imports (heavy) ---
- import os
- import torch
- from unsloth import FastLanguageModel
- from datasets import load_dataset
- from trl import SFTTrainer, SFTConfig
- # --- Load model ---
- print(f"\nLoading {model_name}...")
- model, tokenizer = FastLanguageModel.from_pretrained(
- model_name=model_name,
- max_seq_length=args.max_seq_len,
- load_in_4bit=False,
- load_in_16bit=True,
- full_finetuning=False,
- )
- # --- LoRA ---
- model = FastLanguageModel.get_peft_model(
- model,
- r=args.lora_rank,
- target_modules=[
- "q_proj", "k_proj", "v_proj", "o_proj",
- "gate_proj", "up_proj", "down_proj",
- ],
- lora_alpha=args.lora_rank,
- lora_dropout=0,
- bias="none",
- use_gradient_checkpointing="unsloth",
- random_state=3407,
- max_seq_length=args.max_seq_len,
- )
- # --- Dataset ---
- print(f"Loading dataset from {args.data}...")
- dataset = load_dataset("json", data_files=args.data, split="train")
- dataset = dataset.shuffle(seed=42)
- split = dataset.train_test_split(test_size=0.1, seed=42)
- train_ds = split["train"]
- eval_ds = split["test"]
- print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}")
- # --- Tracking ---
- report_to = "none"
- if os.environ.get("HF_TOKEN"):
- try:
- import trackio
- report_to = "trackio"
- os.environ.setdefault("TRACKIO_PROJECT", "qmd-query-expansion")
- except ImportError:
- pass
- # --- Trainer ---
- trainer = SFTTrainer(
- model=model,
- tokenizer=tokenizer,
- train_dataset=train_ds,
- eval_dataset=eval_ds,
- args=SFTConfig(
- output_dir=output_dir,
- max_seq_length=args.max_seq_len,
- num_train_epochs=args.epochs,
- per_device_train_batch_size=args.batch_size,
- gradient_accumulation_steps=args.grad_accum,
- learning_rate=args.lr,
- warmup_ratio=0.03,
- lr_scheduler_type="cosine",
- logging_steps=10,
- save_strategy="steps",
- save_steps=200,
- save_total_limit=3,
- eval_strategy="steps",
- eval_steps=200,
- bf16=True,
- optim="adamw_8bit",
- seed=3407,
- dataset_num_proc=4,
- report_to=report_to,
- run_name=f"sft-qwen3.5-{args.model}",
- ),
- )
- print("\nStarting training...")
- stats = trainer.train()
- print(f"\nTraining complete!")
- print(f" Total steps: {stats.global_step}")
- print(f" Final loss: {stats.training_loss:.4f}")
- # --- Save ---
- trainer.save_model(output_dir)
- tokenizer.save_pretrained(output_dir)
- print(f"Adapter saved to {output_dir}")
- # --- GGUF export ---
- if not args.no_gguf:
- print("\nExporting GGUF quantizations...")
- gguf_dir = f"{output_dir}/gguf"
- for quant in ["q4_k_m", "q8_0"]:
- print(f" {quant}...")
- try:
- model.save_pretrained_gguf(
- gguf_dir, tokenizer, quantization_method=quant
- )
- print(f" ✓ {quant} saved")
- except Exception as e:
- print(f" ✗ {quant} failed: {e}")
- # --- Push to Hub ---
- if args.push_hub:
- print(f"\nPushing to {args.push_hub}...")
- model.push_to_hub_merged(args.push_hub, tokenizer, save_method="lora")
- if not args.no_gguf:
- for quant in ["q4_k_m", "q8_0"]:
- try:
- model.push_to_hub_gguf(args.push_hub, tokenizer, quantization_method=quant)
- except Exception as e:
- print(f" GGUF push {quant} failed: {e}")
- # --- Eval ---
- if not args.no_eval:
- print("\nRunning evaluation...")
- import subprocess
- subprocess.run(
- [sys.executable, "eval.py", output_dir],
- cwd=str(Path(__file__).parent),
- )
- print(f"\n{'='*60}")
- print(f"Done! Model at: {output_dir}")
- print(f"{'='*60}")
- if __name__ == "__main__":
- main()
|