Эх сурвалжийг харах

Make TUI model list dynamic from HuggingFace Hub

- Fetch available qmd-query-expansion models from tobil/ on Hub
- Auto-detect model size (0.6B, 1.7B, 4B) and use correct base model
- Group models by type (SFT vs GRPO) in menu
- Skip GGUF repos in model listing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 3 сар өмнө
parent
commit
3ea85eff50
1 өөрчлөгдсөн 102 нэмэгдсэн , 17 устгасан
  1. 102 17
      finetune/tui.py

+ 102 - 17
finetune/tui.py

@@ -7,6 +7,7 @@
 #     "peft>=0.7.0",
 #     "torch",
 #     "prompt_toolkit>=3.0.0",
+#     "huggingface_hub>=0.20.0",
 # ]
 # ///
 """
@@ -17,8 +18,10 @@ A cyberpunk-styled TUI for testing finetuned query expansion models.
 from collections import deque
 from dataclasses import dataclass
 from typing import Optional
+import re
 
 import torch
+from huggingface_hub import HfApi
 from peft import PeftModel
 from prompt_toolkit import prompt
 from prompt_toolkit.history import InMemoryHistory
@@ -35,16 +38,82 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
 # CONFIGURATION
 # ═══════════════════════════════════════════════════════════════════════════════
 
-# Model configs: (name, path, version, sft_base)
-# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
-MODELS = {
-    "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
-    "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
-    "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
-    "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
+# Base models by size
+BASE_MODELS = {
+    "0.6B": "Qwen/Qwen3-0.6B",
+    "1.7B": "Qwen/Qwen3-1.7B",
+    "4B": "Qwen/Qwen3-4B",
 }
 
-BASE_MODEL = "Qwen/Qwen3-0.6B"
+
+def get_model_size(model_id: str) -> str:
+    """Extract model size from model ID (e.g., '0.6B', '1.7B', '4B')."""
+    match = re.search(r'(\d+\.?\d*B)', model_id)
+    return match.group(1) if match else "0.6B"
+
+
+def fetch_available_models() -> dict:
+    """Dynamically fetch available qmd-query-expansion models from Hub."""
+    api = HfApi()
+    models = {}
+    idx = 1
+
+    try:
+        # Search for all qmd-query-expansion models
+        hub_models = list(api.list_models(author="tobil", search="qmd-query-expansion"))
+
+        # Group by size and type (SFT vs GRPO)
+        sft_models = []
+        grpo_models = []
+
+        for m in hub_models:
+            model_id = m.id
+            # Skip GGUF repos
+            if "gguf" in model_id.lower():
+                continue
+            if "grpo" in model_id.lower():
+                grpo_models.append(model_id)
+            elif "sft" in model_id.lower() or not any(x in model_id.lower() for x in ["grpo", "gguf"]):
+                sft_models.append(model_id)
+
+        # Sort by size (0.6B, 1.7B, 4B)
+        def size_sort_key(m):
+            size = get_model_size(m)
+            return {"0.6B": 0, "1.7B": 1, "4B": 2}.get(size, 3)
+
+        sft_models.sort(key=size_sort_key)
+        grpo_models.sort(key=size_sort_key)
+
+        # Add SFT models
+        for model_id in sft_models:
+            size = get_model_size(model_id)
+            models[str(idx)] = (f"SFT {size}", model_id, "v3", None, size)
+            idx += 1
+
+        # Add GRPO models (need to find matching SFT base)
+        for model_id in grpo_models:
+            size = get_model_size(model_id)
+            # Find matching SFT model
+            sft_base = None
+            for sft in sft_models:
+                if get_model_size(sft) == size:
+                    sft_base = sft
+                    break
+            models[str(idx)] = (f"GRPO {size}", model_id, "v3", sft_base, size)
+            idx += 1
+
+    except Exception as e:
+        # Fallback to default models if Hub fetch fails
+        models = {
+            "1": ("SFT 0.6B", "tobil/qmd-query-expansion-0.6B-v4", "v3", None, "0.6B"),
+            "2": ("GRPO 0.6B", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4", "0.6B"),
+        }
+
+    return models
+
+
+# Will be populated on startup
+MODELS = {}
 
 # v1 used simple format (before proper chat template)
 PROMPT_TEMPLATE_V1 = """Expand this search query:
