train_hf_job.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "unsloth",
  6. # "transformers>=4.45.0",
  7. # "datasets",
  8. # "trl>=0.12.0",
  9. # "torch",
  10. # "huggingface_hub",
  11. # ]
  12. # ///
  13. """
  14. Train QMD query expansion model using LoRA on HuggingFace Jobs.
  15. This script is designed to run on HuggingFace Jobs infrastructure.
  16. Uses Unsloth for efficient LoRA finetuning.
  17. Usage:
  18. # Local test
  19. python train_hf_job.py --model Qwen/Qwen3-0.6B --data data/train --dry-run
  20. # HuggingFace Jobs (via huggingface-skills)
  21. # See hugging-face-model-trainer skill for deployment
  22. """
  23. import argparse
  24. import os
  25. from pathlib import Path
  26. def main():
  27. parser = argparse.ArgumentParser(description="Train QMD query expansion model")
  28. parser.add_argument("--model", type=str, default="Qwen/Qwen3-0.6B", help="Base model")
  29. parser.add_argument("--data", type=str, default="data/train", help="Training data directory")
  30. parser.add_argument("--output", type=str, default="models/qmd-expansion", help="Output directory")
  31. parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
  32. parser.add_argument("--batch-size", type=int, default=4, help="Batch size")
  33. parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate")
  34. parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank")
  35. parser.add_argument("--max-seq-length", type=int, default=512, help="Max sequence length")
  36. parser.add_argument("--dry-run", action="store_true", help="Print config and exit")
  37. parser.add_argument("--push-to-hub", type=str, help="Push to HuggingFace Hub repo")
  38. args = parser.parse_args()
  39. config = {
  40. "model": args.model,
  41. "data": args.data,
  42. "output": args.output,
  43. "epochs": args.epochs,
  44. "batch_size": args.batch_size,
  45. "learning_rate": args.lr,
  46. "lora_rank": args.lora_rank,
  47. "lora_alpha": args.lora_rank * 2,
  48. "max_seq_length": args.max_seq_length,
  49. }
  50. if args.dry_run:
  51. print("Training configuration:")
  52. for k, v in config.items():
  53. print(f" {k}: {v}")
  54. return
  55. # Import heavy dependencies only when needed
  56. from unsloth import FastLanguageModel
  57. from datasets import load_dataset
  58. from trl import SFTTrainer, SFTConfig
  59. import torch
  60. print(f"Loading base model: {args.model}")
  61. # Load model with Unsloth
  62. model, tokenizer = FastLanguageModel.from_pretrained(
  63. model_name=args.model,
  64. max_seq_length=args.max_seq_length,
  65. dtype=None, # Auto-detect
  66. load_in_4bit=True, # QLoRA
  67. )
  68. # Configure LoRA
  69. model = FastLanguageModel.get_peft_model(
  70. model,
  71. r=args.lora_rank,
  72. target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
  73. "gate_proj", "up_proj", "down_proj"],
  74. lora_alpha=args.lora_rank * 2,
  75. lora_dropout=0,
  76. bias="none",
  77. use_gradient_checkpointing="unsloth",
  78. random_state=42,
  79. )
  80. # Load dataset
  81. data_path = Path(args.data)
  82. if (data_path / "train_chat.jsonl").exists():
  83. dataset = load_dataset("json", data_files=str(data_path / "train_chat.jsonl"))["train"]
  84. print(f"Loaded {len(dataset)} training examples (chat format)")
  85. else:
  86. dataset = load_dataset("json", data_files=str(data_path / "train.jsonl"))["train"]
  87. print(f"Loaded {len(dataset)} training examples")
  88. # Format function for chat template
  89. def format_chat(example):
  90. messages = example.get("messages", [])
  91. if messages:
  92. text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
  93. else:
  94. text = example.get("text", "")
  95. return {"text": text}
  96. dataset = dataset.map(format_chat)
  97. # Training config
  98. output_dir = Path(args.output)
  99. output_dir.mkdir(parents=True, exist_ok=True)
  100. training_args = SFTConfig(
  101. output_dir=str(output_dir),
  102. num_train_epochs=args.epochs,
  103. per_device_train_batch_size=args.batch_size,
  104. gradient_accumulation_steps=4,
  105. learning_rate=args.lr,
  106. weight_decay=0.01,
  107. warmup_ratio=0.03,
  108. lr_scheduler_type="cosine",
  109. logging_steps=10,
  110. save_strategy="epoch",
  111. bf16=torch.cuda.is_bf16_supported(),
  112. fp16=not torch.cuda.is_bf16_supported(),
  113. optim="adamw_8bit",
  114. seed=42,
  115. max_seq_length=args.max_seq_length,
  116. dataset_text_field="text",
  117. packing=False,
  118. )
  119. # Create trainer
  120. trainer = SFTTrainer(
  121. model=model,
  122. tokenizer=tokenizer,
  123. train_dataset=dataset,
  124. args=training_args,
  125. )
  126. # Train
  127. print("Starting training...")
  128. trainer.train()
  129. # Save
  130. print(f"Saving model to {output_dir}")
  131. model.save_pretrained(output_dir)
  132. tokenizer.save_pretrained(output_dir)
  133. # Push to hub if requested
  134. if args.push_to_hub:
  135. print(f"Pushing to HuggingFace Hub: {args.push_to_hub}")
  136. model.push_to_hub(args.push_to_hub)
  137. tokenizer.push_to_hub(args.push_to_hub)
  138. print("Training complete!")
  139. if __name__ == "__main__":
  140. main()