Bladeren bron

Add 1.7B and 4B GRPO training and GGUF conversion scripts

Training scripts for GRPO fine-tuning:
- train_1.7B_grpo.py: GRPO training for Qwen3-1.7B
- train_4B_grpo.py: GRPO training for Qwen3-4B

GGUF conversion scripts:
- convert_1.7B_gguf.py: Merge SFT+GRPO adapters and convert to GGUF
- convert_4B_gguf.py: Merge SFT+GRPO adapters and convert to GGUF

All scripts use PEP 723 inline dependencies for HuggingFace Jobs.

Models published:
- tobil/qmd-query-expansion-1.7B-sft
- tobil/qmd-query-expansion-1.7B-grpo
- tobil/qmd-query-expansion-1.7B-gguf
- tobil/qmd-query-expansion-4B-sft
- tobil/qmd-query-expansion-4B-grpo
- tobil/qmd-query-expansion-4B-gguf

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 4 maanden geleden
bovenliggende
commit
8a1c4cdab0
4 gewijzigde bestanden met toevoegingen van 1368 en 0 verwijderingen
  1. 282 0
      finetune/convert_1.7B_gguf.py
  2. 282 0
      finetune/convert_4B_gguf.py
  3. 402 0
      finetune/train_1.7B_grpo.py
  4. 402 0
      finetune/train_4B_grpo.py

+ 282 - 0
finetune/convert_1.7B_gguf.py

