Răsfoiți Sursa

Fix TUI to load GRPO models with SFT base first

GRPO adapters were trained on merged SFT weights, so they need SFT
loaded and merged first before applying the GRPO adapter.

Updated MODELS config to include sft_base path for GRPO models,
and load_model() now handles the SFT -> merge -> GRPO flow.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 4 luni în urmă
părinte
comite
2648512b7c
1 a modificat fișierele cu 29 adăugiri și 15 ștergeri
  1. 29 15
      finetune/tui.py

+ 29 - 15
finetune/tui.py

@@ -35,11 +35,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
 # CONFIGURATION
 # ═══════════════════════════════════════════════════════════════════════════════
 
+# Model configs: (name, path, version, sft_base)
+# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
 MODELS = {
-    "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3"),
-    "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3"),
-    "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3"),
-    "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3"),
+    "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
+    "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
+    "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
+    "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
 }
 
 BASE_MODEL = "Qwen/Qwen3-0.6B"
@@ -210,13 +212,14 @@ def show_model_menu(current: Optional[str] = None) -> str:
     console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
     console.print(f"[{DIM}]{'─' * 50}[/]")
 
-    for key, (name, path, version) in MODELS.items():
+    for key, (name, path, version, sft_base) in MODELS.items():
         marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
-        console.print(f"  {marker} [{CYAN}]{key}[/] │ {name} [{DIM}]({version})[/]")
+        sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
+        console.print(f"  {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
         console.print(f"      [{DIM}]{path}[/]")
 
     console.print(f"[{DIM}]{'─' * 50}[/]")
-    return prompt("  Enter choice (1-3): ", style=Style.from_dict({'': CYAN})).strip()
+    return prompt("  Enter choice (1-4): ", style=Style.from_dict({'': CYAN})).strip()
 
 
 def render_expansion(expansion: str, scores: dict) -> Panel:
@@ -328,8 +331,12 @@ class LoadedModel:
     version: str  # "v1" or "v3" - determines prompt template
 
 
-def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
-    """Load model with progress indicator."""
+def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
+    """Load model with progress indicator.
+
+    For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
+    merged into the base model, then the GRPO adapter is applied on top.
+    """
     with Progress(
         SpinnerColumn(spinner_name="dots", style=CYAN),
         TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
@@ -342,14 +349,21 @@ def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
             tokenizer.pad_token = tokenizer.eos_token
 
         progress.update(task, description="base model")
-        base = AutoModelForCausalLM.from_pretrained(
+        model = AutoModelForCausalLM.from_pretrained(
             BASE_MODEL,
             torch_dtype=torch.bfloat16,
             device_map="auto",
         )
 
+        # For GRPO models: load SFT first, merge, then apply GRPO
+        if sft_base:
+            progress.update(task, description="SFT adapter")
+            model = PeftModel.from_pretrained(model, sft_base)
+            progress.update(task, description="merging SFT")
+            model = model.merge_and_unload()
+
         progress.update(task, description="adapter")
-        model = PeftModel.from_pretrained(base, model_path)
+        model = PeftModel.from_pretrained(model, model_path)
         model.eval()
 
     return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version)
@@ -428,11 +442,11 @@ def main():
         console.print(f"[{RED}]Invalid choice. Exiting.[/]")
         return
 
-    model_name, model_path, model_version = MODELS[choice]
+    model_name, model_path, model_version, sft_base = MODELS[choice]
     console.print()
 
     try:
-        loaded = load_model(model_path, model_name, model_version)
+        loaded = load_model(model_path, model_name, model_version, sft_base)
     except Exception as e:
         console.print(f"[{RED}]Failed to load model: {e}[/]")
         return
@@ -469,10 +483,10 @@ def main():
             show_banner()
             choice = show_model_menu(loaded.path)
             if choice in MODELS:
-                new_name, new_path, new_version = MODELS[choice]
+                new_name, new_path, new_version, new_sft_base = MODELS[choice]
                 if new_path != loaded.path:
                     console.print()
-                    loaded = load_model(new_path, new_name, new_version)
+                    loaded = load_model(new_path, new_name, new_version, new_sft_base)
                     console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
             console.print()
             continue