tui.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "rich>=13.0.0",
  6. # "transformers>=4.45.0",
  7. # "peft>=0.7.0",
  8. # "torch",
  9. # "prompt_toolkit>=3.0.0",
  10. # "huggingface_hub>=0.20.0",
  11. # ]
  12. # ///
  13. """
  14. QMD Query Expansion Model Tester
  15. A cyberpunk-styled TUI for testing finetuned query expansion models.
  16. """
  17. from collections import deque
  18. from dataclasses import dataclass
  19. from typing import Optional
  20. import re
  21. import torch
  22. from huggingface_hub import HfApi
  23. from peft import PeftModel
  24. from prompt_toolkit import prompt
  25. from prompt_toolkit.history import InMemoryHistory
  26. from prompt_toolkit.styles import Style
  27. from rich.columns import Columns
  28. from rich.console import Console
  29. from rich.panel import Panel
  30. from rich.progress import Progress, SpinnerColumn, TextColumn
  31. from rich.table import Table
  32. from rich.text import Text
  33. from transformers import AutoModelForCausalLM, AutoTokenizer
  34. # ═══════════════════════════════════════════════════════════════════════════════
  35. # CONFIGURATION
  36. # ═══════════════════════════════════════════════════════════════════════════════
  37. # Base models by size
  38. BASE_MODELS = {
  39. "0.6B": "Qwen/Qwen3-0.6B",
  40. "1.7B": "Qwen/Qwen3-1.7B",
  41. "4B": "Qwen/Qwen3-4B",
  42. }
  43. def get_model_size(model_id: str) -> str:
  44. """Extract model size from model ID (e.g., '0.6B', '1.7B', '4B')."""
  45. match = re.search(r'(\d+\.?\d*B)', model_id)
  46. return match.group(1) if match else "0.6B"
  47. def fetch_available_models() -> dict:
  48. """Dynamically fetch available qmd-query-expansion models from Hub."""
  49. api = HfApi()
  50. models = {}
  51. idx = 1
  52. try:
  53. # Search for all qmd-query-expansion models
  54. hub_models = list(api.list_models(author="tobil", search="qmd-query-expansion"))
  55. # Group by size and type (SFT vs GRPO)
  56. sft_models = []
  57. grpo_models = []
  58. for m in hub_models:
  59. model_id = m.id
  60. # Skip GGUF repos
  61. if "gguf" in model_id.lower():
  62. continue
  63. if "grpo" in model_id.lower():
  64. grpo_models.append(model_id)
  65. elif "sft" in model_id.lower() or not any(x in model_id.lower() for x in ["grpo", "gguf"]):
  66. sft_models.append(model_id)
  67. # Sort by size (0.6B, 1.7B, 4B)
  68. def size_sort_key(m):
  69. size = get_model_size(m)
  70. return {"0.6B": 0, "1.7B": 1, "4B": 2}.get(size, 3)
  71. sft_models.sort(key=size_sort_key)
  72. grpo_models.sort(key=size_sort_key)
  73. # Add SFT models
  74. for model_id in sft_models:
  75. size = get_model_size(model_id)
  76. models[str(idx)] = (f"SFT {size}", model_id, "v3", None, size)
  77. idx += 1
  78. # Add GRPO models (need to find matching SFT base)
  79. for model_id in grpo_models:
  80. size = get_model_size(model_id)
  81. # Find matching SFT model
  82. sft_base = None
  83. for sft in sft_models:
  84. if get_model_size(sft) == size:
  85. sft_base = sft
  86. break
  87. models[str(idx)] = (f"GRPO {size}", model_id, "v3", sft_base, size)
  88. idx += 1
  89. except Exception as e:
  90. # Fallback to default models if Hub fetch fails
  91. models = {
  92. "1": ("SFT 0.6B", "tobil/qmd-query-expansion-0.6B-v4", "v3", None, "0.6B"),
  93. "2": ("GRPO 0.6B", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4", "0.6B"),
  94. }
  95. return models
  96. # Will be populated on startup
  97. MODELS = {}
  98. # v1 used simple format (before proper chat template)
  99. PROMPT_TEMPLATE_V1 = """Expand this search query:
  100. {query}"""
  101. # v3+ uses tokenizer.apply_chat_template() - see generate_expansion()
  102. # Cyberpunk color palette
  103. CYAN = "#00ffff"
  104. MAGENTA = "#ff00ff"
  105. PURPLE = "#bd93f9"
  106. DIM = "#6272a4"
  107. BG = "#1a0a2e"
  108. GREEN = "#50fa7b"
  109. YELLOW = "#f1fa8c"
  110. RED = "#ff5555"
  111. console = Console()
  112. # ═══════════════════════════════════════════════════════════════════════════════
  113. # SCORING (from evaluate_model.py)
  114. # ═══════════════════════════════════════════════════════════════════════════════
  115. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or',
  116. 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  117. def parse_expansion(text: str) -> dict:
  118. lines = text.strip().split("\n")
  119. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  120. for line in lines:
  121. line = line.strip()
  122. if not line:
  123. continue
  124. if line.startswith("lex:"):
  125. result["lex"].append(line[4:].strip())
  126. elif line.startswith("vec:"):
  127. result["vec"].append(line[4:].strip())
  128. elif line.startswith("hyde:"):
  129. result["hyde"].append(line[5:].strip())
  130. else:
  131. result["invalid"].append(line)
  132. return result
  133. def score_expansion(query: str, expansion: str) -> dict:
  134. """Score an expansion based on format, diversity, hyde, quality."""
  135. text = expansion.strip()
  136. if not text:
  137. return {"total": 0, "percentage": 0, "rating": "Failed", "format": 0,
  138. "diversity": 0, "hyde": 0, "quality": 0, "parsed": {"lex": [], "vec": [], "hyde": [], "invalid": []}}
  139. parsed = parse_expansion(expansion)
  140. # Check for continuation - but only if NO valid lines were found
  141. has_valid = parsed["lex"] or parsed["vec"] or parsed["hyde"]
  142. if not has_valid:
  143. return {"total": 0, "percentage": 0, "rating": "Failed", "format": 0,
  144. "diversity": 0, "hyde": 0, "quality": 0, "is_continuation": True,
  145. "parsed": parsed}
  146. # Format score (0-30)
  147. format_score = 0
  148. if parsed["lex"]:
  149. format_score += 10
  150. if parsed["vec"]:
  151. format_score += 10
  152. if not parsed["invalid"]:
  153. format_score += 10
  154. else:
  155. format_score += max(0, 10 - len(parsed["invalid"]) * 5)
  156. # Diversity score (0-30)
  157. diversity_score = 0
  158. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  159. if types_present >= 2:
  160. diversity_score += 10
  161. total_exp = len(parsed["lex"]) + len(parsed["vec"])
  162. if total_exp >= 2:
  163. diversity_score += 5
  164. diversity_score += 10 # Base diversity points
  165. diversity_score += 5 # Non-echo points
  166. # Hyde score (0-20)
  167. hyde_score = 0
  168. if parsed["hyde"]:
  169. hyde_text = parsed["hyde"][0]
  170. hyde_score += 5 # Present
  171. hyde_len = len(hyde_text)
  172. if 50 <= hyde_len <= 200:
  173. hyde_score += 5
  174. elif hyde_len < 50:
  175. hyde_score += 2
  176. if "\n" not in hyde_text:
  177. hyde_score += 5
  178. hyde_score += 5 # No repetition (simplified)
  179. # Quality score (0-20)
  180. quality_score = 5 # Base relevance (reduced to make room for key term check)
  181. # Lex must preserve key terms from query
  182. stopwords = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of', 'and', 'or', 'with', 'my'}
  183. key_terms = set(query.lower().split()) - stopwords
  184. if parsed["lex"] and key_terms:
  185. lex_with_terms = sum(1 for l in parsed["lex"] if key_terms & set(l.lower().split()))
  186. if lex_with_terms == len(parsed["lex"]):
  187. quality_score += 5
  188. elif lex_with_terms > 0:
  189. quality_score += 2
  190. if parsed["lex"] and parsed["vec"]:
  191. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  192. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  193. if avg_lex <= avg_vec:
  194. quality_score += 5
  195. if parsed["vec"] and all(" " in v and len(v) > 15 for v in parsed["vec"]):
  196. quality_score += 5
  197. total = format_score + diversity_score + hyde_score + min(20, quality_score)
  198. max_possible = 100 if parsed["hyde"] else 80
  199. percentage = total / max_possible * 100
  200. if percentage >= 80:
  201. rating = "Excellent"
  202. elif percentage >= 60:
  203. rating = "Good"
  204. elif percentage >= 40:
  205. rating = "Acceptable"
  206. elif percentage >= 20:
  207. rating = "Poor"
  208. else:
  209. rating = "Failed"
  210. return {
  211. "format": format_score,
  212. "diversity": diversity_score,
  213. "hyde": hyde_score,
  214. "quality": quality_score,
  215. "total": total,
  216. "max_possible": max_possible,
  217. "percentage": percentage,
  218. "rating": rating,
  219. "parsed": parsed,
  220. }
  221. # ═══════════════════════════════════════════════════════════════════════════════
  222. # UI COMPONENTS
  223. # ═══════════════════════════════════════════════════════════════════════════════
  224. BANNER = """
  225. [bold cyan]╔═══════════════════════════════════════════════════════════════════════════╗[/]
  226. [bold cyan]║[/] [bold magenta]░▀▄░░░░░░░░░░░▄▀░[/] [bold white]Q M D Q U E R Y E X P A N D E R[/] [bold magenta]░▀▄░░░░░░░░░░░▄▀░[/] [bold cyan]║[/]
  227. [bold cyan]╚═══════════════════════════════════════════════════════════════════════════╝[/]
  228. """
  229. def show_banner():
  230. console.print(BANNER)
  231. def show_model_menu(current: Optional[str] = None) -> str:
  232. """Display model selection menu."""
  233. console.print()
  234. console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
  235. console.print(f"[{DIM}]{'─' * 50}[/]")
  236. for key, model_info in MODELS.items():
  237. name, path, version, sft_base = model_info[:4]
  238. marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
  239. sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
  240. console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
  241. console.print(f" [{DIM}]{path}[/]")
  242. console.print(f"[{DIM}]{'─' * 50}[/]")
  243. return prompt(" Enter choice (1-4): ", style=Style.from_dict({'': CYAN})).strip()
  244. def render_expansion(expansion: str, scores: dict) -> Panel:
  245. """Render the expansion output with syntax highlighting."""
  246. parsed = scores.get("parsed", parse_expansion(expansion))
  247. content = Text()
  248. # Lex lines
  249. for lex in parsed["lex"]:
  250. content.append("lex: ", style=f"bold {CYAN}")
  251. content.append(f"{lex}\n", style="white")
  252. # Vec lines
  253. for vec in parsed["vec"]:
  254. content.append("vec: ", style=f"bold {MAGENTA}")
  255. content.append(f"{vec}\n", style="white")
  256. # Hyde lines
  257. for hyde in parsed["hyde"]:
  258. content.append("hyde: ", style=f"bold {PURPLE}")
  259. content.append(f"{hyde}\n", style=f"italic {DIM}")
  260. # Invalid lines
  261. for inv in parsed["invalid"]:
  262. content.append(f"[invalid] {inv}\n", style=f"dim {RED}")
  263. return Panel(
  264. content,
  265. title=f"[bold {CYAN}]◈ EXPANSION[/]",
  266. border_style=CYAN,
  267. padding=(0, 1),
  268. )
  269. def render_scores(scores: dict) -> Panel:
  270. """Render score breakdown as a compact table."""
  271. rating = scores["rating"]
  272. rating_color = {
  273. "Excellent": GREEN,
  274. "Good": CYAN,
  275. "Acceptable": YELLOW,
  276. "Poor": RED,
  277. "Failed": RED,
  278. }.get(rating, DIM)
  279. table = Table(show_header=False, box=None, padding=(0, 1))
  280. table.add_column("metric", style=DIM)
  281. table.add_column("score", justify="right")
  282. table.add_column("bar", width=12)
  283. def score_bar(val: int, max_val: int) -> str:
  284. filled = int(val / max_val * 10)
  285. return f"[{CYAN}]{'█' * filled}[/][{DIM}]{'░' * (10 - filled)}[/]"
  286. table.add_row("Format", f"[white]{scores['format']}[/]/30", score_bar(scores['format'], 30))
  287. table.add_row("Diversity", f"[white]{scores['diversity']}[/]/30", score_bar(scores['diversity'], 30))
  288. table.add_row("Hyde", f"[white]{scores['hyde']}[/]/20", score_bar(scores['hyde'], 20))
  289. table.add_row("Quality", f"[white]{scores['quality']}[/]/20", score_bar(scores['quality'], 20))
  290. table.add_row("", "", "")
  291. table.add_row(
  292. f"[bold]TOTAL[/]",
  293. f"[bold white]{scores['total']}[/]/{scores.get('max_possible', 80)}",
  294. f"[bold {rating_color}]{rating}[/]"
  295. )
  296. return Panel(
  297. table,
  298. title=f"[bold {MAGENTA}]◈ SCORES[/]",
  299. border_style=MAGENTA,
  300. padding=(0, 1),
  301. )
  302. def render_history(history: deque) -> Panel:
  303. """Render recent query history."""
  304. content = Text()
  305. for i, (query, rating) in enumerate(history):
  306. rating_color = {
  307. "Excellent": GREEN, "Good": CYAN, "Acceptable": YELLOW,
  308. "Poor": RED, "Failed": RED,
  309. }.get(rating, DIM)
  310. content.append(f" [{DIM}]{i+1}.[/] {query[:40]}")
  311. if len(query) > 40:
  312. content.append(f"[{DIM}]...[/]")
  313. content.append(f" [{rating_color}]●[/]\n")
  314. if not history:
  315. content.append(f" [{DIM}]No queries yet[/]")
  316. return Panel(
  317. content,
  318. title=f"[bold {PURPLE}]◈ HISTORY[/]",
  319. border_style=PURPLE,
  320. padding=(0, 1),
  321. )
  322. # ═══════════════════════════════════════════════════════════════════════════════
  323. # MODEL LOADING
  324. # ═══════════════════════════════════════════════════════════════════════════════
  325. @dataclass
  326. class LoadedModel:
  327. model: any
  328. tokenizer: any
  329. name: str
  330. path: str
  331. version: str # "v1" or "v3" - determines prompt template
  332. def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None, size: str = "0.6B") -> LoadedModel:
  333. """Load model with progress indicator.
  334. For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
  335. merged into the base model, then the GRPO adapter is applied on top.
  336. """
  337. base_model = BASE_MODELS.get(size, BASE_MODELS["0.6B"])
  338. with Progress(
  339. SpinnerColumn(spinner_name="dots", style=CYAN),
  340. TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
  341. console=console,
  342. transient=True,
  343. ) as progress:
  344. task = progress.add_task("tokenizer", total=None)
  345. tokenizer = AutoTokenizer.from_pretrained(base_model)
  346. if tokenizer.pad_token is None:
  347. tokenizer.pad_token = tokenizer.eos_token
  348. progress.update(task, description=f"base model ({size})")
  349. model = AutoModelForCausalLM.from_pretrained(
  350. base_model,
  351. torch_dtype=torch.bfloat16,
  352. device_map="auto",
  353. )
  354. # For GRPO models: load SFT first, merge, then apply GRPO
  355. if sft_base:
  356. progress.update(task, description="SFT adapter")
  357. model = PeftModel.from_pretrained(model, sft_base)
  358. progress.update(task, description="merging SFT")
  359. model = model.merge_and_unload()
  360. progress.update(task, description="adapter")
  361. model = PeftModel.from_pretrained(model, model_path)
  362. model.eval()
  363. return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version)
  364. DEBUG = False # Set to True for debug output
  365. def generate_expansion(loaded: LoadedModel, query: str) -> str:
  366. """Generate expansion using proper Qwen3 chat template."""
  367. if loaded.version == "v3":
  368. # Use tokenizer's chat template with /no_think to disable thinking mode
  369. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  370. prompt_text = loaded.tokenizer.apply_chat_template(
  371. messages,
  372. tokenize=False,
  373. add_generation_prompt=True
  374. )
  375. else:
  376. # v1 used simple format
  377. prompt_text = PROMPT_TEMPLATE_V1.format(query=query)
  378. if DEBUG:
  379. console.print(f"[{DIM}]─── DEBUG: Prompt ───[/]")
  380. console.print(f"[{DIM}]{repr(prompt_text)}[/]")
  381. inputs = loaded.tokenizer(prompt_text, return_tensors="pt").to(loaded.model.device)
  382. with torch.no_grad():
  383. outputs = loaded.model.generate(
  384. **inputs,
  385. max_new_tokens=200,
  386. temperature=0.7,
  387. do_sample=True,
  388. pad_token_id=loaded.tokenizer.pad_token_id,
  389. eos_token_id=loaded.tokenizer.eos_token_id,
  390. )
  391. full_output = loaded.tokenizer.decode(outputs[0], skip_special_tokens=True)
  392. if DEBUG:
  393. console.print(f"[{DIM}]─── DEBUG: Full output ───[/]")
  394. console.print(f"[{DIM}]{repr(full_output[:500])}[/]")
  395. # Extract assistant response (skip_special_tokens leaves "user\n...\nassistant\n...")
  396. if "\nassistant\n" in full_output:
  397. expansion = full_output.split("\nassistant\n")[-1].strip()
  398. elif "assistant\n" in full_output:
  399. expansion = full_output.split("assistant\n")[-1].strip()
  400. else:
  401. expansion = full_output[len(prompt_text):].strip()
  402. # Remove any <think> tags that might remain
  403. if expansion.startswith("<think>"):
  404. think_end = expansion.find("</think>")
  405. if think_end != -1:
  406. expansion = expansion[think_end + 8:].strip()
  407. if DEBUG:
  408. console.print(f"[{DIM}]─── DEBUG: Expansion ───[/]")
  409. console.print(f"[{DIM}]{repr(expansion[:300])}[/]")
  410. return expansion
  411. # ═══════════════════════════════════════════════════════════════════════════════
  412. # MAIN LOOP
  413. # ═══════════════════════════════════════════════════════════════════════════════
  414. def main():
  415. global MODELS
  416. console.clear()
  417. show_banner()
  418. # Fetch available models from Hub
  419. console.print(f"[{DIM}]Fetching available models...[/]")
  420. MODELS = fetch_available_models()
  421. if not MODELS:
  422. console.print(f"[{RED}]No models found. Exiting.[/]")
  423. return
  424. # Model selection
  425. choice = show_model_menu()
  426. if choice not in MODELS:
  427. console.print(f"[{RED}]Invalid choice. Exiting.[/]")
  428. return
  429. model_info = MODELS[choice]
  430. model_name, model_path, model_version, sft_base = model_info[:4]
  431. model_size = model_info[4] if len(model_info) > 4 else get_model_size(model_path)
  432. console.print()
  433. try:
  434. loaded = load_model(model_path, model_name, model_version, sft_base, model_size)
  435. except Exception as e:
  436. console.print(f"[{RED}]Failed to load model: {e}[/]")
  437. return
  438. console.print(f"[{GREEN}]✓ Model loaded: {model_name}[/]")
  439. console.print()
  440. # Query history
  441. history: deque = deque(maxlen=5)
  442. input_history = InMemoryHistory()
  443. # Main loop
  444. console.print(f"[{DIM}]Enter queries to expand. Type 'quit' to exit, 'model' to switch models.[/]")
  445. console.print()
  446. while True:
  447. try:
  448. query = prompt(
  449. f"[{CYAN}]❯[/] ",
  450. history=input_history,
  451. style=Style.from_dict({'': 'ansicyan'}),
  452. ).strip()
  453. except (KeyboardInterrupt, EOFError):
  454. break
  455. if not query:
  456. continue
  457. if query.lower() == 'quit':
  458. break
  459. if query.lower() == 'model':
  460. console.clear()
  461. show_banner()
  462. choice = show_model_menu(loaded.path)
  463. if choice in MODELS:
  464. new_info = MODELS[choice]
  465. new_name, new_path, new_version, new_sft_base = new_info[:4]
  466. new_size = new_info[4] if len(new_info) > 4 else get_model_size(new_path)
  467. if new_path != loaded.path:
  468. console.print()
  469. loaded = load_model(new_path, new_name, new_version, new_sft_base, new_size)
  470. console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
  471. console.print()
  472. continue
  473. if query.lower() == 'history':
  474. console.print(render_history(history))
  475. continue
  476. # Generate expansion
  477. with Progress(
  478. SpinnerColumn(spinner_name="dots", style=MAGENTA),
  479. TextColumn(f"[{MAGENTA}]Expanding...[/]"),
  480. console=console,
  481. transient=True,
  482. ) as progress:
  483. progress.add_task("gen", total=None)
  484. expansion = generate_expansion(loaded, query)
  485. # Score it
  486. scores = score_expansion(query, expansion)
  487. # Add to history
  488. history.appendleft((query, scores["rating"]))
  489. # Display results
  490. console.print()
  491. console.print(f"[{DIM}]Query: [/][bold white]{query}[/]")
  492. console.print()
  493. # Side-by-side layout
  494. console.print(Columns([
  495. render_expansion(expansion, scores),
  496. render_scores(scores),
  497. ], equal=True, expand=True))
  498. console.print()
  499. console.print(f"\n[{CYAN}]◆ Goodbye![/]\n")
  500. if __name__ == "__main__":
  501. main()