@@ -0,0 +1,282 @@
+#!/usr/bin/env python3
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "transformers>=4.36.0",
+#     "peft>=0.7.0",
+#     "torch>=2.0.0",
+#     "accelerate>=0.24.0",
+#     "huggingface_hub>=0.20.0",
+#     "sentencepiece>=0.1.99",
+#     "protobuf>=3.20.0",
+#     "numpy",
+#     "gguf",
+# ]
+# ///
+"""
+GGUF Conversion for QMD Query Expansion 1.7B Model
+
+Loads base model, applies SFT adapter, then GRPO adapter, merges all,
+and converts to GGUF format for use with Ollama/llama.cpp/LM Studio.
+"""
+
+import os
+import sys
+import subprocess
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from peft import PeftModel
+from huggingface_hub import HfApi, login
+
+# Configuration
+BASE_MODEL = "Qwen/Qwen3-1.7B"
+SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
+GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
+OUTPUT_REPO = "tobil/qmd-query-expansion-1.7B-gguf"
+
+def run_command(cmd, description):
+    """Run a command with error handling."""
+    print(f"   {description}...")
+    try:
+        result = subprocess.run(cmd, check=True, capture_output=True, text=True)
+        return True
+    except subprocess.CalledProcessError as e:
+        print(f"   ❌ Command failed: {' '.join(cmd)}")
+        if e.stderr:
+            print(f"   STDERR: {e.stderr[:500]}")
+        return False
+    except FileNotFoundError:
+        print(f"   ❌ Command not found: {cmd[0]}")
+        return False
+
+
+print("🔄 QMD Query Expansion 1.7B GGUF Conversion")
+print("=" * 60)
+
+# Install build tools
+print("\n📦 Installing build dependencies...")
+subprocess.run(["apt-get", "update", "-qq"], capture_output=True)
+subprocess.run(["apt-get", "install", "-y", "-qq", "build-essential", "cmake", "git"], capture_output=True)
+print("   ✅ Build tools ready")
+
+# Login to HuggingFace
+hf_token = os.environ.get("HF_TOKEN")
+if hf_token:
+    print("\n🔐 Logging in to HuggingFace...")
+    login(token=hf_token)
+    print("   ✅ Logged in")
+
+# Step 1: Load base model
+print(f"\n🔧 Step 1: Loading base model {BASE_MODEL}...")
+base_model = AutoModelForCausalLM.from_pretrained(
+    BASE_MODEL,
+    torch_dtype=torch.bfloat16,
+    device_map="auto",
+    trust_remote_code=True,
+)
+print("   ✅ Base model loaded")
+
+# Step 2: Load and merge SFT adapter
+print(f"\n🔧 Step 2: Loading SFT adapter {SFT_MODEL}...")
+model = PeftModel.from_pretrained(base_model, SFT_MODEL)
+print("   Merging SFT adapter...")
+model = model.merge_and_unload()
+print("   ✅ SFT merged")
+
+# Step 3: Load and merge GRPO adapter
+print(f"\n🔧 Step 3: Loading GRPO adapter {GRPO_MODEL}...")
+model = PeftModel.from_pretrained(model, GRPO_MODEL)
+print("   Merging GRPO adapter...")
+merged_model = model.merge_and_unload()
+print("   ✅ GRPO merged - final model ready")
+
+# Load tokenizer
+print("\n📝 Loading tokenizer...")
+tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
+print("   ✅ Tokenizer loaded")
+
+# Step 4: Save merged model
+print("\n💾 Step 4: Saving merged model to disk...")
+merged_dir = "/tmp/merged_model"
+merged_model.save_pretrained(merged_dir, safe_serialization=True)
+tokenizer.save_pretrained(merged_dir)
+print(f"   ✅ Saved to {merged_dir}")
+
+# Step 5: Setup llama.cpp
+print("\n📥 Step 5: Setting up llama.cpp...")
+if not os.path.exists("/tmp/llama.cpp"):
+    run_command(
+        ["git", "clone", "--depth", "1", "https://github.com/ggerganov/llama.cpp.git", "/tmp/llama.cpp"],
+        "Cloning llama.cpp"
+    )
+
+# Install Python deps
+subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-r", "/tmp/llama.cpp/requirements.txt"], capture_output=True)
+subprocess.run([sys.executable, "-m", "pip", "install", "-q", "sentencepiece", "protobuf"], capture_output=True)
+print("   ✅ llama.cpp ready")
+
+# Step 6: Convert to GGUF (FP16)
+print("\n🔄 Step 6: Converting to GGUF format (FP16)...")
+gguf_output_dir = "/tmp/gguf_output"
+os.makedirs(gguf_output_dir, exist_ok=True)
+
+model_name = "qmd-query-expansion-1.7B"
+gguf_file = f"{gguf_output_dir}/{model_name}-f16.gguf"
+
+convert_script = "/tmp/llama.cpp/convert_hf_to_gguf.py"
+if not run_command(
+    [sys.executable, convert_script, merged_dir, "--outfile", gguf_file, "--outtype", "f16"],
+    "Converting to FP16 GGUF"
+):
+    print("   ❌ Conversion failed!")
+    sys.exit(1)
+
+size_mb = os.path.getsize(gguf_file) / (1024 * 1024)
+print(f"   ✅ FP16 GGUF created: {size_mb:.1f} MB")
+
+# Step 7: Build quantize tool
+print("\n⚙️  Step 7: Building quantize tool...")
+os.makedirs("/tmp/llama.cpp/build", exist_ok=True)
+
+run_command(
+    ["cmake", "-B", "/tmp/llama.cpp/build", "-S", "/tmp/llama.cpp", "-DGGML_CUDA=OFF"],
+    "Configuring with CMake"
+)
+run_command(
+    ["cmake", "--build", "/tmp/llama.cpp/build", "--target", "llama-quantize", "-j", "4"],
+    "Building llama-quantize"
+)
+
+quantize_bin = "/tmp/llama.cpp/build/bin/llama-quantize"
+print("   ✅ Quantize tool built")
+
+# Step 8: Create quantized versions
+print("\n⚙️  Step 8: Creating quantized versions...")
+quant_formats = [
+    ("Q4_K_M", "4-bit medium (recommended)"),
+    ("Q5_K_M", "5-bit medium"),
+    ("Q8_0", "8-bit"),
+]
+
+quantized_files = []
+for quant_type, description in quant_formats:
+    print(f"   Creating {quant_type} ({description})...")
+    quant_file = f"{gguf_output_dir}/{model_name}-{quant_type.lower()}.gguf"
+
+    if run_command([quantize_bin, gguf_file, quant_file, quant_type], f"Quantizing to {quant_type}"):
+        size_mb = os.path.getsize(quant_file) / (1024 * 1024)
+        print(f"   ✅ {quant_type}: {size_mb:.1f} MB")
+        quantized_files.append((quant_file, quant_type))
+    else:
+        print(f"   ⚠️  Skipping {quant_type}")
+
+# Step 9: Upload to Hub
+print("\n☁️  Step 9: Uploading to Hugging Face Hub...")
+api = HfApi()
+
+print(f"   Creating repository: {OUTPUT_REPO}")
+api.create_repo(repo_id=OUTPUT_REPO, repo_type="model", exist_ok=True)
+
+# Upload F16
+print("   Uploading FP16...")
+api.upload_file(
+    path_or_fileobj=gguf_file,
+    path_in_repo=f"{model_name}-f16.gguf",
+    repo_id=OUTPUT_REPO,
+)
+print("   ✅ FP16 uploaded")
+
+# Upload quantized versions
+for quant_file, quant_type in quantized_files:
+    print(f"   Uploading {quant_type}...")
+    api.upload_file(
+        path_or_fileobj=quant_file,
+        path_in_repo=f"{model_name}-{quant_type.lower()}.gguf",
+        repo_id=OUTPUT_REPO,
+    )
+    print(f"   ✅ {quant_type} uploaded")
+
+# Create README
+print("\n📝 Creating README...")
+readme_content = f"""---
+base_model: {BASE_MODEL}
+tags:
+- gguf
+- llama.cpp
+- quantized
+- query-expansion
+- qmd
+---
+
+# QMD Query Expansion 1.7B (GGUF)
+
+GGUF conversion of the QMD Query Expansion model for use with Ollama, llama.cpp, and LM Studio.
+
+## Model Details
+
+- **Base Model:** {BASE_MODEL}
+- **SFT Adapter:** {SFT_MODEL}
+- **GRPO Adapter:** {GRPO_MODEL}
+- **Task:** Query expansion for hybrid search (lex/vec/hyde format)
+
+## Available Quantizations
+
+| File | Quant | Description |
+|------|-------|-------------|
+| {model_name}-f16.gguf | F16 | Full precision |
+| {model_name}-q8_0.gguf | Q8_0 | 8-bit |
+| {model_name}-q5_k_m.gguf | Q5_K_M | 5-bit medium |
+| {model_name}-q4_k_m.gguf | Q4_K_M | 4-bit medium (recommended) |
+
+## Usage
+
+### With Ollama
+
+```bash
+# Download
+huggingface-cli download {OUTPUT_REPO} {model_name}-q4_k_m.gguf --local-dir .
+
+# Create Modelfile
+echo 'FROM ./{model_name}-q4_k_m.gguf' > Modelfile
+
+# Create and run
+ollama create qmd-expand -f Modelfile
+ollama run qmd-expand
+```
+
+### Prompt Format
+
+Use Qwen3 chat format with `/no_think`:
+
+```
+<|im_start|>user
+/no_think Expand this search query: your query here<|im_end|>
+<|im_start|>assistant
+```
+
+### Expected Output
+
+```
+lex: keyword variation 1
+lex: keyword variation 2
+vec: natural language reformulation
+hyde: Hypothetical document passage answering the query.
+```
+
+## License
+
+Apache 2.0 (inherited from Qwen3)
+"""
+
+api.upload_file(
+    path_or_fileobj=readme_content.encode(),
+    path_in_repo="README.md",
+    repo_id=OUTPUT_REPO,
+)
+print("   ✅ README uploaded")
+
+print("\n" + "=" * 60)
+print("✅ GGUF Conversion Complete!")
+print(f"📦 Repository: https://huggingface.co/{OUTPUT_REPO}")
+print("=" * 60)

