|
|
@@ -35,11 +35,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
# CONFIGURATION
|
|
|
# ═══════════════════════════════════════════════════════════════════════════════
|
|
|
|
|
|
+# Model configs: (name, path, version, sft_base)
|
|
|
+# sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
|
|
|
MODELS = {
|
|
|
- "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3"),
|
|
|
- "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3"),
|
|
|
- "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3"),
|
|
|
- "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3"),
|
|
|
+ "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
|
|
|
+ "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
|
|
|
+ "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
|
|
|
+ "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
|
|
|
}
|
|
|
|
|
|
BASE_MODEL = "Qwen/Qwen3-0.6B"
|
|
|
@@ -210,13 +212,14 @@ def show_model_menu(current: Optional[str] = None) -> str:
|
|
|
console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
|
|
|
console.print(f"[{DIM}]{'─' * 50}[/]")
|
|
|
|
|
|
- for key, (name, path, version) in MODELS.items():
|
|
|
+ for key, (name, path, version, sft_base) in MODELS.items():
|
|
|
marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
|
|
|
- console.print(f" {marker} [{CYAN}]{key}[/] │ {name} [{DIM}]({version})[/]")
|
|
|
+ sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
|
|
|
+ console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
|
|
|
console.print(f" [{DIM}]{path}[/]")
|
|
|
|
|
|
console.print(f"[{DIM}]{'─' * 50}[/]")
|
|
|
- return prompt(" Enter choice (1-3): ", style=Style.from_dict({'': CYAN})).strip()
|
|
|
+ return prompt(" Enter choice (1-4): ", style=Style.from_dict({'': CYAN})).strip()
|
|
|
|
|
|
|
|
|
def render_expansion(expansion: str, scores: dict) -> Panel:
|
|
|
@@ -328,8 +331,12 @@ class LoadedModel:
|
|
|
version: str # "v1" or "v3" - determines prompt template
|
|
|
|
|
|
|
|
|
-def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
|
|
|
- """Load model with progress indicator."""
|
|
|
+def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
|
|
|
+ """Load model with progress indicator.
|
|
|
+
|
|
|
+ For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
|
|
|
+ merged into the base model, then the GRPO adapter is applied on top.
|
|
|
+ """
|
|
|
with Progress(
|
|
|
SpinnerColumn(spinner_name="dots", style=CYAN),
|
|
|
TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
|
|
|
@@ -342,14 +349,21 @@ def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
progress.update(task, description="base model")
|
|
|
- base = AutoModelForCausalLM.from_pretrained(
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(
|
|
|
BASE_MODEL,
|
|
|
torch_dtype=torch.bfloat16,
|
|
|
device_map="auto",
|
|
|
)
|
|
|
|
|
|
+ # For GRPO models: load SFT first, merge, then apply GRPO
|
|
|
+ if sft_base:
|
|
|
+ progress.update(task, description="SFT adapter")
|
|
|
+ model = PeftModel.from_pretrained(model, sft_base)
|
|
|
+ progress.update(task, description="merging SFT")
|
|
|
+ model = model.merge_and_unload()
|
|
|
+
|
|
|
progress.update(task, description="adapter")
|
|
|
- model = PeftModel.from_pretrained(base, model_path)
|
|
|
+ model = PeftModel.from_pretrained(model, model_path)
|
|
|
model.eval()
|
|
|
|
|
|
return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version)
|
|
|
@@ -428,11 +442,11 @@ def main():
|
|
|
console.print(f"[{RED}]Invalid choice. Exiting.[/]")
|
|
|
return
|
|
|
|
|
|
- model_name, model_path, model_version = MODELS[choice]
|
|
|
+ model_name, model_path, model_version, sft_base = MODELS[choice]
|
|
|
console.print()
|
|
|
|
|
|
try:
|
|
|
- loaded = load_model(model_path, model_name, model_version)
|
|
|
+ loaded = load_model(model_path, model_name, model_version, sft_base)
|
|
|
except Exception as e:
|
|
|
console.print(f"[{RED}]Failed to load model: {e}[/]")
|
|
|
return
|
|
|
@@ -469,10 +483,10 @@ def main():
|
|
|
show_banner()
|
|
|
choice = show_model_menu(loaded.path)
|
|
|
if choice in MODELS:
|
|
|
- new_name, new_path, new_version = MODELS[choice]
|
|
|
+ new_name, new_path, new_version, new_sft_base = MODELS[choice]
|
|
|
if new_path != loaded.path:
|
|
|
console.print()
|
|
|
- loaded = load_model(new_path, new_name, new_version)
|
|
|
+ loaded = load_model(new_path, new_name, new_version, new_sft_base)
|
|
|
console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
|
|
|
console.print()
|
|
|
continue
|