@@ -212,7 +281,8 @@ def show_model_menu(current: Optional[str] = None) -> str:
     console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
     console.print(f"[{DIM}]{'─' * 50}[/]")
 
-    for key, (name, path, version, sft_base) in MODELS.items():
+    for key, model_info in MODELS.items():
+        name, path, version, sft_base = model_info[:4]
         marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
         sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
         console.print(f"  {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
@@ -331,12 +401,14 @@ class LoadedModel:
     version: str  # "v1" or "v3" - determines prompt template
 
 
-def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
+def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None, size: str = "0.6B") -> LoadedModel:
     """Load model with progress indicator.
 
     For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
     merged into the base model, then the GRPO adapter is applied on top.
     """
+    base_model = BASE_MODELS.get(size, BASE_MODELS["0.6B"])
+
     with Progress(
         SpinnerColumn(spinner_name="dots", style=CYAN),
         TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
@@ -344,13 +416,13 @@ def load_model(model_path: str, model_name: str, version: str, sft_base: Optiona
         transient=True,
     ) as progress:
         task = progress.add_task("tokenizer", total=None)
-        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
+        tokenizer = AutoTokenizer.from_pretrained(base_model)
         if tokenizer.pad_token is None:
             tokenizer.pad_token = tokenizer.eos_token
 
-        progress.update(task, description="base model")
+        progress.update(task, description=f"base model ({size})")
         model = AutoModelForCausalLM.from_pretrained(
-            BASE_MODEL,
+            base_model,
             torch_dtype=torch.bfloat16,
             device_map="auto",
         )
@@ -433,20 +505,31 @@ def generate_expansion(loaded: LoadedModel, query: str) -> str:
 # ═══════════════════════════════════════════════════════════════════════════════
 
 def main():
+    global MODELS
     console.clear()
     show_banner()
 
+    # Fetch available models from Hub
+    console.print(f"[{DIM}]Fetching available models...[/]")
+    MODELS = fetch_available_models()
+
+    if not MODELS:
+        console.print(f"[{RED}]No models found. Exiting.[/]")
+        return
+
     # Model selection
     choice = show_model_menu()
     if choice not in MODELS:
         console.print(f"[{RED}]Invalid choice. Exiting.[/]")
         return
 
-    model_name, model_path, model_version, sft_base = MODELS[choice]
+    model_info = MODELS[choice]
+    model_name, model_path, model_version, sft_base = model_info[:4]
+    model_size = model_info[4] if len(model_info) > 4 else get_model_size(model_path)
     console.print()
 
     try:
-        loaded = load_model(model_path, model_name, model_version, sft_base)
+        loaded = load_model(model_path, model_name, model_version, sft_base, model_size)
     except Exception as e:
         console.print(f"[{RED}]Failed to load model: {e}[/]")
         return
@@ -483,10 +566,12 @@ def main():
             show_banner()
             choice = show_model_menu(loaded.path)
             if choice in MODELS:
-                new_name, new_path, new_version, new_sft_base = MODELS[choice]
+                new_info = MODELS[choice]
+                new_name, new_path, new_version, new_sft_base = new_info[:4]
+                new_size = new_info[4] if len(new_info) > 4 else get_model_size(new_path)
                 if new_path != loaded.path:
                     console.print()
-                    loaded = load_model(new_path, new_name, new_version, new_sft_base)
+                    loaded = load_model(new_path, new_name, new_version, new_sft_base, new_size)
                     console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
             console.print()
             continue