|
|
@@ -7,6 +7,7 @@
|
|
|
# "peft>=0.7.0",
|
|
|
# "torch",
|
|
|
# "prompt_toolkit>=3.0.0",
|
|
|
+# "huggingface_hub>=0.20.0",
|
|
|
# ]
|
|
|
# ///
|
|
|
"""
|
|
|
@@ -17,8 +18,10 @@ A cyberpunk-styled TUI for testing finetuned query expansion models.
|
|
|
from collections import deque
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Optional
|
|
|
+import re
|
|
|
|
|
|
import torch
|
|
|
+from huggingface_hub import HfApi
|
|
|
from peft import PeftModel
|
|
|
from prompt_toolkit import prompt
|
|
|
from prompt_toolkit.history import InMemoryHistory
|
|
|
@@ -35,16 +38,82 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
# CONFIGURATION
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
-# Model configs: (name, path, version, sft_base)
|
|
|
-# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
|
|
|
-MODELS = {
|
|
|
- "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
|
|
|
- "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
|
|
|
- "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
|
|
|
- "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
|
|
|
+# Base models by size
|
|
|
+BASE_MODELS = {
|
|
|
+ "0.6B": "Qwen/Qwen3-0.6B",
|
|
|
+ "1.7B": "Qwen/Qwen3-1.7B",
|
|
|
+ "4B": "Qwen/Qwen3-4B",
|
|
|
}
|
|
|
|
|
|
-BASE_MODEL = "Qwen/Qwen3-0.6B"
|
|
|
+
|
|
|
+def get_model_size(model_id: str) -> str:
|
|
|
+ """Extract model size from model ID (e.g., '0.6B', '1.7B', '4B')."""
|
|
|
+ match = re.search(r'(\d+\.?\d*B)', model_id)
|
|
|
+ return match.group(1) if match else "0.6B"
|
|
|
+
|
|
|
+
|
|
|
+def fetch_available_models() -> dict:
|
|
|
+ """Dynamically fetch available qmd-query-expansion models from Hub."""
|
|
|
+ api = HfApi()
|
|
|
+ models = {}
|
|
|
+ idx = 1
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Search for all qmd-query-expansion models
|
|
|
+ hub_models = list(api.list_models(author="tobil", search="qmd-query-expansion"))
|
|
|
+
|
|
|
+ # Group by size and type (SFT vs GRPO)
|
|
|
+ sft_models = []
|
|
|
+ grpo_models = []
|
|
|
+
|
|
|
+ for m in hub_models:
|
|
|
+ model_id = m.id
|
|
|
+ # Skip GGUF repos
|
|
|
+ if "gguf" in model_id.lower():
|
|
|
+ continue
|
|
|
+ if "grpo" in model_id.lower():
|
|
|
+ grpo_models.append(model_id)
|
|
|
+ elif "sft" in model_id.lower() or not any(x in model_id.lower() for x in ["grpo", "gguf"]):
|
|
|
+ sft_models.append(model_id)
|
|
|
+
|
|
|
+ # Sort by size (0.6B, 1.7B, 4B)
|
|
|
+ def size_sort_key(m):
|
|
|
+ size = get_model_size(m)
|
|
|
+ return {"0.6B": 0, "1.7B": 1, "4B": 2}.get(size, 3)
|
|
|
+
|
|
|
+ sft_models.sort(key=size_sort_key)
|
|
|
+ grpo_models.sort(key=size_sort_key)
|
|
|
+
|
|
|
+ # Add SFT models
|
|
|
+ for model_id in sft_models:
|
|
|
+ size = get_model_size(model_id)
|
|
|
+ models[str(idx)] = (f"SFT {size}", model_id, "v3", None, size)
|
|
|
+ idx += 1
|
|
|
+
|
|
|
+ # Add GRPO models (need to find matching SFT base)
|
|
|
+ for model_id in grpo_models:
|
|
|
+ size = get_model_size(model_id)
|
|
|
+ # Find matching SFT model
|
|
|
+ sft_base = None
|
|
|
+ for sft in sft_models:
|
|
|
+ if get_model_size(sft) == size:
|
|
|
+ sft_base = sft
|
|
|
+ break
|
|
|
+ models[str(idx)] = (f"GRPO {size}", model_id, "v3", sft_base, size)
|
|
|
+ idx += 1
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ # Fallback to default models if Hub fetch fails
|
|
|
+ models = {
|
|
|
+ "1": ("SFT 0.6B", "tobil/qmd-query-expansion-0.6B-v4", "v3", None, "0.6B"),
|
|
|
+ "2": ("GRPO 0.6B", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4", "0.6B"),
|
|
|
+ }
|
|
|
+
|
|
|
+ return models
|
|
|
+
|
|
|
+
|
|
|
+# Will be populated on startup
|
|
|
+MODELS = {}
|
|
|
|
|
|
# v1 used simple format (before proper chat template)
|
|
|
PROMPT_TEMPLATE_V1 = """Expand this search query:
|
|
|
@@ -212,7 +281,8 @@ def show_model_menu(current: Optional[str] = None) -> str:
|
|
|
console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
|
|
|
console.print(f"[{DIM}]{'─' * 50}[/]")
|
|
|
|
|
|
- for key, (name, path, version, sft_base) in MODELS.items():
|
|
|
+ for key, model_info in MODELS.items():
|
|
|
+ name, path, version, sft_base = model_info[:4]
|
|
|
marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
|
|
|
sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
|
|
|
console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
|
|
|
@@ -331,12 +401,14 @@ class LoadedModel:
|
|
|
version: str # "v1" or "v3" - determines prompt template
|
|
|
|
|
|
|
|
|
-def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
|
|
|
+def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None, size: str = "0.6B") -> LoadedModel:
|
|
|
"""Load model with progress indicator.
|
|
|
|
|
|
For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
|
|
|
merged into the base model, then the GRPO adapter is applied on top.
|
|
|
"""
|
|
|
+ base_model = BASE_MODELS.get(size, BASE_MODELS["0.6B"])
|
|
|
+
|
|
|
with Progress(
|
|
|
SpinnerColumn(spinner_name="dots", style=CYAN),
|
|
|
TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
|
|
|
@@ -344,13 +416,13 @@ def load_model(model_path: str, model_name: str, version: str, sft_base: Optiona
|
|
|
transient=True,
|
|
|
) as progress:
|
|
|
task = progress.add_task("tokenizer", total=None)
|
|
|
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
- progress.update(task, description="base model")
|
|
|
+ progress.update(task, description=f"base model ({size})")
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
- BASE_MODEL,
|
|
|
+ base_model,
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
device_map="auto",
|
|
|
)
|
|
|
@@ -433,20 +505,31 @@ def generate_expansion(loaded: LoadedModel, query: str) -> str:
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
def main():
|
|
|
+ global MODELS
|
|
|
console.clear()
|
|
|
show_banner()
|
|
|
|
|
|
+ # Fetch available models from Hub
|
|
|
+ console.print(f"[{DIM}]Fetching available models...[/]")
|
|
|
+ MODELS = fetch_available_models()
|
|
|
+
|
|
|
+ if not MODELS:
|
|
|
+ console.print(f"[{RED}]No models found. Exiting.[/]")
|
|
|
+ return
|
|
|
+
|
|
|
# Model selection
|
|
|
choice = show_model_menu()
|
|
|
if choice not in MODELS:
|
|
|
console.print(f"[{RED}]Invalid choice. Exiting.[/]")
|
|
|
return
|
|
|
|
|
|
- model_name, model_path, model_version, sft_base = MODELS[choice]
|
|
|
+ model_info = MODELS[choice]
|
|
|
+ model_name, model_path, model_version, sft_base = model_info[:4]
|
|
|
+ model_size = model_info[4] if len(model_info) > 4 else get_model_size(model_path)
|
|
|
console.print()
|
|
|
|
|
|
try:
|
|
|
- loaded = load_model(model_path, model_name, model_version, sft_base)
|
|
|
+ loaded = load_model(model_path, model_name, model_version, sft_base, model_size)
|
|
|
except Exception as e:
|
|
|
console.print(f"[{RED}]Failed to load model: {e}[/]")
|
|
|
return
|
|
|
@@ -483,10 +566,12 @@ def main():
|
|
|
show_banner()
|
|
|
choice = show_model_menu(loaded.path)
|
|
|
if choice in MODELS:
|
|
|
- new_name, new_path, new_version, new_sft_base = MODELS[choice]
|
|
|
+ new_info = MODELS[choice]
|
|
|
+ new_name, new_path, new_version, new_sft_base = new_info[:4]
|
|
|
+ new_size = new_info[4] if len(new_info) > 4 else get_model_size(new_path)
|
|
|
if new_path != loaded.path:
|
|
|
console.print()
|
|
|
- loaded = load_model(new_path, new_name, new_version, new_sft_base)
|
|
|
+ loaded = load_model(new_path, new_name, new_version, new_sft_base, new_size)
|
|
|
console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
|
|
|
console.print()
|
|
|
continue
|