+ 282 - 0
finetune/convert_4B_gguf.py

@@ -0,0 +1,282 @@
+#!/usr/bin/env python3
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "transformers>=4.36.0",
+#     "peft>=0.7.0",
+#     "torch>=2.0.0",
+#     "accelerate>=0.24.0",
+#     "huggingface_hub>=0.20.0",
+#     "sentencepiece>=0.1.99",
+#     "protobuf>=3.20.0",
+#     "numpy",
+#     "gguf",
+# ]
+# ///
+"""
+GGUF Conversion for QMD Query Expansion 4B Model
+
+Loads base model, applies SFT adapter, then GRPO adapter, merges all,
+and converts to GGUF format for use with Ollama/llama.cpp/LM Studio.
+"""
+
+import os
+import sys
+import subprocess
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from peft import PeftModel
+from huggingface_hub import HfApi, login
+
+# Configuration
+BASE_MODEL = "Qwen/Qwen3-4B"
+SFT_MODEL = "tobil/qmd-query-expansion-4B-sft"
+GRPO_MODEL = "tobil/qmd-query-expansion-4B-grpo"
+OUTPUT_REPO = "tobil/qmd-query-expansion-4B-gguf"
+
+def run_command(cmd, description):
+    """Run a command with error handling."""
+    print(f"   {description}...")
+    try:
+        result = subprocess.run(cmd, check=True, capture_output=True, text=True)
+        return True
+    except subprocess.CalledProcessError as e:
+        print(f"   ❌ Command failed: {' '.join(cmd)}")
+        if e.stderr:
+            print(f"   STDERR: {e.stderr[:500]}")
+        return False
+    except FileNotFoundError:
+        print(f"   ❌ Command not found: {cmd[0]}")
+        return False
+
+
+print("🔄 QMD Query Expansion 4B GGUF Conversion")
+print("=" * 60)
+
+# Install build tools
+print("\n📦 Installing build dependencies...")
+subprocess.run(["apt-get", "update", "-qq"], capture_output=True)
+subprocess.run(["apt-get", "install", "-y", "-qq", "build-essential", "cmake", "git"], capture_output=True)
+print("   ✅ Build tools ready")
+
+# Login to HuggingFace
+hf_token = os.environ.get("HF_TOKEN")
+if hf_token:
+    print("\n🔐 Logging in to HuggingFace...")
+    login(token=hf_token)
+    print("   ✅ Logged in")
+
+# Step 1: Load base model
+print(f"\n🔧 Step 1: Loading base model {BASE_MODEL}...")
+base_model = AutoModelForCausalLM.from_pretrained(
+    BASE_MODEL,
+    torch_dtype=torch.bfloat16,
+    device_map="auto",
+    trust_remote_code=True,
+)
+print("   ✅ Base model loaded")
+
+# Step 2: Load and merge SFT adapter
+print(f"\n🔧 Step 2: Loading SFT adapter {SFT_MODEL}...")
+model = PeftModel.from_pretrained(base_model, SFT_MODEL)
+print("   Merging SFT adapter...")
+model = model.merge_and_unload()
+print("   ✅ SFT merged")
+
+# Step 3: Load and merge GRPO adapter
+print(f"\n🔧 Step 3: Loading GRPO adapter {GRPO_MODEL}...")
+model = PeftModel.from_pretrained(model, GRPO_MODEL)
+print("   Merging GRPO adapter...")
+merged_model = model.merge_and_unload()
+print("   ✅ GRPO merged - final model ready")
+
+# Load tokenizer
+print("\n📝 Loading tokenizer...")
+tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
+print("   ✅ Tokenizer loaded")
+
+# Step 4: Save merged model
+print("\n💾 Step 4: Saving merged model to disk...")
+merged_dir = "/tmp/merged_model"
+merged_model.save_pretrained(merged_dir, safe_serialization=True)
+tokenizer.save_pretrained(merged_dir)
+print(f"   ✅ Saved to {merged_dir}")
+
+# Step 5: Setup llama.cpp
+print("\n📥 Step 5: Setting up llama.cpp...")
+if not os.path.exists("/tmp/llama.cpp"):
+    run_command(
+        ["git", "clone", "--depth", "1", "https://github.com/ggerganov/llama.cpp.git", "/tmp/llama.cpp"],
+        "Cloning llama.cpp"
+    )
+
+# Install Python deps
+subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-r", "/tmp/llama.cpp/requirements.txt"], capture_output=True)
+subprocess.run([sys.executable, "-m", "pip", "install", "-q", "sentencepiece", "protobuf"], capture_output=True)
+print("   ✅ llama.cpp ready")
+
+# Step 6: Convert to GGUF (FP16)
+print("\n🔄 Step 6: Converting to GGUF format (FP16)...")
+gguf_output_dir = "/tmp/gguf_output"
+os.makedirs(gguf_output_dir, exist_ok=True)
+
+model_name = "qmd-query-expansion-4B"
+gguf_file = f"{gguf_output_dir}/{model_name}-f16.gguf"
+
+convert_script = "/tmp/llama.cpp/convert_hf_to_gguf.py"
+if not run_command(
+    [sys.executable, convert_script, merged_dir, "--outfile", gguf_file, "--outtype", "f16"],
+    "Converting to FP16 GGUF"
+):
+    print("   ❌ Conversion failed!")
+    sys.exit(1)
+
+size_mb = os.path.getsize(gguf_file) / (1024 * 1024)
+print(f"   ✅ FP16 GGUF created: {size_mb:.1f} MB")
+
+# Step 7: Build quantize tool
+print("\n⚙️  Step 7: Building quantize tool...")
+os.makedirs("/tmp/llama.cpp/build", exist_ok=True)
+
+run_command(
+    ["cmake", "-B", "/tmp/llama.cpp/build", "-S", "/tmp/llama.cpp", "-DGGML_CUDA=OFF"],
+    "Configuring with CMake"
+)
+run_command(
+    ["cmake", "--build", "/tmp/llama.cpp/build", "--target", "llama-quantize", "-j", "4"],
+    "Building llama-quantize"
+)
+
+quantize_bin = "/tmp/llama.cpp/build/bin/llama-quantize"
+print("   ✅ Quantize tool built")
+
+# Step 8: Create quantized versions
+print("\n⚙️  Step 8: Creating quantized versions...")
+quant_formats = [
+    ("Q4_K_M", "4-bit medium (recommended)"),
+    ("Q5_K_M", "5-bit medium"),
+    ("Q8_0", "8-bit"),
+]
+
+quantized_files = []
+for quant_type, description in quant_formats:
+    print(f"   Creating {quant_type} ({description})...")
+    quant_file = f"{gguf_output_dir}/{model_name}-{quant_type.lower()}.gguf"
+
+    if run_command([quantize_bin, gguf_file, quant_file, quant_type], f"Quantizing to {quant_type}"):
+        size_mb = os.path.getsize(quant_file) / (1024 * 1024)
+        print(f"   ✅ {quant_type}: {size_mb:.1f} MB")
+        quantized_files.append((quant_file, quant_type))
+    else:
+        print(f"   ⚠️  Skipping {quant_type}")
+
+# Step 9: Upload to Hub
+print("\n☁️  Step 9: Uploading to Hugging Face Hub...")
+api = HfApi()
+
+print(f"   Creating repository: {OUTPUT_REPO}")
+api.create_repo(repo_id=OUTPUT_REPO, repo_type="model", exist_ok=True)
+
+# Upload F16
+print("   Uploading FP16...")
+api.upload_file(
+    path_or_fileobj=gguf_file,
+    path_in_repo=f"{model_name}-f16.gguf",
+    repo_id=OUTPUT_REPO,
+)
+print("   ✅ FP16 uploaded")
+
+# Upload quantized versions
+for quant_file, quant_type in quantized_files:
+    print(f"   Uploading {quant_type}...")
+    api.upload_file(
+        path_or_fileobj=quant_file,
+        path_in_repo=f"{model_name}-{quant_type.lower()}.gguf",
+        repo_id=OUTPUT_REPO,
+    )
+    print(f"   ✅ {quant_type} uploaded")
+
+# Create README
+print("\n📝 Creating README...")
+readme_content = f"""---
+base_model: {BASE_MODEL}
+tags:
+- gguf
+- llama.cpp
+- quantized
+- query-expansion
+- qmd
+---
+
+# QMD Query Expansion 4B (GGUF)
+
+GGUF conversion of the QMD Query Expansion model for use with Ollama, llama.cpp, and LM Studio.
+
+## Model Details
+
+- **Base Model:** {BASE_MODEL}
+- **SFT Adapter:** {SFT_MODEL}
+- **GRPO Adapter:** {GRPO_MODEL}
+- **Task:** Query expansion for hybrid search (lex/vec/hyde format)
+
+## Available Quantizations
+
+| File | Quant | Description |
+|------|-------|-------------|
+| {model_name}-f16.gguf | F16 | Full precision |
+| {model_name}-q8_0.gguf | Q8_0 | 8-bit |
+| {model_name}-q5_k_m.gguf | Q5_K_M | 5-bit medium |
+| {model_name}-q4_k_m.gguf | Q4_K_M | 4-bit medium (recommended) |
+
+## Usage
+
+### With Ollama
+
+```bash
+# Download
+huggingface-cli download {OUTPUT_REPO} {model_name}-q4_k_m.gguf --local-dir .
+
+# Create Modelfile
+echo 'FROM ./{model_name}-q4_k_m.gguf' > Modelfile
+
+# Create and run
+ollama create qmd-expand-4b -f Modelfile
+ollama run qmd-expand-4b
+```
+
+### Prompt Format
+
+Use Qwen3 chat format with `/no_think`:
+
+```
+<|im_start|>user
+/no_think Expand this search query: your query here<|im_end|>
+<|im_start|>assistant
+```
+
+### Expected Output
+
+```
+lex: keyword variation 1
+lex: keyword variation 2
+vec: natural language reformulation
+hyde: Hypothetical document passage answering the query.
+```
+
+## License
+
+Apache 2.0 (inherited from Qwen3)
+"""
+
+api.upload_file(
+    path_or_fileobj=readme_content.encode(),
+    path_in_repo="README.md",
+    repo_id=OUTPUT_REPO,
+)
+print("   ✅ README uploaded")
+
+print("\n" + "=" * 60)
+print("✅ GGUF Conversion Complete!")
+print(f"📦 Repository: https://huggingface.co/{OUTPUT_REPO}")
+print("=" * 60)

