Просмотр исходного кода

Fix TUI to load GRPO models with SFT base first

GRPO adapters were trained on merged SFT weights, so they need SFT
loaded and merged first before applying the GRPO adapter.

Updated MODELS config to include sft_base path for GRPO models,
and load_model() now handles the SFT -> merge -> GRPO flow.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 4 месяцев назад
Родитель
Сommit
2648512b7c
1 измененных файлов с 29 добавлено и 15 удалено
  1. 29 15
      finetune/tui.py

+ 29 - 15
finetune/tui.py

@@ -35,11 +35,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
 # CONFIGURATION
 # CONFIGURATION
 # ═══════════════════════════════════════════════════════════════════════════════
 # ═══════════════════════════════════════════════════════════════════════════════
 
 
+# Model configs: (name, path, version, sft_base)
+# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
 MODELS = {
 MODELS = {
-    "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3"),
-    "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3"),
-    "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3"),
-    "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3"),
+    "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
+    "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
+    "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
+    "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
 }
 }
 
 
 BASE_MODEL = "Qwen/Qwen3-0.6B"
 BASE_MODEL = "Qwen/Qwen3-0.6B"
@@ -210,13 +212,14 @@ def show_model_menu(current: Optional[str] = None) -> str:
     console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
     console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
     console.print(f"[{DIM}]{'─' * 50}[/]")
     console.print(f"[{DIM}]{'─' * 50}[/]")
 
 
-    for key, (name, path, version) in MODELS.items():
+    for key, (name, path, version, sft_base) in MODELS.items():
         marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
         marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
-        console.print(f"  {marker} [{CYAN}]{key}[/] │ {name} [{DIM}]({version})[/]")
+        sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
+        console.print(f"  {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
         console.print(f"      [{DIM}]{path}[/]")
         console.print(f"      [{DIM}]{path}[/]")
 
 
     console.print(f"[{DIM}]{'─' * 50}[/]")
     console.print(f"[{DIM}]{'─' * 50}[/]")
-    return prompt("  Enter choice (1-3): ", style=Style.from_dict({'': CYAN})).strip()
+    return prompt("  Enter choice (1-4): ", style=Style.from_dict({'': CYAN})).strip()
 
 
 
 
 def render_expansion(expansion: str, scores: dict) -> Panel:
 def render_expansion(expansion: str, scores: dict) -> Panel:
@@ -328,8 +331,12 @@ class LoadedModel:
     version: str  # "v1" or "v3" - determines prompt template
     version: str  # "v1" or "v3" - determines prompt template
 
 
 
 
-def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
-    """Load model with progress indicator."""
+def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
+    """Load model with progress indicator.
+
+    For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
+    merged into the base model, then the GRPO adapter is applied on top.
+    """
     with Progress(
     with Progress(
         SpinnerColumn(spinner_name="dots", style=CYAN),
         SpinnerColumn(spinner_name="dots", style=CYAN),
         TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
         TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
@@ -342,14 +349,21 @@ def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
             tokenizer.pad_token = tokenizer.eos_token
             tokenizer.pad_token = tokenizer.eos_token
 
 
         progress.update(task, description="base model")
         progress.update(task, description="base model")
-        base = AutoModelForCausalLM.from_pretrained(
+        model = AutoModelForCausalLM.from_pretrained(
             BASE_MODEL,
             BASE_MODEL,
             torch_dtype=torch.bfloat16,
             torch_dtype=torch.bfloat16,
             device_map="auto",
             device_map="auto",
         )
         )
 
 
+        # For GRPO models: load SFT first, merge, then apply GRPO
+        if sft_base:
+            progress.update(task, description="SFT adapter")
+            model = PeftModel.from_pretrained(model, sft_base)
+            progress.update(task, description="merging SFT")
+            model = model.merge_and_unload()
+
         progress.update(task, description="adapter")
         progress.update(task, description="adapter")
-        model = PeftModel.from_pretrained(base, model_path)
+        model = PeftModel.from_pretrained(model, model_path)
         model.eval()
         model.eval()
 
 
     return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version)
     return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version)
@@ -428,11 +442,11 @@ def main():
         console.print(f"[{RED}]Invalid choice. Exiting.[/]")
         console.print(f"[{RED}]Invalid choice. Exiting.[/]")
         return
         return
 
 
-    model_name, model_path, model_version = MODELS[choice]
+    model_name, model_path, model_version, sft_base = MODELS[choice]
     console.print()
     console.print()
 
 
     try:
     try:
-        loaded = load_model(model_path, model_name, model_version)
+        loaded = load_model(model_path, model_name, model_version, sft_base)
     except Exception as e:
     except Exception as e:
         console.print(f"[{RED}]Failed to load model: {e}[/]")
         console.print(f"[{RED}]Failed to load model: {e}[/]")
         return
         return
@@ -469,10 +483,10 @@ def main():
             show_banner()
             show_banner()
             choice = show_model_menu(loaded.path)
             choice = show_model_menu(loaded.path)
             if choice in MODELS:
             if choice in MODELS:
-                new_name, new_path, new_version = MODELS[choice]
+                new_name, new_path, new_version, new_sft_base = MODELS[choice]
                 if new_path != loaded.path:
                 if new_path != loaded.path:
                     console.print()
                     console.print()
-                    loaded = load_model(new_path, new_name, new_version)
+                    loaded = load_model(new_path, new_name, new_version, new_sft_base)
                     console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
                     console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
             console.print()
             console.print()
             continue
             continue