Просмотр исходного кода

Make TUI model list dynamic from HuggingFace Hub

- Fetch available qmd-query-expansion models from tobil/ on Hub
- Auto-detect model size (0.6B, 1.7B, 4B) and use correct base model
- Group models by type (SFT vs GRPO) in menu
- Skip GGUF repos in model listing

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 3 месяцев назад
Родитель
Сommit
3ea85eff50
1 измененных файлов с 102 добавлено и 17 удалено
  1. 102 17
      finetune/tui.py

+ 102 - 17
finetune/tui.py

@@ -7,6 +7,7 @@
 #     "peft>=0.7.0",
 #     "peft>=0.7.0",
 #     "torch",
 #     "torch",
 #     "prompt_toolkit>=3.0.0",
 #     "prompt_toolkit>=3.0.0",
+#     "huggingface_hub>=0.20.0",
 # ]
 # ]
 # ///
 # ///
 """
 """
@@ -17,8 +18,10 @@ A cyberpunk-styled TUI for testing finetuned query expansion models.
 from collections import deque
 from collections import deque
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Optional
 from typing import Optional
+import re
 
 
 import torch
 import torch
+from huggingface_hub import HfApi
 from peft import PeftModel
 from peft import PeftModel
 from prompt_toolkit import prompt
 from prompt_toolkit import prompt
 from prompt_toolkit.history import InMemoryHistory
 from prompt_toolkit.history import InMemoryHistory
@@ -35,16 +38,82 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
 # CONFIGURATION
 # CONFIGURATION
 # ═══════════════════════════════════════════════════════════════════════════════
 # ═══════════════════════════════════════════════════════════════════════════════
 
 
-# Model configs: (name, path, version, sft_base)
-# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
-MODELS = {
-    "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
-    "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
-    "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
-    "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
+# Base models by size
+BASE_MODELS = {
+    "0.6B": "Qwen/Qwen3-0.6B",
+    "1.7B": "Qwen/Qwen3-1.7B",
+    "4B": "Qwen/Qwen3-4B",
 }
 }
 
 