+ 402 - 0
finetune/train_1.7B_grpo.py

@@ -0,0 +1,402 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "trl>=0.12.0",
+#     "peft>=0.7.0",
+#     "transformers>=4.45.0",
+#     "accelerate>=0.24.0",
+#     "huggingface_hub>=0.20.0",
+#     "trackio",
+#     "datasets",
+#     "bitsandbytes",
+# ]
+# ///
+"""
+GRPO training for Qwen3-1.7B query expansion model.
+Trains on top of merged SFT weights with reward function.
+"""
+
+import os
+import re
+from collections import Counter
+
+import torch
+import trackio
+from datasets import load_dataset
+from huggingface_hub import login
+from peft import LoraConfig, PeftModel, get_peft_model
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from trl import GRPOTrainer, GRPOConfig
+
+# ==================== REWARD FUNCTION ====================
+
+STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
+KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
+                      'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
+                      'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
+
+GENERIC_LEX_PHRASES = {
+    'find information about', 'search for', 'look up', 'get information',
+    'learn about', 'information on', 'details about', 'find out about',
+    'what is', 'how to', 'guide to', 'help with'
+}
+
+
+def extract_named_entities(query: str) -> set:
+    """Extract named entities from query using simple heuristics."""
+    entities = set()
+    words = query.split()
+    prev_was_entity = False
+
+    for i, word in enumerate(words):
+        clean = word.strip('.,!?:;()[]"\'')
+        if not clean:
+            prev_was_entity = False
+            continue
+
+        is_entity = False
+
+        if clean.isupper() and len(clean) >= 2:
+            entities.add(clean.lower())
+            is_entity = True
+        elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
+            entities.add(clean.lower())
+            is_entity = True
+        elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
+            entities.add(clean.lower())
+            is_entity = True
+        elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
+            entities.add(clean.lower())
+            is_entity = True
+        elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
+            entities.add(clean.lower())
+            is_entity = True
+
+        prev_was_entity = is_entity
+
+    return entities
+
+
+def get_key_terms(query: str) -> set:
+    words = set(query.lower().split())
+    return words - KEY_TERM_STOPWORDS
+
+
+def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
+    key_terms = get_key_terms(query)
+    if not key_terms:
+        return True
+    lex_words = set(lex_line.lower().split())
+    return bool(key_terms & lex_words)
+
+
+def lex_preserves_entities(lex_line: str, entities: set) -> bool:
+    if not entities:
+        return True
+    lex_lower = lex_line.lower()
+    return any(entity in lex_lower for entity in entities)
+
+
+def lex_is_generic(lex_line: str) -> bool:
+    lex_lower = lex_line.lower().strip()
+    for phrase in GENERIC_LEX_PHRASES:
+        if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
+            remaining = lex_lower
+            for word in phrase.split():
+                remaining = remaining.replace(word, '', 1).strip()
+            if len(remaining) < 3:
+                return True
+    return False
+
+
+def parse_expansion(text: str) -> dict:
+    lines = text.strip().split("\n")
+    result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
+    for line in lines:
+        line = line.strip()
+        if not line:
+            continue
+        if line.startswith("lex:"):
+            result["lex"].append(line[4:].strip())
+        elif line.startswith("vec:"):
+            result["vec"].append(line[4:].strip())
+        elif line.startswith("hyde:"):
+            result["hyde"].append(line[5:].strip())
+        else:
+            result["invalid"].append(line)
+    return result
+
+
+def edit_distance_simple(a: str, b: str) -> int:
+    words_a = set(a.lower().split())
+    words_b = set(b.lower().split())
+    return len(words_a ^ words_b)
+
+
+def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
+    a, b = a.lower().strip(), b.lower().strip()
+    if a == b:
+        return False
+    if a in b or b in a:
+        return False
+    return edit_distance_simple(a, b) >= min_distance
+
+
+def echoes_query(expansion: str, query: str) -> bool:
+    exp = expansion.lower().strip()
+    q = query.lower().strip()
+    if exp == q:
+        return True
+    if q in exp and len(exp) < len(q) + 10:
+        return True
+    return False
+
+
+def word_repetition_penalty(text: str) -> int:
+    words = re.findall(r'\b\w+\b', text.lower())
+    counts = Counter(words)
+    penalty = 0
+    for word, count in counts.items():
+        if count >= 3 and word not in STOPWORDS and len(word) > 2:
+            penalty += (count - 2) * 2
+    return penalty
+
+
+def score_expansion(query: str, expansion: str) -> float:
+    """Score expansion. Returns 0.0-1.0 for RL reward."""
+    text = expansion.strip()
+
+    # HARD FAIL: Chat template artifacts
+    if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
+                                        '\nassistant\n', '\nuser\n', '<|endoftext|>']):
+        return 0.0
+
+    # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
+    for line in text.split("\n"):
+        line = line.strip()
+        if not line:
+            continue
+        if not line.startswith(("lex:", "vec:", "hyde:")):
+            return 0.0
+
+    parsed = parse_expansion(expansion)
+
+    # FORMAT (0-30)
+    format_score = 0
+    if parsed["lex"]:
+        format_score += 10
+    if parsed["vec"]:
+        format_score += 10
+    format_score += 10
+
+    # DIVERSITY (0-30)
+    diversity_score = 0
+    types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
+    if types_present >= 2:
+        diversity_score += 10
+    total_expansions = len(parsed["lex"]) + len(parsed["vec"])
+    if total_expansions >= 2:
+        diversity_score += 5
+
+    lex_score = 5
+    for i, a in enumerate(parsed["lex"]):
+        for b in parsed["lex"][i+1:]:
+            if not is_diverse(a, b, 2):
+                lex_score -= 2
+    diversity_score += max(0, lex_score)
+
+    vec_score = 5
+    for i, a in enumerate(parsed["vec"]):
+        for b in parsed["vec"][i+1:]:
+            if not is_diverse(a, b, 3):
+                vec_score -= 2
+    diversity_score += max(0, vec_score)
+
+    echo_score = 5
+    for exp in parsed["lex"] + parsed["vec"]:
+        if echoes_query(exp, query):
+            echo_score -= 3
+    diversity_score += max(0, echo_score)
+
+    # HYDE (0-20)
+    hyde_score = 0
+    if parsed["hyde"]:
+        hyde_text = parsed["hyde"][0]
+        hyde_score += 5
+        hyde_len = len(hyde_text)
+        if 50 <= hyde_len <= 200:
+            hyde_score += 5
+        elif hyde_len < 50:
+            hyde_score += 2
+        if "\n" not in hyde_text:
+            hyde_score += 5
+        rep_penalty = word_repetition_penalty(hyde_text)
+        hyde_score += max(0, 5 - rep_penalty)
+
+    # QUALITY (0-20)
+    quality_score = 5
+    if parsed["lex"] and parsed["vec"]:
+        avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
+        avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
+        if avg_lex <= avg_vec:
+            quality_score += 5
+    if parsed["vec"]:
+        natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
+        if natural == len(parsed["vec"]):
+            quality_score += 5
+        else:
+            quality_score += 2
+    if parsed["lex"]:
+        lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
+        if lex_with_terms == len(parsed["lex"]):
+            quality_score += 5
+        elif lex_with_terms > 0:
+            quality_score += 2
+
+    # NAMED ENTITY PRESERVATION
+    entity_score = 0
+    entities = extract_named_entities(query)
+    if entities and parsed["lex"]:
+        lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
+        if lex_with_entities == len(parsed["lex"]):
+            entity_score += 15
+        elif lex_with_entities > 0:
+            entity_score += 5
+        else:
+            entity_score -= 30
+
+        generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
+        entity_score -= generic_count * 15
+
+        if parsed["vec"]:
+            vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
+            if vec_with_entities > 0:
+                entity_score += 5
+    elif not entities:
+        entity_score = 10
+
+    total = format_score + diversity_score + hyde_score + quality_score + entity_score
+    max_possible = 120 if parsed["hyde"] else 100
+    return max(0.0, min(1.0, total / max_possible))
+
+
+def extract_query_from_prompt(prompt: str) -> str:
+    if "Expand this search query:" in prompt:
+        return prompt.split("Expand this search query:")[-1].strip()
+    return prompt.strip()
+
+
+class QMDRewardFunction:
+    __name__ = "qmd_scoring_reward"
+
+    def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
+        rewards = []
+        for i, completion in enumerate(completions):
+            query = ""
+            if prompts and i < len(prompts):
+                query = extract_query_from_prompt(prompts[i])
+            score = score_expansion(query, completion)
+            rewards.append(score)
+        return rewards
+
+
+# ==================== MAIN ====================
+
+def main():
+    # Config
+    SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
+    BASE_MODEL = "Qwen/Qwen3-1.7B"
+    OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
+    DATASET = "tobil/qmd-query-expansion-train-v2"
+
+    # Login
+    hf_token = os.environ.get("HF_TOKEN")
+    if hf_token:
+        print("Logging in to HuggingFace Hub...")
+        login(token=hf_token)
+
+    # Load dataset
+    print("Loading dataset...")
+    dataset = load_dataset(DATASET, split="train")
+
+    def extract_prompt(example):
+        return {"prompt": example["messages"][0]["content"]}
+
+    dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
+    dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
+    print(f"Using {len(dataset)} prompts for GRPO")
+
+    # Load tokenizer and model
+    print(f"Loading tokenizer from {BASE_MODEL}...")
+    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
+    if tokenizer.pad_token is None:
+        tokenizer.pad_token = tokenizer.eos_token
+
+    print(f"Loading SFT model from {SFT_MODEL}...")
+    base_model = AutoModelForCausalLM.from_pretrained(
+        BASE_MODEL,
+        torch_dtype=torch.bfloat16,
+        device_map="auto",
+    )
+    model = PeftModel.from_pretrained(base_model, SFT_MODEL)
+    model = model.merge_and_unload()
+    print("Model loaded and LoRA merged.")
+
+    # Add LoRA for GRPO
+    grpo_lora_config = LoraConfig(
+        r=4,
+        lora_alpha=8,
+        lora_dropout=0.05,
+        bias="none",
+        task_type="CAUSAL_LM",
+        target_modules=["q_proj", "v_proj"],
+    )
+    model = get_peft_model(model, grpo_lora_config)
+    model.print_trainable_parameters()
+
+    # GRPO config
+    config = GRPOConfig(
+        output_dir="qmd-query-expansion-1.7B-grpo",
+        push_to_hub=True,
+        hub_model_id=OUTPUT_MODEL,
+
+        num_generations=4,
+        max_completion_length=200,
+
+        num_train_epochs=1,
+        per_device_train_batch_size=2,
+        gradient_accumulation_steps=8,
+        learning_rate=5e-7,
+        max_grad_norm=0.5,
+        max_steps=200,
+
+        logging_steps=10,
+        save_strategy="epoch",
+
+        report_to="trackio",
+        project="qmd-query-expansion",
+        run_name="qwen3-1.7b-grpo",
+    )
+
+    # Train
+    print("Initializing GRPO trainer...")
+    trainer = GRPOTrainer(
+        model=model,
+        processing_class=tokenizer,
+        args=config,
+        train_dataset=dataset,
+        reward_funcs=[QMDRewardFunction()],
+    )
+
+    print("Starting GRPO training...")
+    trainer.train()
+
+    print("Pushing to Hub...")
+    trainer.push_to_hub()
+
+    trackio.finish()
+    print(f"Complete! Model at: https://huggingface.co/{OUTPUT_MODEL}")
+
+
+if __name__ == "__main__":
+    main()

