tui.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "rich>=13.0.0",
  6. # "transformers>=4.45.0",
  7. # "peft>=0.7.0",
  8. # "torch",
  9. # "prompt_toolkit>=3.0.0",
  10. # ]
  11. # ///
  12. """
  13. QMD Query Expansion Model Tester
  14. A cyberpunk-styled TUI for testing finetuned query expansion models.
  15. """
  16. from collections import deque
  17. from dataclasses import dataclass
  18. from typing import Optional
  19. import torch
  20. from peft import PeftModel
  21. from prompt_toolkit import prompt
  22. from prompt_toolkit.history import InMemoryHistory
  23. from prompt_toolkit.styles import Style
  24. from rich.columns import Columns
  25. from rich.console import Console
  26. from rich.panel import Panel
  27. from rich.progress import Progress, SpinnerColumn, TextColumn
  28. from rich.table import Table
  29. from rich.text import Text
  30. from transformers import AutoModelForCausalLM, AutoTokenizer
  31. # ═══════════════════════════════════════════════════════════════════════════════
  32. # CONFIGURATION
  33. # ═══════════════════════════════════════════════════════════════════════════════
  34. MODELS = {
  35. "1": ("SFT (Hub)", "tobil/qmd-query-expansion-0.6B-v4", "v3"),
  36. "2": ("GRPO (Hub)", "tobil/qmd-query-expansion-0.6B-v4-grpo", "v3"),
  37. "3": ("SFT (local)", "./qmd-query-expansion-0.6B-v4", "v3"),
  38. "4": ("GRPO (local)", "./qmd-query-expansion-0.6B-v4-grpo", "v3"),
  39. }
  40. BASE_MODEL = "Qwen/Qwen3-0.6B"
  41. # v1 used simple format (before proper chat template)
  42. PROMPT_TEMPLATE_V1 = """Expand this search query:
  43. {query}"""
  44. # v3+ uses tokenizer.apply_chat_template() - see generate_expansion()
  45. # Cyberpunk color palette
  46. CYAN = "#00ffff"
  47. MAGENTA = "#ff00ff"
  48. PURPLE = "#bd93f9"
  49. DIM = "#6272a4"
  50. BG = "#1a0a2e"
  51. GREEN = "#50fa7b"
  52. YELLOW = "#f1fa8c"
  53. RED = "#ff5555"
  54. console = Console()
  55. # ═══════════════════════════════════════════════════════════════════════════════
  56. # SCORING (from evaluate_model.py)
  57. # ═══════════════════════════════════════════════════════════════════════════════
  58. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or',
  59. 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  60. def parse_expansion(text: str) -> dict:
  61. lines = text.strip().split("\n")
  62. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  63. for line in lines:
  64. line = line.strip()
  65. if not line:
  66. continue
  67. if line.startswith("lex:"):
  68. result["lex"].append(line[4:].strip())
  69. elif line.startswith("vec:"):
  70. result["vec"].append(line[4:].strip())
  71. elif line.startswith("hyde:"):
  72. result["hyde"].append(line[5:].strip())
  73. else:
  74. result["invalid"].append(line)
  75. return result
  76. def score_expansion(query: str, expansion: str) -> dict:
  77. """Score an expansion based on format, diversity, hyde, quality."""
  78. text = expansion.strip()
  79. if not text:
  80. return {"total": 0, "percentage": 0, "rating": "Failed", "format": 0,
  81. "diversity": 0, "hyde": 0, "quality": 0, "parsed": {"lex": [], "vec": [], "hyde": [], "invalid": []}}
  82. parsed = parse_expansion(expansion)
  83. # Check for continuation - but only if NO valid lines were found
  84. has_valid = parsed["lex"] or parsed["vec"] or parsed["hyde"]
  85. if not has_valid:
  86. return {"total": 0, "percentage": 0, "rating": "Failed", "format": 0,
  87. "diversity": 0, "hyde": 0, "quality": 0, "is_continuation": True,
  88. "parsed": parsed}
  89. # Format score (0-30)
  90. format_score = 0
  91. if parsed["lex"]:
  92. format_score += 10
  93. if parsed["vec"]:
  94. format_score += 10
  95. if not parsed["invalid"]:
  96. format_score += 10
  97. else:
  98. format_score += max(0, 10 - len(parsed["invalid"]) * 5)
  99. # Diversity score (0-30)
  100. diversity_score = 0
  101. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  102. if types_present >= 2:
  103. diversity_score += 10
  104. total_exp = len(parsed["lex"]) + len(parsed["vec"])
  105. if total_exp >= 2:
  106. diversity_score += 5
  107. diversity_score += 10 # Base diversity points
  108. diversity_score += 5 # Non-echo points
  109. # Hyde score (0-20)
  110. hyde_score = 0
  111. if parsed["hyde"]:
  112. hyde_text = parsed["hyde"][0]
  113. hyde_score += 5 # Present
  114. hyde_len = len(hyde_text)
  115. if 50 <= hyde_len <= 200:
  116. hyde_score += 5
  117. elif hyde_len < 50:
  118. hyde_score += 2
  119. if "\n" not in hyde_text:
  120. hyde_score += 5
  121. hyde_score += 5 # No repetition (simplified)
  122. # Quality score (0-20)
  123. quality_score = 5 # Base relevance (reduced to make room for key term check)
  124. # Lex must preserve key terms from query
  125. stopwords = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of', 'and', 'or', 'with', 'my'}
  126. key_terms = set(query.lower().split()) - stopwords
  127. if parsed["lex"] and key_terms:
  128. lex_with_terms = sum(1 for l in parsed["lex"] if key_terms & set(l.lower().split()))
  129. if lex_with_terms == len(parsed["lex"]):
  130. quality_score += 5
  131. elif lex_with_terms > 0:
  132. quality_score += 2
  133. if parsed["lex"] and parsed["vec"]:
  134. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  135. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  136. if avg_lex <= avg_vec:
  137. quality_score += 5
  138. if parsed["vec"] and all(" " in v and len(v) > 15 for v in parsed["vec"]):
  139. quality_score += 5
  140. total = format_score + diversity_score + hyde_score + min(20, quality_score)
  141. max_possible = 100 if parsed["hyde"] else 80
  142. percentage = total / max_possible * 100
  143. if percentage >= 80:
  144. rating = "Excellent"
  145. elif percentage >= 60:
  146. rating = "Good"
  147. elif percentage >= 40:
  148. rating = "Acceptable"
  149. elif percentage >= 20:
  150. rating = "Poor"
  151. else:
  152. rating = "Failed"
  153. return {
  154. "format": format_score,
  155. "diversity": diversity_score,
  156. "hyde": hyde_score,
  157. "quality": quality_score,
  158. "total": total,
  159. "max_possible": max_possible,
  160. "percentage": percentage,
  161. "rating": rating,
  162. "parsed": parsed,
  163. }
  164. # ═══════════════════════════════════════════════════════════════════════════════
  165. # UI COMPONENTS
  166. # ═══════════════════════════════════════════════════════════════════════════════
  167. BANNER = """
  168. [bold cyan]╔═══════════════════════════════════════════════════════════════════════════╗[/]
  169. [bold cyan]║[/] [bold magenta]░▀▄░░░░░░░░░░░▄▀░[/] [bold white]Q M D Q U E R Y E X P A N D E R[/] [bold magenta]░▀▄░░░░░░░░░░░▄▀░[/] [bold cyan]║[/]
  170. [bold cyan]╚═══════════════════════════════════════════════════════════════════════════╝[/]
  171. """
  172. def show_banner():
  173. console.print(BANNER)
  174. def show_model_menu(current: Optional[str] = None) -> str:
  175. """Display model selection menu."""
  176. console.print()
  177. console.print(f"[bold {CYAN}]◆ SELECT MODEL[/]")
  178. console.print(f"[{DIM}]{'─' * 50}[/]")
  179. for key, (name, path, version) in MODELS.items():
  180. marker = "[bold green]●[/]" if path == current else f"[{DIM}]○[/]"
  181. console.print(f" {marker} [{CYAN}]{key}[/] │ {name} [{DIM}]({version})[/]")
  182. console.print(f" [{DIM}]{path}[/]")
  183. console.print(f"[{DIM}]{'─' * 50}[/]")
  184. return prompt(" Enter choice (1-3): ", style=Style.from_dict({'': CYAN})).strip()
  185. def render_expansion(expansion: str, scores: dict) -> Panel:
  186. """Render the expansion output with syntax highlighting."""
  187. parsed = scores.get("parsed", parse_expansion(expansion))
  188. content = Text()
  189. # Lex lines
  190. for lex in parsed["lex"]:
  191. content.append("lex: ", style=f"bold {CYAN}")
  192. content.append(f"{lex}\n", style="white")
  193. # Vec lines
  194. for vec in parsed["vec"]:
  195. content.append("vec: ", style=f"bold {MAGENTA}")
  196. content.append(f"{vec}\n", style="white")
  197. # Hyde lines
  198. for hyde in parsed["hyde"]:
  199. content.append("hyde: ", style=f"bold {PURPLE}")
  200. content.append(f"{hyde}\n", style=f"italic {DIM}")
  201. # Invalid lines
  202. for inv in parsed["invalid"]:
  203. content.append(f"[invalid] {inv}\n", style=f"dim {RED}")
  204. return Panel(
  205. content,
  206. title=f"[bold {CYAN}]◈ EXPANSION[/]",
  207. border_style=CYAN,
  208. padding=(0, 1),
  209. )
  210. def render_scores(scores: dict) -> Panel:
  211. """Render score breakdown as a compact table."""
  212. rating = scores["rating"]
  213. rating_color = {
  214. "Excellent": GREEN,
  215. "Good": CYAN,
  216. "Acceptable": YELLOW,
  217. "Poor": RED,
  218. "Failed": RED,
  219. }.get(rating, DIM)
  220. table = Table(show_header=False, box=None, padding=(0, 1))
  221. table.add_column("metric", style=DIM)
  222. table.add_column("score", justify="right")
  223. table.add_column("bar", width=12)
  224. def score_bar(val: int, max_val: int) -> str:
  225. filled = int(val / max_val * 10)
  226. return f"[{CYAN}]{'█' * filled}[/][{DIM}]{'░' * (10 - filled)}[/]"
  227. table.add_row("Format", f"[white]{scores['format']}[/]/30", score_bar(scores['format'], 30))
  228. table.add_row("Diversity", f"[white]{scores['diversity']}[/]/30", score_bar(scores['diversity'], 30))
  229. table.add_row("Hyde", f"[white]{scores['hyde']}[/]/20", score_bar(scores['hyde'], 20))
  230. table.add_row("Quality", f"[white]{scores['quality']}[/]/20", score_bar(scores['quality'], 20))
  231. table.add_row("", "", "")
  232. table.add_row(
  233. f"[bold]TOTAL[/]",
  234. f"[bold white]{scores['total']}[/]/{scores.get('max_possible', 80)}",
  235. f"[bold {rating_color}]{rating}[/]"
  236. )
  237. return Panel(
  238. table,
  239. title=f"[bold {MAGENTA}]◈ SCORES[/]",
  240. border_style=MAGENTA,
  241. padding=(0, 1),
  242. )
  243. def render_history(history: deque) -> Panel:
  244. """Render recent query history."""
  245. content = Text()
  246. for i, (query, rating) in enumerate(history):
  247. rating_color = {
  248. "Excellent": GREEN, "Good": CYAN, "Acceptable": YELLOW,
  249. "Poor": RED, "Failed": RED,
  250. }.get(rating, DIM)
  251. content.append(f" [{DIM}]{i+1}.[/] {query[:40]}")
  252. if len(query) > 40:
  253. content.append(f"[{DIM}]...[/]")
  254. content.append(f" [{rating_color}]●[/]\n")
  255. if not history:
  256. content.append(f" [{DIM}]No queries yet[/]")
  257. return Panel(
  258. content,
  259. title=f"[bold {PURPLE}]◈ HISTORY[/]",
  260. border_style=PURPLE,
  261. padding=(0, 1),
  262. )
  263. # ═══════════════════════════════════════════════════════════════════════════════
  264. # MODEL LOADING
  265. # ═══════════════════════════════════════════════════════════════════════════════
  266. @dataclass
  267. class LoadedModel:
  268. model: any
  269. tokenizer: any
  270. name: str
  271. path: str
  272. version: str # "v1" or "v3" - determines prompt template
  273. def load_model(model_path: str, model_name: str, version: str) -> LoadedModel:
  274. """Load model with progress indicator."""
  275. with Progress(
  276. SpinnerColumn(spinner_name="dots", style=CYAN),
  277. TextColumn(f"[{CYAN}]Loading {{task.description}}...[/]"),
  278. console=console,
  279. transient=True,
  280. ) as progress:
  281. task = progress.add_task("tokenizer", total=None)
  282. tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
  283. if tokenizer.pad_token is None:
  284. tokenizer.pad_token = tokenizer.eos_token
  285. progress.update(task, description="base model")
  286. base = AutoModelForCausalLM.from_pretrained(
  287. BASE_MODEL,
  288. torch_dtype=torch.bfloat16,
  289. device_map="auto",
  290. )
  291. progress.update(task, description="adapter")
  292. model = PeftModel.from_pretrained(base, model_path)
  293. model.eval()
  294. return LoadedModel(model=model, tokenizer=tokenizer, name=model_name, path=model_path, version=version)
  295. DEBUG = False # Set to True for debug output
  296. def generate_expansion(loaded: LoadedModel, query: str) -> str:
  297. """Generate expansion using proper Qwen3 chat template."""
  298. if loaded.version == "v3":
  299. # Use tokenizer's chat template with /no_think to disable thinking mode
  300. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  301. prompt_text = loaded.tokenizer.apply_chat_template(
  302. messages,
  303. tokenize=False,
  304. add_generation_prompt=True
  305. )
  306. else:
  307. # v1 used simple format
  308. prompt_text = PROMPT_TEMPLATE_V1.format(query=query)
  309. if DEBUG:
  310. console.print(f"[{DIM}]─── DEBUG: Prompt ───[/]")
  311. console.print(f"[{DIM}]{repr(prompt_text)}[/]")
  312. inputs = loaded.tokenizer(prompt_text, return_tensors="pt").to(loaded.model.device)
  313. with torch.no_grad():
  314. outputs = loaded.model.generate(
  315. **inputs,
  316. max_new_tokens=200,
  317. temperature=0.7,
  318. do_sample=True,
  319. pad_token_id=loaded.tokenizer.pad_token_id,
  320. eos_token_id=loaded.tokenizer.eos_token_id,
  321. )
  322. full_output = loaded.tokenizer.decode(outputs[0], skip_special_tokens=True)
  323. if DEBUG:
  324. console.print(f"[{DIM}]─── DEBUG: Full output ───[/]")
  325. console.print(f"[{DIM}]{repr(full_output[:500])}[/]")
  326. # Extract assistant response (skip_special_tokens leaves "user\n...\nassistant\n...")
  327. if "\nassistant\n" in full_output:
  328. expansion = full_output.split("\nassistant\n")[-1].strip()
  329. elif "assistant\n" in full_output:
  330. expansion = full_output.split("assistant\n")[-1].strip()
  331. else:
  332. expansion = full_output[len(prompt_text):].strip()
  333. # Remove any <think> tags that might remain
  334. if expansion.startswith("<think>"):
  335. think_end = expansion.find("</think>")
  336. if think_end != -1:
  337. expansion = expansion[think_end + 8:].strip()
  338. if DEBUG:
  339. console.print(f"[{DIM}]─── DEBUG: Expansion ───[/]")
  340. console.print(f"[{DIM}]{repr(expansion[:300])}[/]")
  341. return expansion
  342. # ═══════════════════════════════════════════════════════════════════════════════
  343. # MAIN LOOP
  344. # ═══════════════════════════════════════════════════════════════════════════════
  345. def main():
  346. console.clear()
  347. show_banner()
  348. # Model selection
  349. choice = show_model_menu()
  350. if choice not in MODELS:
  351. console.print(f"[{RED}]Invalid choice. Exiting.[/]")
  352. return
  353. model_name, model_path, model_version = MODELS[choice]
  354. console.print()
  355. try:
  356. loaded = load_model(model_path, model_name, model_version)
  357. except Exception as e:
  358. console.print(f"[{RED}]Failed to load model: {e}[/]")
  359. return
  360. console.print(f"[{GREEN}]✓ Model loaded: {model_name}[/]")
  361. console.print()
  362. # Query history
  363. history: deque = deque(maxlen=5)
  364. input_history = InMemoryHistory()
  365. # Main loop
  366. console.print(f"[{DIM}]Enter queries to expand. Type 'quit' to exit, 'model' to switch models.[/]")
  367. console.print()
  368. while True:
  369. try:
  370. query = prompt(
  371. f"[{CYAN}]❯[/] ",
  372. history=input_history,
  373. style=Style.from_dict({'': 'ansicyan'}),
  374. ).strip()
  375. except (KeyboardInterrupt, EOFError):
  376. break
  377. if not query:
  378. continue
  379. if query.lower() == 'quit':
  380. break
  381. if query.lower() == 'model':
  382. console.clear()
  383. show_banner()
  384. choice = show_model_menu(loaded.path)
  385. if choice in MODELS:
  386. new_name, new_path, new_version = MODELS[choice]
  387. if new_path != loaded.path:
  388. console.print()
  389. loaded = load_model(new_path, new_name, new_version)
  390. console.print(f"[{GREEN}]✓ Switched to: {new_name}[/]")
  391. console.print()
  392. continue
  393. if query.lower() == 'history':
  394. console.print(render_history(history))
  395. continue
  396. # Generate expansion
  397. with Progress(
  398. SpinnerColumn(spinner_name="dots", style=MAGENTA),
  399. TextColumn(f"[{MAGENTA}]Expanding...[/]"),
  400. console=console,
  401. transient=True,
  402. ) as progress:
  403. progress.add_task("gen", total=None)
  404. expansion = generate_expansion(loaded, query)
  405. # Score it
  406. scores = score_expansion(query, expansion)
  407. # Add to history
  408. history.appendleft((query, scores["rating"]))
  409. # Display results
  410. console.print()
  411. console.print(f"[{DIM}]Query: [/][bold white]{query}[/]")
  412. console.print()
  413. # Side-by-side layout
  414. console.print(Columns([
  415. render_expansion(expansion, scores),
  416. render_scores(scores),
  417. ], equal=True, expand=True))
  418. console.print()
  419. console.print(f"\n[{CYAN}]◆ Goodbye![/]\n")
  420. if __name__ == "__main__":
  421. main()