tui.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "rich>=13.0.0",
  6. # "transformers>=4.45.0",
  7. # "peft>=0.7.0",
  8. # "torch",
  9. # "prompt_toolkit>=3.0.0",
  10. # ]
  11. # ///
  12. """
  13. QMD Query Expansion Model Tester
  14. A cyberpunk-styled TUI for testing finetuned query expansion models.
  15. """
  16. from collections import deque
  17. from dataclasses import dataclass
  18. from typing import Optional
  19. import torch
  20. from peft import PeftModel
  21. from prompt_toolkit import prompt
  22. from prompt_toolkit.history import InMemoryHistory
  23. from prompt_toolkit.styles import Style
  24. from rich.columns import Columns
  25. from rich.console import Console
  26. from rich.panel import Panel
  27. from rich.progress import Progress, SpinnerColumn, TextColumn
  28. from rich.table import Table
  29. from rich.text import Text
  30. from transformers import AutoModelForCausalLM, AutoTokenizer
  31. # ═══════════════════════════════════════════════════════════════════════════════
  32. # CONFIGURATION
  33. # ═══════════════════════════════════════════════════════════════════════════════
  34. # Model configs: (name, path, version, sft_base)
  35. # sft_base is required for GRPO models - they must load SFT first, merge, then apply GRPO
  36. MODELS = {
  37. "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3", None),
  38. "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
  39. "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3", None),
  40. "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3", "tobil/qmd-query-expansion-0.6B-v4"),
  41. }
  42. BASE_MODEL = "Qwen/Qwen3-0.6B"
  43. # v1 used simple format (before proper chat template)
  44. PROMPT_TEMPLATE_V1 = """Expand this search query:
  45. {query}"""
  46. # v3+ uses tokenizer.apply_chat_template() - see generate_expansion()
  47. # Cyberpunk color palette
  48. CYAN = "#00ffff"
  49. MAGENTA = "#ff00ff"
  50. PURPLE = "#bd93f9"
  51. DIM = "#6272a4"
  52. BG = "#1a0a2e"
  53. GREEN = "#50fa7b"
  54. YELLOW = "#f1fa8c"
  55. RED = "#ff5555"
  56. console = Console()
  57. # ═══════════════════════════════════════════════════════════════════════════════
  58. # SCORING (from evaluate_model.py)
  59. # ═══════════════════════════════════════════════════════════════════════════════
  60. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or',
  61. 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  62. def parse_expansion(text: str) -> dict:
  63. lines = text.strip().split("\n")
  64. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  65. for line in lines:
  66. line = line.strip()
  67. if not line:
  68. continue
  69. if line.startswith("lex:"):
  70. result["lex"].append(line[4:].strip())
  71. elif line.startswith("vec:"):
  72. result["vec"].append(line[4:].strip())
  73. elif line.startswith("hyde:"):
  74. result["hyde"].append(line[5:].strip())
  75. else:
  76. result["invalid"].append(line)
  77. return result
  78. def score_expansion(query: str, expansion: str) -> dict:
  79. """Score an expansion based on format, diversity, hyde, quality."""
  80. text = expansion.strip()
  81. if not text:
  82. return {"total": 0, "percentage": 0, "rating": "Failed", "format": 0,
  83. "diversity": 0, "hyde": 0, "quality": 0, "parsed": {"lex": [], "vec": [], "hyde": [], "invalid": []}}
  84. parsed = parse_expansion(expansion)
  85. # Check for continuation - but only if NO valid lines were found
  86. has_valid = parsed["lex"] or parsed["vec"] or parsed["hyde"]
  87. if not has_valid:
  88. return {"total": 0, "percentage": 0, "rating": "Failed", "format": 0,
  89. "diversity": 0, "hyde": 0, "quality": 0, "is_continuation": True,
  90. "parsed": parsed}
  91. # Format score (0-30)
  92. format_score = 0
  93. if parsed["lex"]:
  94. format_score += 10
  95. if parsed["vec"]:
  96. format_score += 10
  97. if not parsed["invalid"]:
  98. format_score += 10
  99. else:
  100. format_score += max(0, 10 - len(parsed["invalid"]) * 5)
  101. # Diversity score (0-30)
  102. diversity_score = 0
  103. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  104. if types_present >= 2:
  105. diversity_score += 10
  106. total_exp = len(parsed["lex"]) + len(parsed["vec"])
  107. if total_exp >= 2:
  108. diversity_score += 5
  109. diversity_score += 10 # Base diversity points
  110. diversity_score += 5 # Non-echo points
  111. # Hyde score (0-20)
  112. hyde_score = 0
  113. if parsed["hyde"]:
  114. hyde_text = parsed["hyde"][0]
  115. hyde_score += 5 # Present
  116. hyde_len = len(hyde_text)
  117. if 50 <= hyde_len <= 200:
  118. hyde_score += 5
  119. elif hyde_len < 50:
  120. hyde_score += 2
  121. if "\n" not in hyde_text:
  122. hyde_score += 5
  123. hyde_score += 5 # No repetition (simplified)
  124. # Quality score (0-20)
  125. quality_score = 5 # Base relevance (reduced to make room for key term check)
  126. # Lex must preserve key terms from query
  127. stopwords = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of', 'and', 'or', 'with', 'my'}
  128. key_terms = set(query.lower().split()) - stopwords
  129. if parsed["lex"] and key_terms:
  130. lex_with_terms = sum(1 for l in parsed["lex"] if key_terms & set(l.lower().split()))
  131. if lex_with_terms == len(parsed["lex"]):
  132. quality_score += 5
  133. elif lex_with_terms > 0:
  134. quality_score += 2
  135. if parsed["lex"] and parsed["vec"]:
  136. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  137. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  138. if avg_lex <= avg_vec:
  139. quality_score += 5
  140. if parsed["vec"] and all(" " in v and len(v) > 15 for v in parsed["vec"]):
  141. quality_score += 5
  142. total = format_score + diversity_score + hyde_score + min(20, quality_score)
  143. max_possible = 100 if parsed["hyde"] else 80
  144. percentage = total / max_possible * 100
  145. if percentage >= 80:
  146. rating = "Excellent"
  147. elif percentage >= 60:
  148. rating = "Good"
  149. elif percentage >= 40:
  150. rating = "Acceptable"
  151. elif percentage >= 20:
  152. rating = "Poor"
  153. else:
  154. rating = "Failed"
  155. return {
  156. "format": format_score,
  157. "diversity": diversity_score,
  158. "hyde": hyde_score,
  159. "quality": quality_score,
  160. "total": total,
  161. "max_possible": max_possible,
  162. "percentage": percentage,
  163. "rating": rating,
  164. "parsed": parsed,
  165. }
  166. # ═══════════════════════════════════════════════════════════════════════════════
  167. # UI COMPONENTS
  168. # ═══════════════════════════════════════════════════════════════════════════════
  169. BANNER = """
  170. [bold cyan]╔═══════════════════════════════════════════════════════════════════════════╗[/]
  171. [bold cyan]║[/] [bold magenta]░▀▄░░░░░░░░░░░▄▀░[/] [bold white]Q M D Q U E R Y E X P A N D E R[/] [bold magenta]░▀▄░░░░░░░░░░░▄▀░[/] [bold cyan]║[/]
  172. [bold cyan]╚═══════════════════════════════════════════════════════════════════════════╝[/]
  173. """
  174. def show_banner():
  175. console.print(BANNER)
  176. def show_model_menu(current: Optional[str] = None) -> str:
  177. """Display model selection menu."""
  178. console.print()
  179. console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
  180. console.print(f"[{DIM}]{'─' * 50}[/]")
  181. for key, (name, path, version, sft_base) in MODELS.items():
  182. marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
  183. sft_note = f" [{DIM}](+SFT)[/]" if sft_base else ""
  184. console.print(f" {marker} [{CYAN}]{key}[/] │ {name}{sft_note} [{DIM}]({version})[/]")
  185. console.print(f" [{DIM}]{path}[/]")
  186. console.print(f"[{DIM}]{'─' * 50}[/]")
  187. return prompt(" Enter choice (1-4): ", style=Style.from_dict({'': CYAN})).strip()
  188. def render_expansion(expansion: str, scores: dict) -> Panel:
  189. """Render the expansion output with syntax highlighting."""
  190. parsed = scores.get("parsed", parse_expansion(expansion))
  191. content = Text()
  192. # Lex lines
  193. for lex in parsed["lex"]:
  194. content.append("lex: ", style=f"bold {CYAN}")
  195. content.append(f"{lex}\n", style="white")
  196. # Vec lines
  197. for vec in parsed["vec"]:
  198. content.append("vec: ", style=f"bold {MAGENTA}")
  199. content.append(f"{vec}\n", style="white")
  200. # Hyde lines
  201. for hyde in parsed["hyde"]:
  202. content.append("hyde: ", style=f"bold {PURPLE}")
  203. content.append(f"{hyde}\n", style=f"italic {DIM}")
  204. # Invalid lines
  205. for inv in parsed["invalid"]:
  206. content.append(f"[invalid] {inv}\n", style=f"dim {RED}")
  207. return Panel(
  208. content,
  209. title=f"[bold {CYAN}]◈ EXPANSION[/]",
  210. border_style=CYAN,
  211. padding=(0, 1),
  212. )
  213. def render_scores(scores: dict) -> Panel:
  214. """Render score breakdown as a compact table."""
  215. rating = scores["rating"]
  216. rating_color = {
  217. "Excellent": GREEN,
  218. "Good": CYAN,
  219. "Acceptable": YELLOW,
  220. "Poor": RED,
  221. "Failed": RED,
  222. }.get(rating, DIM)
  223. table = Table(show_header=False, box=None, padding=(0, 1))
  224. table.add_column("metric", style=DIM)
  225. table.add_column("score", justify="right")
  226. table.add_column("bar", width=12)
  227. def score_bar(val: int, max_val: int) -> str:
  228. filled = int(val / max_val * 10)
  229. return f"[{CYAN}]{'█' * filled}[/][{DIM}]{'░' * (10 - filled)}[/]"
  230. table.add_row("Format", f"[white]{scores['format']}[/]/30", score_bar(scores['format'], 30))
  231. table.add_row("Diversity", f"[white]{scores['diversity']}[/]/30", score_bar(scores['diversity'], 30))
  232. table.add_row("Hyde", f"[white]{scores['hyde']}[/]/20", score_bar(scores['hyde'], 20))
  233. table.add_row("Quality", f"[white]{scores['quality']}[/]/20", score_bar(scores['quality'], 20))
  234. table.add_row("", "", "")
  235. table.add_row(
  236. f"[bold]TOTAL[/]",
  237. f"[bold white]{scores['total']}[/]/{scores.get('max_possible', 80)}",
  238. f"[bold {rating_color}]{rating}[/]"
  239. )
  240. return Panel(
  241. table,
  242. title=f"[bold {MAGENTA}]◈ SCORES[/]",
  243. border_style=MAGENTA,
  244. padding=(0, 1),
  245. )
  246. def render_history(history: deque) -> Panel:
  247. """Render recent query history."""
  248. content = Text()
  249. for i, (query, rating) in enumerate(history):
  250. rating_color = {
  251. "Excellent": GREEN, "Good": CYAN, "Acceptable": YELLOW,
  252. "Poor": RED, "Failed": RED,
  253. }.get(rating, DIM)
  254. content.append(f" [{DIM}]{i+1}.[/] {query[:40]}")
  255. if len(query) > 40:
  256. content.append(f"[{DIM}]...[/]")
  257. content.append(f" [{rating_color}]●[/]\n")
  258. if not history:
  259. content.append(f" [{DIM}]No queries yet[/]")
  260. return Panel(
  261. content,
  262. title=f"[bold {PURPLE}]◈ HISTORY[/]",
  263. border_style=PURPLE,
  264. padding=(0, 1),
  265. )
  266. # ═══════════════════════════════════════════════════════════════════════════════
  267. # MODEL LOADING
  268. # ═══════════════════════════════════════════════════════════════════════════════
  269. @dataclass
  270. class LoadedModel:
  271. model: any
  272. tokenizer: any
  273. name: str
  274. path: str
  275. version: str # "v1" or "v3" - determines prompt template
  276. def load_model(model_path: str, model_name: str, version: str, sft_base: Optional[str] = None) -> LoadedModel:
  277. """Load model with progress indicator.
  278. For GRPO models, sft_base must be provided - the SFT adapter is loaded first,
  279. merged into the base model, then the GRPO adapter is applied on top.
  280. """
  281. with Progress(
  282. SpinnerColumn(spinner_name="dots", style=CYAN),
  283. TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
  284. console=console,
  285. transient=True,
  286. ) as progress:
  287. task = progress.add_task("tokenizer", total=None)
  288. tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
  289. if tokenizer.pad_token is None:
  290. tokenizer.pad_token = tokenizer.eos_token
  291. progress.update(task, description="base model")
  292. model = AutoModelForCausalLM.from_pretrained(
  293. BASE_MODEL,
  294. torch_dtype=torch.bfloat16,
  295. device_map="auto",
  296. )
  297. # For GRPO models: load SFT first, merge, then apply GRPO
  298. if sft_base:
  299. progress.update(task, description="SFT adapter")
  300. model = PeftModel.from_pretrained(model, sft_base)
  301. progress.update(task, description="merging SFT")
  302. model = model.merge_and_unload()
  303. progress.update(task, description="adapter")
  304. model = PeftModel.from_pretrained(model, model_path)
  305. model.eval()
  306. return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version)
  307. DEBUG = False # Set to True for debug output
  308. def generate_expansion(loaded: LoadedModel, query: str) -> str:
  309. """Generate expansion using proper Qwen3 chat template."""
  310. if loaded.version == "v3":
  311. # Use tokenizer's chat template with /no_think to disable thinking mode
  312. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  313. prompt_text = loaded.tokenizer.apply_chat_template(
  314. messages,
  315. tokenize=False,
  316. add_generation_prompt=True
  317. )
  318. else:
  319. # v1 used simple format
  320. prompt_text = PROMPT_TEMPLATE_V1.format(query=query)
  321. if DEBUG:
  322. console.print(f"[{DIM}]─── DEBUG: Prompt ───[/]")
  323. console.print(f"[{DIM}]{repr(prompt_text)}[/]")
  324. inputs = loaded.tokenizer(prompt_text, return_tensors="pt").to(loaded.model.device)
  325. with torch.no_grad():
  326. outputs = loaded.model.generate(
  327. **inputs,
  328. max_new_tokens=200,
  329. temperature=0.7,
  330. do_sample=True,
  331. pad_token_id=loaded.tokenizer.pad_token_id,
  332. eos_token_id=loaded.tokenizer.eos_token_id,
  333. )
  334. full_output = loaded.tokenizer.decode(outputs[0], skip_special_tokens=True)
  335. if DEBUG:
  336. console.print(f"[{DIM}]─── DEBUG: Full output ───[/]")
  337. console.print(f"[{DIM}]{repr(full_output[:500])}[/]")
  338. # Extract assistant response (skip_special_tokens leaves "user\n...\nassistant\n...")
  339. if "\nassistant\n" in full_output:
  340. expansion = full_output.split("\nassistant\n")[-1].strip()
  341. elif "assistant\n" in full_output:
  342. expansion = full_output.split("assistant\n")[-1].strip()
  343. else:
  344. expansion = full_output[len(prompt_text):].strip()
  345. # Remove any <think> tags that might remain
  346. if expansion.startswith("<think>"):
  347. think_end = expansion.find("</think>")
  348. if think_end != -1:
  349. expansion = expansion[think_end + 8:].strip()
  350. if DEBUG:
  351. console.print(f"[{DIM}]─── DEBUG: Expansion ───[/]")
  352. console.print(f"[{DIM}]{repr(expansion[:300])}[/]")
  353. return expansion
  354. # ═══════════════════════════════════════════════════════════════════════════════
  355. # MAIN LOOP
  356. # ═══════════════════════════════════════════════════════════════════════════════
  357. def main():
  358. console.clear()
  359. show_banner()
  360. # Model selection
  361. choice = show_model_menu()
  362. if choice not in MODELS:
  363. console.print(f"[{RED}]Invalid choice. Exiting.[/]")
  364. return
  365. model_name, model_path, model_version, sft_base = MODELS[choice]
  366. console.print()
  367. try:
  368. loaded = load_model(model_path, model_name, model_version, sft_base)
  369. except Exception as e:
  370. console.print(f"[{RED}]Failed to load model: {e}[/]")
  371. return
  372. console.print(f"[{GREEN}]✓ Model loaded: {model_name}[/]")
  373. console.print()
  374. # Query history
  375. history: deque = deque(maxlen=5)
  376. input_history = InMemoryHistory()
  377. # Main loop
  378. console.print(f"[{DIM}]Enter queries to expand. Type 'quit' to exit, 'model' to switch models.[/]")
  379. console.print()
  380. while True:
  381. try:
  382. query = prompt(
  383. f"[{CYAN}]❯[/] ",
  384. history=input_history,
  385. style=Style.from_dict({'': 'ansicyan'}),
  386. ).strip()
  387. except (KeyboardInterrupt, EOFError):
  388. break
  389. if not query:
  390. continue
  391. if query.lower() == 'quit':
  392. break
  393. if query.lower() == 'model':
  394. console.clear()
  395. show_banner()
  396. choice = show_model_menu(loaded.path)
  397. if choice in MODELS:
  398. new_name, new_path, new_version, new_sft_base = MODELS[choice]
  399. if new_path != loaded.path:
  400. console.print()
  401. loaded = load_model(new_path, new_name, new_version, new_sft_base)
  402. console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
  403. console.print()
  404. continue
  405. if query.lower() == 'history':
  406. console.print(render_history(history))
  407. continue
  408. # Generate expansion
  409. with Progress(
  410. SpinnerColumn(spinner_name="dots", style=MAGENTA),
  411. TextColumn(f"[{MAGENTA}]Expanding...[/]"),
  412. console=console,
  413. transient=True,
  414. ) as progress:
  415. progress.add_task("gen", total=None)
  416. expansion = generate_expansion(loaded, query)
  417. # Score it
  418. scores = score_expansion(query, expansion)
  419. # Add to history
  420. history.appendleft((query, scores["rating"]))
  421. # Display results
  422. console.print()
  423. console.print(f"[{DIM}]Query: [/][bold white]{query}[/]")
  424. console.print()
  425. # Side-by-side layout
  426. console.print(Columns([
  427. render_expansion(expansion, scores),
  428. render_scores(scores),
  429. ], equal=True, expand=True))
  430. console.print()
  431. console.print(f"\n[{CYAN}]◆ Goodbye![/]\n")
  432. if __name__ == "__main__":
  433. main()