-BASE_MODEL = "Qwen/Qwen3-0.6B"
+
+def get_model_size(model_id: str) -> str:
+    """Extract model size from model ID (e.g., '0.6B', '1.7B', '4B')."""
+    match = re.search(r'(\d+\.?\d*B)', model_id)
+    return match.group(1) if match else "0.6B"
+
+
+def fetch_available_models() -> dict:
+    """Dynamically fetch available qmd-query-expansion models from Hub."""
+    api = HfApi()
+    models = {}
+    idx = 1
+
+    try:
+        # Search for all qmd-query-expansion models
+        hub_models = list(api.list_models(author="tobil", search="qmd-query-expansion"))
+
+        # Group by size and type (SFT vs GRPO)
+        sft_models = []
+        grpo_models = []
+
+        for m in hub_models:
+            model_id = m.id
+            # Skip GGUF repos
+            if "gguf" in model_id.lower():
+                continue
+            if "grpo" in model_id.lower():
+                grpo_models.append(model_id)
+            elif "sft" in model_id.lower() or not any(x in model_id.lower() for x in ["grpo", "gguf"]):
+                sft_models.append(model_id)
+
+        # Sort by size (0.6B, 1.7B, 4B)
+        def size_sort_key(m):
+            size = get_model_size(m)
+            return {"0.6B": 0, "1.7B": 1, "4B": 2}.get(size, 3)
+
+        sft_models.sort(key=size_sort_key)
+        grpo_models.sort(key=size_sort_key)
+
+        # Add SFT models
+        for model_id in sft_models:
+            size = get_model_size(model_id)
+            models[str(idx)] = (f"SFT {size}", model_id, "v3", None, size)
+            idx += 1
+
+        # Add GRPO models (need to find matching SFT base)
+        for model_id in grpo_models:
+            size = get_model_size(model_id)
+            # Find matching SFT model
+            sft_base = None
+            for sft in sft_models:
+                if get_model_size(sft) == size:
+                    sft_base = sft
+                    break
+            models[str(idx)] = (f"GRPO {size}", model_id, "v3", sft_base, size)
+            idx += 1
+
+    except Exception as e:
+        # Fallback to default models if Hub fetch fails
+        models = {
+            "1": ("SFT 0.6B", "tobil/qmd-query-expansion-0.6B-v4", "v3", None, "0.6B"),
+            "2": ("GRPO 0.6B", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4", "0.6B"),
+        }
+
+    return models
+
+
+# Will be populated on startup
+MODELS = {}
 
 
 # v1 used simple format (before proper chat template)
 # v1 used simple format (before proper chat template)
 PROMPT_TEMPLATE_V1 = """Expand this search query:
 PROMPT_TEMPLATE_V1 = """Expand this search query:
@@ -212,7 +281,8 @@ def show_model_menu(current: Optional[str] = None) -> str:
     console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
     console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
     console.print(f"[{DIM}]{'─' * 50}[/]")
     console.print(f"[{DIM}]{'─' * 50}[/]")
 
 
-    for key, (name, path, version, sft_base) in MODELS.items():
+    for key, model_info in MODELS.items():
+        name, path, version, sft_base = model_info[:4]
         marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
         marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
         sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
         sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
         console.print(f"  {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
         console.print(f"  {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
@@ -331,12 +401,14 @@ class LoadedModel:
     version: str  # "v1" or "v3" - determines prompt template
     version: str  # "v1" or "v3" - determines prompt template
 
 
 
 
-def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
+def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None, size: str = "0.6B") -> LoadedModel:
     """Load model with progress indicator.
     """Load model with progress indicator.
 
 
     For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
     For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
     merged into the base model, then the GRPO adapter is applied on top.
     merged into the base model, then the GRPO adapter is applied on top.
     """
     """
+    base_model = BASE_MODELS.get(size, BASE_MODELS["0.6B"])
+
     with Progress(
     with Progress(
         SpinnerColumn(spinner_name="dots", style=CYAN),
         SpinnerColumn(spinner_name="dots", style=CYAN),
         TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
         TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
@@ -344,13 +416,13 @@ def load_model(model_path: str, model_name: str, version: str, sft_base: Optiona
         transient=True,
         transient=True,
     ) as progress:
     ) as progress:
         task = progress.add_task("tokenizer", total=None)
         task = progress.add_task("tokenizer", total=None)
-        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
+        tokenizer = AutoTokenizer.from_pretrained(base_model)
         if tokenizer.pad_token is None:
         if tokenizer.pad_token is None:
             tokenizer.pad_token = tokenizer.eos_token
             tokenizer.pad_token = tokenizer.eos_token
 
 
-        progress.update(task, description="base model")
+        progress.update(task, description=f"base model ({size})")
         model = AutoModelForCausalLM.from_pretrained(
         model = AutoModelForCausalLM.from_pretrained(
-            BASE_MODEL,
+            base_model,
             torch_dtype=torch.bfloat16,
             torch_dtype=torch.bfloat16,
             device_map="auto",
             device_map="auto",
         )
         )
@@ -433,20 +505,31 @@ def generate_expansion(loaded: LoadedModel, query: str) -> str:
 # ═══════════════════════════════════════════════════════════════════════════════
 # ═══════════════════════════════════════════════════════════════════════════════
 
 
 def main():
 def main():
+    global MODELS
     console.clear()
     console.clear()
     show_banner()
     show_banner()
 
 
+    # Fetch available models from Hub
+    console.print(f"[{DIM}]Fetching available models...[/]")
+    MODELS = fetch_available_models()
+
+    if not MODELS:
+        console.print(f"[{RED}]No models found. Exiting.[/]")
+        return
+
     # Model selection
     # Model selection
     choice = show_model_menu()
     choice = show_model_menu()
     if choice not in MODELS:
     if choice not in MODELS:
         console.print(f"[{RED}]Invalid choice. Exiting.[/]")
         console.print(f"[{RED}]Invalid choice. Exiting.[/]")
         return
         return
 
 
-    model_name, model_path, model_version, sft_base = MODELS[choice]
+    model_info = MODELS[choice]
+    model_name, model_path, model_version, sft_base = model_info[:4]
+    model_size = model_info[4] if len(model_info) > 4 else get_model_size(model_path)
     console.print()
     console.print()
 
 
     try:
     try:
-        loaded = load_model(model_path, model_name, model_version, sft_base)
+        loaded = load_model(model_path, model_name, model_version, sft_base, model_size)
     except Exception as e:
     except Exception as e:
         console.print(f"[{RED}]Failed to load model: {e}[/]")
         console.print(f"[{RED}]Failed to load model: {e}[/]")
         return
         return
@@ -483,10 +566,12 @@ def main():
             show_banner()
             show_banner()
             choice = show_model_menu(loaded.path)
             choice = show_model_menu(loaded.path)
             if choice in MODELS:
             if choice in MODELS:
-                new_name, new_path, new_version, new_sft_base = MODELS[choice]
+                new_info = MODELS[choice]
+                new_name, new_path, new_version, new_sft_base = new_info[:4]
+                new_size = new_info[4] if len(new_info) > 4 else get_model_size(new_path)
                 if new_path != loaded.path:
                 if new_path != loaded.path:
                     console.print()
                     console.print()
-                    loaded = load_model(new_path, new_name, new_version, new_sft_base)
+                    loaded = load_model(new_path, new_name, new_version, new_sft_base, new_size)
                     console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
                     console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
             console.print()
             console.print()
             continue
             continue