+ 402 - 0
finetune/train_4B_grpo.py

@@ -0,0 +1,402 @@
+# /// script
+# requires-python = ">=3.10"
+# dependencies = [
+#     "trl>=0.12.0",
+#     "peft>=0.7.0",
+#     "transformers>=4.45.0",
+#     "accelerate>=0.24.0",
+#     "huggingface_hub>=0.20.0",
+#     "trackio",
+#     "datasets",
+#     "bitsandbytes",
+# ]
+# ///
+"""
+GRPO training for Qwen3-4B query expansion model.
+Trains on top of merged SFT weights with reward function.
+"""
+
+import os
+import re
+from collections import Counter
+
+import torch
+import trackio
+from datasets import load_dataset
+from huggingface_hub import login
+from peft import LoraConfig, PeftModel, get_peft_model
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from trl import GRPOTrainer, GRPOConfig
+
+# ==================== REWARD FUNCTION ====================
+
+STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
+KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
+                      'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
+                      'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
+
+GENERIC_LEX_PHRASES = {
+    'find information about', 'search for', 'look up', 'get information',
+    'learn about', 'information on', 'details about', 'find out about',
+    'what is', 'how to', 'guide to', 'help with'
+}
+
+
+def extract_named_entities(query: str) -> set:
+    """Extract named entities from query using simple heuristics."""
+    entities = set()
+    words = query.split()
+    prev_was_entity = False
+
+    for i, word in enumerate(words):
+        clean = word.strip('.,!?:;()[]"\'')
+        if not clean:
+            prev_was_entity = False
+            continue
+
+        is_entity = False
+
+        if clean.isupper() and len(clean) >= 2:
+            entities.add(clean.lower())
+            is_entity = True
+        elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
+            entities.add(clean.lower())
+            is_entity = True
+        elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
+            entities.add(clean.lower())
+            is_entity = True
+        elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
+            entities.add(clean.lower())
+            is_entity = True
+        elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
+            entities.add(clean.lower())
+            is_entity = True
+
+        prev_was_entity = is_entity
+
+    return entities
+
+
+def get_key_terms(query: str) -> set:
+    words = set(query.lower().split())
+    return words - KEY_TERM_STOPWORDS
+
+
+def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
+    key_terms = get_key_terms(query)
+    if not key_terms:
+        return True
+    lex_words = set(lex_line.lower().split())
+    return bool(key_terms & lex_words)
+
+
+def lex_preserves_entities(lex_line: str, entities: set) -> bool:
+    if not entities:
+        return True
+    lex_lower = lex_line.lower()
+    return any(entity in lex_lower for entity in entities)
+
+
+def lex_is_generic(lex_line: str) -> bool:
+    lex_lower = lex_line.lower().strip()
+    for phrase in GENERIC_LEX_PHRASES:
+        if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
+            remaining = lex_lower
+            for word in phrase.split():
+                remaining = remaining.replace(word, '', 1).strip()
+            if len(remaining) < 3:
+                return True
+    return False
+
+
+def parse_expansion(text: str) -> dict:
+    lines = text.strip().split("\n")
+    result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
+    for line in lines:
+        line = line.strip()
+        if not line:
+            continue
+        if line.startswith("lex:"):
+            result["lex"].append(line[4:].strip())
+        elif line.startswith("vec:"):
+            result["vec"].append(line[4:].strip())
+        elif line.startswith("hyde:"):
+            result["hyde"].append(line[5:].strip())
+        else:
+            result["invalid"].append(line)
+    return result
+
+
+def edit_distance_simple(a: str, b: str) -> int:
+    words_a = set(a.lower().split())
+    words_b = set(b.lower().split())
+    return len(words_a ^ words_b)
+
+
+def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
+    a, b = a.lower().strip(), b.lower().strip()
+    if a == b:
+        return False
+    if a in b or b in a:
+        return False
+    return edit_distance_simple(a, b) >= min_distance
+
+
+def echoes_query(expansion: str, query: str) -> bool:
+    exp = expansion.lower().strip()
+    q = query.lower().strip()
+    if exp == q:
+        return True
+    if q in exp and len(exp) < len(q) + 10:
+        return True
+    return False
+
+
+def word_repetition_penalty(text: str) -> int:
+    words = re.findall(r'\b\w+\b', text.lower())
+    counts = Counter(words)
+    penalty = 0
+    for word, count in counts.items():
+        if count >= 3 and word not in STOPWORDS and len(word) > 2:
+            penalty += (count - 2) * 2
+    return penalty
+
+
+def score_expansion(query: str, expansion: str) -> float:
+    """Score expansion. Returns 0.0-1.0 for RL reward."""
+    text = expansion.strip()
+
+    # HARD FAIL: Chat template artifacts
+    if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
+                                        '\nassistant\n', '\nuser\n', '<|endoftext|>']):
+        return 0.0
+
+    # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
+    for line in text.split("\n"):
+        line = line.strip()
+        if not line:
+            continue
+        if not line.startswith(("lex:", "vec:", "hyde:")):
+            return 0.0
+
+    parsed = parse_expansion(expansion)
+
+    # FORMAT (0-30)
+    format_score = 0
+    if parsed["lex"]:
+        format_score += 10
+    if parsed["vec"]:
+        format_score += 10
+    format_score += 10
+
+    # DIVERSITY (0-30)
+    diversity_score = 0
+    types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
+    if types_present >= 2:
+        diversity_score += 10
+    total_expansions = len(parsed["lex"]) + len(parsed["vec"])
+    if total_expansions >= 2:
+        diversity_score += 5
+
+    lex_score = 5
+    for i, a in enumerate(parsed["lex"]):
+        for b in parsed["lex"][i+1:]:
+            if not is_diverse(a, b, 2):
+                lex_score -= 2
+    diversity_score += max(0, lex_score)
+
+    vec_score = 5
+    for i, a in enumerate(parsed["vec"]):
+        for b in parsed["vec"][i+1:]:
+            if not is_diverse(a, b, 3):
+                vec_score -= 2
+    diversity_score += max(0, vec_score)
+
+    echo_score = 5
+    for exp in parsed["lex"] + parsed["vec"]:
+        if echoes_query(exp, query):
+            echo_score -= 3
+    diversity_score += max(0, echo_score)
+
+    # HYDE (0-20)
+    hyde_score = 0
+    if parsed["hyde"]:
+        hyde_text = parsed["hyde"][0]
+        hyde_score += 5
+        hyde_len = len(hyde_text)
+        if 50 <= hyde_len <= 200:
+            hyde_score += 5
+        elif hyde_len < 50:
+            hyde_score += 2
+        if "\n" not in hyde_text:
+            hyde_score += 5
+        rep_penalty = word_repetition_penalty(hyde_text)
+        hyde_score += max(0, 5 - rep_penalty)
+
+    # QUALITY (0-20)
+    quality_score = 5
+    if parsed["lex"] and parsed["vec"]:
+        avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
+        avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
+        if avg_lex <= avg_vec:
+            quality_score += 5
+    if parsed["vec"]:
+        natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
+        if natural == len(parsed["vec"]):
+            quality_score += 5
+        else:
+            quality_score += 2
+    if parsed["lex"]:
+        lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
+        if lex_with_terms == len(parsed["lex"]):
+            quality_score += 5
+        elif lex_with_terms > 0:
+            quality_score += 2
+
+    # NAMED ENTITY PRESERVATION
+    entity_score = 0
+    entities = extract_named_entities(query)
+    if entities and parsed["lex"]:
+        lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
+        if lex_with_entities == len(parsed["lex"]):
+            entity_score += 15
+        elif lex_with_entities > 0:
+            entity_score += 5
+        else:
+            entity_score -= 30
+
+        generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
+        entity_score -= generic_count * 15
+
+        if parsed["vec"]:
+            vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
+            if vec_with_entities > 0:
+                entity_score += 5
+    elif not entities:
+        entity_score = 10
+
+    total = format_score + diversity_score + hyde_score + quality_score + entity_score
+    max_possible = 120 if parsed["hyde"] else 100
+    return max(0.0, min(1.0, total / max_possible))
+
+
+def extract_query_from_prompt(prompt: str) -> str:
+    if "Expand this search query:" in prompt:
+        return prompt.split("Expand this search query:")[-1].strip()
+    return prompt.strip()
+
+
+class QMDRewardFunction:
+    __name__ = "qmd_scoring_reward"
+
+    def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
+        rewards = []
+        for i, completion in enumerate(completions):
+            query = ""
+            if prompts and i < len(prompts):
+                query = extract_query_from_prompt(prompts[i])
+            score = score_expansion(query, completion)
+            rewards.append(score)
+        return rewards
+
+
+# ==================== MAIN ====================
+
+def main():
+    # Config
+    SFT_MODEL = "tobil/qmd-query-expansion-4B-sft"
+    BASE_MODEL = "Qwen/Qwen3-4B"
+    OUTPUT_MODEL = "tobil/qmd-query-expansion-4B-grpo"
+    DATASET = "tobil/qmd-query-expansion-train-v2"
+
+    # Login
+    hf_token = os.environ.get("HF_TOKEN")
+    if hf_token:
+        print("Logging in to HuggingFace Hub...")
+        login(token=hf_token)
+
+    # Load dataset
+    print("Loading dataset...")
+    dataset = load_dataset(DATASET, split="train")
+
+    def extract_prompt(example):
+        return {"prompt": example["messages"][0]["content"]}
+
+    dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
+    dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
+    print(f"Using {len(dataset)} prompts for GRPO")
+
+    # Load tokenizer and model
+    print(f"Loading tokenizer from {BASE_MODEL}...")
+    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
+    if tokenizer.pad_token is None:
+        tokenizer.pad_token = tokenizer.eos_token
+
+    print(f"Loading SFT model from {SFT_MODEL}...")
+    base_model = AutoModelForCausalLM.from_pretrained(
+        BASE_MODEL,
+        torch_dtype=torch.bfloat16,
+        device_map="auto",
+    )
+    model = PeftModel.from_pretrained(base_model, SFT_MODEL)
+    model = model.merge_and_unload()
+    print("Model loaded and LoRA merged.")
+
+    # Add LoRA for GRPO
+    grpo_lora_config = LoraConfig(
+        r=4,
+        lora_alpha=8,
+        lora_dropout=0.05,
+        bias="none",
+        task_type="CAUSAL_LM",
+        target_modules=["q_proj", "v_proj"],
+    )
+    model = get_peft_model(model, grpo_lora_config)
+    model.print_trainable_parameters()
+
+    # GRPO config
+    config = GRPOConfig(
+        output_dir="qmd-query-expansion-4B-grpo",
+        push_to_hub=True,
+        hub_model_id=OUTPUT_MODEL,
+
+        num_generations=4,
+        max_completion_length=200,
+
+        num_train_epochs=1,
+        per_device_train_batch_size=1,  # Smaller for 4B model
+        gradient_accumulation_steps=16,  # Compensate with more accumulation
+        learning_rate=5e-7,
+        max_grad_norm=0.5,
+        max_steps=200,
+
+        logging_steps=10,
+        save_strategy="epoch",
+
+        report_to="trackio",
+        project="qmd-query-expansion",
+        run_name="qwen3-4b-grpo",
+    )
+
+    # Train
+    print("Initializing GRPO trainer...")
+    trainer = GRPOTrainer(
+        model=model,
+        processing_class=tokenizer,
+        args=config,
+        train_dataset=dataset,
+        reward_funcs=[QMDRewardFunction()],
+    )
+
+    print("Starting GRPO training...")
+    trainer.train()
+
+    print("Pushing to Hub...")
+    trainer.push_to_hub()
+
+    trackio.finish()
+    print(f"Complete! Model at: https://huggingface.co/{OUTPUT_MODEL}")
+
+
+if __name__ == "__main__":
+    main()