train_unsloth.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #!/usr/bin/env python3
  2. """
  3. QMD Query Expansion fine-tuning with Unsloth (Qwen3.5 support).
  4. Usage:
  5. python train_unsloth.py --model 0.8B
  6. python train_unsloth.py --model 2B
  7. python train_unsloth.py --model 4B --epochs 3
  8. Requires: pip install unsloth unsloth_zoo
  9. """
  10. import argparse
  11. import json
  12. import sys
  13. from pathlib import Path
  14. MODEL_MAP = {
  15. "0.8B": "unsloth/Qwen3.5-0.8B",
  16. "2B": "unsloth/Qwen3.5-2B",
  17. "4B": "unsloth/Qwen3.5-4B",
  18. "9B": "unsloth/Qwen3.5-9B",
  19. "27B": "unsloth/Qwen3.5-27B",
  20. }
  21. def main():
  22. parser = argparse.ArgumentParser(description="QMD fine-tuning with Unsloth")
  23. parser.add_argument("--model", required=True, choices=list(MODEL_MAP.keys()),
  24. help="Model size to train")
  25. parser.add_argument("--epochs", type=int, default=5)
  26. parser.add_argument("--batch-size", type=int, default=4)
  27. parser.add_argument("--grad-accum", type=int, default=4)
  28. parser.add_argument("--lr", type=float, default=2e-4)
  29. parser.add_argument("--max-seq-len", type=int, default=512)
  30. parser.add_argument("--lora-rank", type=int, default=16)
  31. parser.add_argument("--data", type=str, default="data/train/train.jsonl")
  32. parser.add_argument("--output", type=str, default=None,
  33. help="Output directory (default: outputs/qwen3.5-{size})")
  34. parser.add_argument("--push-hub", type=str, default=None,
  35. help="Push to HF hub (e.g. tobil/qmd-query-expansion-qwen3.5-0.8B)")
  36. parser.add_argument("--no-gguf", action="store_true")
  37. parser.add_argument("--no-eval", action="store_true")
  38. parser.add_argument("--dry-run", action="store_true")
  39. args = parser.parse_args()
  40. model_name = MODEL_MAP[args.model]
  41. output_dir = args.output or f"outputs/qwen3.5-{args.model}"
  42. print(f"{'='*60}")
  43. print(f"QMD Query Expansion — Unsloth SFT")
  44. print(f" Base model: {model_name}")
  45. print(f" Output: {output_dir}")
  46. print(f" Data: {args.data}")
  47. print(f" Epochs: {args.epochs}")
  48. print(f" Batch: {args.batch_size} x {args.grad_accum} accum")
  49. print(f" LR: {args.lr}")
  50. print(f" LoRA rank: {args.lora_rank}")
  51. print(f" Max seq len: {args.max_seq_len}")
  52. print(f"{'='*60}")
  53. if args.dry_run:
  54. print("Dry run — exiting.")
  55. return
  56. # --- Imports (heavy) ---
  57. import os
  58. import torch
  59. from unsloth import FastLanguageModel
  60. from datasets import load_dataset
  61. from trl import SFTTrainer, SFTConfig
  62. # --- Load model ---
  63. print(f"\nLoading {model_name}...")
  64. model, tokenizer = FastLanguageModel.from_pretrained(
  65. model_name=model_name,
  66. max_seq_length=args.max_seq_len,
  67. load_in_4bit=False,
  68. load_in_16bit=True,
  69. full_finetuning=False,
  70. )
  71. # --- LoRA ---
  72. model = FastLanguageModel.get_peft_model(
  73. model,
  74. r=args.lora_rank,
  75. target_modules=[
  76. "q_proj", "k_proj", "v_proj", "o_proj",
  77. "gate_proj", "up_proj", "down_proj",
  78. ],
  79. lora_alpha=args.lora_rank,
  80. lora_dropout=0,
  81. bias="none",
  82. use_gradient_checkpointing="unsloth",
  83. random_state=3407,
  84. max_seq_length=args.max_seq_len,
  85. )
  86. # --- Dataset ---
  87. print(f"Loading dataset from {args.data}...")
  88. dataset = load_dataset("json", data_files=args.data, split="train")
  89. dataset = dataset.shuffle(seed=42)
  90. split = dataset.train_test_split(test_size=0.1, seed=42)
  91. train_ds = split["train"]
  92. eval_ds = split["test"]
  93. print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}")
  94. # --- Tracking ---
  95. report_to = "none"
  96. if os.environ.get("HF_TOKEN"):
  97. try:
  98. import trackio
  99. report_to = "trackio"
  100. os.environ.setdefault("TRACKIO_PROJECT", "qmd-query-expansion")
  101. except ImportError:
  102. pass
  103. # --- Trainer ---
  104. trainer = SFTTrainer(
  105. model=model,
  106. tokenizer=tokenizer,
  107. train_dataset=train_ds,
  108. eval_dataset=eval_ds,
  109. args=SFTConfig(
  110. output_dir=output_dir,
  111. max_seq_length=args.max_seq_len,
  112. num_train_epochs=args.epochs,
  113. per_device_train_batch_size=args.batch_size,
  114. gradient_accumulation_steps=args.grad_accum,
  115. learning_rate=args.lr,
  116. warmup_ratio=0.03,
  117. lr_scheduler_type="cosine",
  118. logging_steps=10,
  119. save_strategy="steps",
  120. save_steps=200,
  121. save_total_limit=3,
  122. eval_strategy="steps",
  123. eval_steps=200,
  124. bf16=True,
  125. optim="adamw_8bit",
  126. seed=3407,
  127. dataset_num_proc=4,
  128. report_to=report_to,
  129. run_name=f"sft-qwen3.5-{args.model}",
  130. ),
  131. )
  132. print("\nStarting training...")
  133. stats = trainer.train()
  134. print(f"\nTraining complete!")
  135. print(f" Total steps: {stats.global_step}")
  136. print(f" Final loss: {stats.training_loss:.4f}")
  137. # --- Save ---
  138. trainer.save_model(output_dir)
  139. tokenizer.save_pretrained(output_dir)
  140. print(f"Adapter saved to {output_dir}")
  141. # --- GGUF export ---
  142. if not args.no_gguf:
  143. print("\nExporting GGUF quantizations...")
  144. gguf_dir = f"{output_dir}/gguf"
  145. for quant in ["q4_k_m", "q8_0"]:
  146. print(f" {quant}...")
  147. try:
  148. model.save_pretrained_gguf(
  149. gguf_dir, tokenizer, quantization_method=quant
  150. )
  151. print(f" ✓ {quant} saved")
  152. except Exception as e:
  153. print(f" ✗ {quant} failed: {e}")
  154. # --- Push to Hub ---
  155. if args.push_hub:
  156. print(f"\nPushing to {args.push_hub}...")
  157. model.push_to_hub_merged(args.push_hub, tokenizer, save_method="lora")
  158. if not args.no_gguf:
  159. for quant in ["q4_k_m", "q8_0"]:
  160. try:
  161. model.push_to_hub_gguf(args.push_hub, tokenizer, quantization_method=quant)
  162. except Exception as e:
  163. print(f" GGUF push {quant} failed: {e}")
  164. # --- Eval ---
  165. if not args.no_eval:
  166. print("\nRunning evaluation...")
  167. import subprocess
  168. subprocess.run(
  169. [sys.executable, "eval.py", output_dir],
  170. cwd=str(Path(__file__).parent),
  171. )
  172. print(f"\n{'='*60}")
  173. print(f"Done! Model at: {output_dir}")
  174. print(f"{'='*60}")
  175. if __name__ == "__main__":
  176. main()