evaluate_model.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "huggingface_hub",
  8. # "accelerate",
  9. # ]
  10. # ///
  11. """
  12. Evaluate QMD query expansion model quality.
  13. See SCORING.md for detailed scoring criteria.
  14. """
  15. import json
  16. import re
  17. import torch
  18. from collections import Counter
  19. from transformers import AutoModelForCausalLM, AutoTokenizer
  20. from peft import PeftModel
  21. # Test queries covering different QMD use cases
  22. TEST_QUERIES = [
  23. # Technical documentation
  24. "how to configure authentication",
  25. "typescript async await",
  26. "docker compose networking",
  27. "git rebase vs merge",
  28. "react useEffect cleanup",
  29. # Short/ambiguous queries
  30. "auth",
  31. "config",
  32. "setup",
  33. "api",
  34. # Personal notes / journals style
  35. "meeting notes project kickoff",
  36. "ideas for new feature",
  37. "todo list app architecture",
  38. # Research / learning
  39. "what is dependency injection",
  40. "difference between sql and nosql",
  41. "kubernetes vs docker swarm",
  42. # Error/debugging
  43. "connection timeout error",
  44. "memory leak debugging",
  45. "cors error fix",
  46. # Complex queries
  47. "how to implement caching with redis in nodejs",
  48. "best practices for api rate limiting",
  49. "setting up ci cd pipeline with github actions",
  50. ]
  51. # Prompt is generated via tokenizer.apply_chat_template() - see generate_expansion()
  52. # Don't manually construct <|im_start|> tags
  53. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  54. def parse_expansion(text: str) -> dict:
  55. """Parse expansion into structured format."""
  56. lines = text.strip().split("\n")
  57. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  58. for line in lines:
  59. line = line.strip()
  60. if not line:
  61. continue
  62. if line.startswith("lex:"):
  63. result["lex"].append(line[4:].strip())
  64. elif line.startswith("vec:"):
  65. result["vec"].append(line[4:].strip())
  66. elif line.startswith("hyde:"):
  67. result["hyde"].append(line[5:].strip())
  68. else:
  69. result["invalid"].append(line)
  70. return result
  71. def edit_distance_simple(a: str, b: str) -> int:
  72. """Simple word-level edit distance."""
  73. words_a = set(a.lower().split())
  74. words_b = set(b.lower().split())
  75. return len(words_a ^ words_b) # Symmetric difference
  76. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  77. """Check if two strings are sufficiently different."""
  78. a, b = a.lower().strip(), b.lower().strip()
  79. if a == b:
  80. return False
  81. if a in b or b in a:
  82. return False
  83. return edit_distance_simple(a, b) >= min_distance
  84. def echoes_query(expansion: str, query: str) -> bool:
  85. """Check if expansion is just echoing the query."""
  86. exp = expansion.lower().strip()
  87. q = query.lower().strip()
  88. # Exact match or very close
  89. if exp == q:
  90. return True
  91. # Query is contained in expansion with little else
  92. if q in exp and len(exp) < len(q) + 10:
  93. return True
  94. return False
  95. def get_key_terms(query: str) -> set:
  96. """Extract key terms from query (excluding stopwords)."""
  97. stopwords = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  98. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we'}
  99. words = set(query.lower().split())
  100. return words - stopwords
  101. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  102. """Check if lex line contains at least one key term from query."""
  103. key_terms = get_key_terms(query)
  104. if not key_terms: # Very short query
  105. return True
  106. lex_words = set(lex_line.lower().split())
  107. return bool(key_terms & lex_words)
  108. def word_repetition_penalty(text: str) -> int:
  109. """Count penalty for repeated words (excluding stopwords)."""
  110. words = re.findall(r'\b\w+\b', text.lower())
  111. counts = Counter(words)
  112. penalty = 0
  113. for word, count in counts.items():
  114. if count >= 3 and word not in STOPWORDS and len(word) > 2:
  115. penalty += (count - 2) * 2
  116. return penalty
  117. def is_continuation(expansion: str) -> bool:
  118. """
  119. Detect if output is a continuation rather than proper expansion.
  120. A continuation is when the model continues the query as prose
  121. instead of outputting lex:/vec:/hyde: lines.
  122. """
  123. text = expansion.strip()
  124. if not text:
  125. return True
  126. # Check first non-empty line
  127. first_line = text.split("\n")[0].strip()
  128. # Valid outputs must start with a prefix
  129. valid_prefixes = ("lex:", "vec:", "hyde:")
  130. if first_line.startswith(valid_prefixes):
  131. return False
  132. # If first line doesn't have a valid prefix, it's a continuation
  133. # Exception: empty first line (check second)
  134. lines = [l.strip() for l in text.split("\n") if l.strip()]
  135. if lines and not lines[0].startswith(valid_prefixes):
  136. return True
  137. return False
  138. def score_expansion(query: str, expansion: str) -> dict:
  139. """
  140. Score an expansion based on SCORING.md criteria.
  141. Returns dict with score breakdown and total (0-100).
  142. """
  143. # HARD FAIL: Continuation detection
  144. if is_continuation(expansion):
  145. return {
  146. "format": 0,
  147. "diversity": 0,
  148. "hyde": 0,
  149. "quality": 0,
  150. "total": 0,
  151. "max_possible": 80,
  152. "percentage": 0,
  153. "rating": "Failed",
  154. "deductions": ["CONTINUATION DETECTED - output does not start with lex:/vec:/hyde:"],
  155. "parsed": {"lex": [], "vec": [], "hyde": [], "invalid": [expansion[:100]]},
  156. "is_continuation": True,
  157. }
  158. parsed = parse_expansion(expansion)
  159. scores = {
  160. "format": 0,
  161. "diversity": 0,
  162. "hyde": 0,
  163. "quality": 0,
  164. "deductions": [],
  165. }
  166. # === FORMAT (0-30) ===
  167. format_score = 0
  168. # Has at least one lex: line (+10)
  169. if parsed["lex"]:
  170. format_score += 10
  171. else:
  172. scores["deductions"].append("missing lex: (-10)")
  173. # Has at least one vec: line (+10)
  174. if parsed["vec"]:
  175. format_score += 10
  176. else:
  177. scores["deductions"].append("missing vec: (-10)")
  178. # All lines have valid prefix (+10, -5 per invalid)
  179. if not parsed["invalid"]:
  180. format_score += 10
  181. else:
  182. invalid_penalty = min(10, len(parsed["invalid"]) * 5)
  183. format_score += (10 - invalid_penalty)
  184. scores["deductions"].append(f"{len(parsed['invalid'])} invalid lines (-{invalid_penalty})")
  185. scores["format"] = max(0, format_score)
  186. # === DIVERSITY (0-30) ===
  187. diversity_score = 0
  188. # 2+ different types present (+10)
  189. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  190. if types_present >= 2:
  191. diversity_score += 10
  192. else:
  193. scores["deductions"].append("only one type present (-10)")
  194. # 2+ total expansions (+5)
  195. total_expansions = len(parsed["lex"]) + len(parsed["vec"])
  196. if total_expansions >= 2:
  197. diversity_score += 5
  198. else:
  199. scores["deductions"].append("fewer than 2 expansions (-5)")
  200. # Multiple lex: lines are diverse (+5, -2 per duplicate pair)
  201. lex_diverse_score = 5
  202. for i, a in enumerate(parsed["lex"]):
  203. for b in parsed["lex"][i+1:]:
  204. if not is_diverse(a, b, min_distance=2):
  205. lex_diverse_score -= 2
  206. scores["deductions"].append(f"lex duplicates: '{a[:20]}...' ~ '{b[:20]}...'")
  207. diversity_score += max(0, lex_diverse_score)
  208. # Multiple vec: lines are diverse (+5, -2 per duplicate pair)
  209. vec_diverse_score = 5
  210. for i, a in enumerate(parsed["vec"]):
  211. for b in parsed["vec"][i+1:]:
  212. if not is_diverse(a, b, min_distance=3):
  213. vec_diverse_score -= 2
  214. scores["deductions"].append(f"vec duplicates: '{a[:20]}...' ~ '{b[:20]}...'")
  215. diversity_score += max(0, vec_diverse_score)
  216. # lex/vec not identical to original query (+5, -5 per echo)
  217. echo_score = 5
  218. for exp in parsed["lex"] + parsed["vec"]:
  219. if echoes_query(exp, query):
  220. echo_score -= 5
  221. scores["deductions"].append(f"echoes query: '{exp[:30]}...'")
  222. diversity_score += max(0, echo_score)
  223. scores["diversity"] = max(0, diversity_score)
  224. # === HYDE QUALITY (0-20, optional bonus) ===
  225. hyde_score = 0
  226. if parsed["hyde"]:
  227. hyde_text = parsed["hyde"][0] # Only first hyde counts
  228. # Hyde present and well-formed (+5)
  229. hyde_score += 5
  230. # Hyde is concise: 50-200 chars (+5)
  231. hyde_len = len(hyde_text)
  232. if 50 <= hyde_len <= 200:
  233. hyde_score += 5
  234. elif hyde_len < 50:
  235. hyde_score += 2
  236. scores["deductions"].append(f"hyde too short ({hyde_len} chars)")
  237. else:
  238. scores["deductions"].append(f"hyde too long ({hyde_len} chars)")
  239. # Hyde has no newlines (+5)
  240. if "\n" not in hyde_text:
  241. hyde_score += 5
  242. else:
  243. scores["deductions"].append("hyde contains newlines")
  244. # Hyde has no excessive repetition (+5)
  245. rep_penalty = word_repetition_penalty(hyde_text)
  246. if rep_penalty == 0:
  247. hyde_score += 5
  248. else:
  249. hyde_score += max(0, 5 - rep_penalty)
  250. scores["deductions"].append(f"hyde repetition penalty (-{min(5, rep_penalty)})")
  251. scores["hyde"] = hyde_score
  252. # === QUALITY (0-20) ===
  253. quality_score = 10 # Base relevance (assume relevant unless obvious garbage)
  254. # Lex lines should be keyword-focused (shorter than vec on average)
  255. if parsed["lex"] and parsed["vec"]:
  256. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  257. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  258. if avg_lex <= avg_vec:
  259. quality_score += 5
  260. else:
  261. scores["deductions"].append("lex longer than vec (should be keywords)")
  262. else:
  263. quality_score += 2 # Partial credit
  264. # Vec lines should be natural language (contain spaces, longer)
  265. if parsed["vec"]:
  266. vec_natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  267. if vec_natural == len(parsed["vec"]):
  268. quality_score += 5
  269. else:
  270. quality_score += 2
  271. scores["deductions"].append("some vec lines too short/keyword-like")
  272. # Lex lines must preserve key terms from query (not be generic)
  273. if parsed["lex"]:
  274. lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  275. if lex_with_terms == len(parsed["lex"]):
  276. quality_score += 5
  277. elif lex_with_terms > 0:
  278. quality_score += 2
  279. else:
  280. scores["deductions"].append("lex lines too generic - missing key terms from query")
  281. scores["quality"] = min(20, quality_score) # Cap at 20
  282. # === TOTAL ===
  283. scores["total"] = scores["format"] + scores["diversity"] + scores["hyde"] + scores["quality"]
  284. scores["max_possible"] = 100 if parsed["hyde"] else 80
  285. scores["percentage"] = scores["total"] / scores["max_possible"] * 100
  286. # Rating
  287. pct = scores["percentage"]
  288. if pct >= 80:
  289. scores["rating"] = "Excellent"
  290. elif pct >= 60:
  291. scores["rating"] = "Good"
  292. elif pct >= 40:
  293. scores["rating"] = "Acceptable"
  294. elif pct >= 20:
  295. scores["rating"] = "Poor"
  296. else:
  297. scores["rating"] = "Failed"
  298. scores["parsed"] = parsed
  299. scores["is_continuation"] = False
  300. return scores
  301. def load_model(model_name: str, base_model: str = "Qwen/Qwen3-0.6B"):
  302. """Load the finetuned model."""
  303. print(f"Loading tokenizer from {base_model}...")
  304. tokenizer = AutoTokenizer.from_pretrained(base_model)
  305. if tokenizer.pad_token is None:
  306. tokenizer.pad_token = tokenizer.eos_token
  307. print(f"Loading base model...")
  308. base = AutoModelForCausalLM.from_pretrained(
  309. base_model,
  310. torch_dtype=torch.bfloat16,
  311. device_map="auto",
  312. )
  313. print(f"Loading adapter from {model_name}...")
  314. model = PeftModel.from_pretrained(base, model_name)
  315. model.eval()
  316. return model, tokenizer
  317. def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200) -> str:
  318. """Generate query expansion using proper Qwen3 chat template."""
  319. # Use tokenizer's chat template with /no_think to disable thinking mode
  320. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  321. prompt = tokenizer.apply_chat_template(
  322. messages,
  323. tokenize=False,
  324. add_generation_prompt=True
  325. )
  326. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  327. with torch.no_grad():
  328. outputs = model.generate(
  329. **inputs,
  330. max_new_tokens=max_new_tokens,
  331. temperature=0.7,
  332. do_sample=True,
  333. pad_token_id=tokenizer.pad_token_id,
  334. eos_token_id=tokenizer.eos_token_id,
  335. )
  336. # Decode and extract expansion
  337. # skip_special_tokens=True strips <|im_start|> etc, leaving "user\n...\nassistant\n..."
  338. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  339. # Extract assistant response
  340. if "\nassistant\n" in full_output:
  341. expansion = full_output.split("\nassistant\n")[-1].strip()
  342. elif "assistant\n" in full_output:
  343. expansion = full_output.split("assistant\n")[-1].strip()
  344. else:
  345. # Fallback: strip prompt length
  346. expansion = full_output[len(prompt):].strip()
  347. # Remove any <think> tags that might remain
  348. if expansion.startswith("<think>"):
  349. # Find end of thinking block
  350. think_end = expansion.find("</think>")
  351. if think_end != -1:
  352. expansion = expansion[think_end + 8:].strip()
  353. return expansion
  354. def print_score_breakdown(scores: dict):
  355. """Pretty print score breakdown."""
  356. print(f" Score: {scores['total']}/{scores['max_possible']} ({scores['percentage']:.0f}%) - {scores['rating']}")
  357. print(f" Format: {scores['format']}/30")
  358. print(f" Diversity: {scores['diversity']}/30")
  359. print(f" Hyde: {scores['hyde']}/20")
  360. print(f" Quality: {scores['quality']}/20")
  361. if scores["deductions"]:
  362. print(f" Deductions:")
  363. for d in scores["deductions"][:5]: # Show top 5
  364. print(f" - {d}")
  365. if len(scores["deductions"]) > 5:
  366. print(f" ... and {len(scores['deductions']) - 5} more")
  367. def run_examples():
  368. """Run good and bad examples to demonstrate scoring."""
  369. print("=" * 70)
  370. print("SCORING EXAMPLES")
  371. print("=" * 70)
  372. # Good example
  373. good_expansion = """lex: react hooks tutorial
  374. lex: usestate useeffect
  375. vec: how to use react hooks in functional components
  376. vec: react hooks best practices guide
  377. hyde: React Hooks allow you to use state and lifecycle features in functional components without writing a class."""
  378. print("\n[GOOD EXAMPLE]")
  379. print(f"Query: react hooks")
  380. print(f"Output:\n{good_expansion}")
  381. scores = score_expansion("react hooks", good_expansion)
  382. print_score_breakdown(scores)
  383. # Bad example
  384. bad_expansion = """auth is an important concept that relates to authentication.
  385. The answer should be in Chinese.
  386. The answer should be in Chinese."""
  387. print("\n[BAD EXAMPLE]")
  388. print(f"Query: auth")
  389. print(f"Output:\n{bad_expansion}")
  390. scores = score_expansion("auth", bad_expansion)
  391. print_score_breakdown(scores)
  392. # Medium example - repetitive hyde
  393. medium_expansion = """lex: docker networking
  394. vec: docker networking
  395. hyde: Docker networking is an important concept. Docker networking is used for container communication. Docker networking configuration is essential."""
  396. print("\n[MEDIUM EXAMPLE - Repetitive]")
  397. print(f"Query: docker networking")
  398. print(f"Output:\n{medium_expansion}")
  399. scores = score_expansion("docker networking", medium_expansion)
  400. print_score_breakdown(scores)
  401. # Medium example - echoes query
  402. echo_expansion = """lex: auth
  403. lex: authentication
  404. vec: auth
  405. vec: authentication configuration
  406. hyde: Authentication is the process of verifying identity."""
  407. print("\n[MEDIUM EXAMPLE - Echoes Query]")
  408. print(f"Query: auth")
  409. print(f"Output:\n{echo_expansion}")
  410. scores = score_expansion("auth", echo_expansion)
  411. print_score_breakdown(scores)
  412. print("\n" + "=" * 70)
  413. def main():
  414. import argparse
  415. parser = argparse.ArgumentParser()
  416. parser.add_argument("--model", default="tobil/qmd-query-expansion-0.6B",
  417. help="Model to evaluate")
  418. parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
  419. help="Base model")
  420. parser.add_argument("--output", default="evaluation_results.json",
  421. help="Output file for results")
  422. parser.add_argument("--queries", type=str, help="Custom queries file (one per line)")
  423. parser.add_argument("--examples", action="store_true", help="Run scoring examples only")
  424. args = parser.parse_args()
  425. # Run examples if requested
  426. if args.examples:
  427. run_examples()
  428. return
  429. # Load custom queries if provided
  430. queries = TEST_QUERIES
  431. if args.queries:
  432. with open(args.queries) as f:
  433. queries = [l.strip() for l in f if l.strip()]
  434. # Load model
  435. model, tokenizer = load_model(args.model, args.base_model)
  436. # Run evaluation
  437. results = []
  438. print(f"\n{'='*70}")
  439. print("EVALUATION RESULTS")
  440. print(f"{'='*70}\n")
  441. for i, query in enumerate(queries, 1):
  442. print(f"[{i}/{len(queries)}] Query: {query}")
  443. print("-" * 50)
  444. expansion = generate_expansion(model, tokenizer, query)
  445. scores = score_expansion(query, expansion)
  446. print(expansion)
  447. print()
  448. print_score_breakdown(scores)
  449. print()
  450. results.append({
  451. "query": query,
  452. "expansion": expansion,
  453. "scores": {k: v for k, v in scores.items() if k != "parsed"},
  454. "parsed": scores["parsed"],
  455. })
  456. # Summary
  457. print(f"\n{'='*70}")
  458. print("SUMMARY")
  459. print(f"{'='*70}")
  460. avg_score = sum(r["scores"]["percentage"] for r in results) / len(results)
  461. excellent = sum(1 for r in results if r["scores"]["rating"] == "Excellent")
  462. good = sum(1 for r in results if r["scores"]["rating"] == "Good")
  463. acceptable = sum(1 for r in results if r["scores"]["rating"] == "Acceptable")
  464. poor = sum(1 for r in results if r["scores"]["rating"] == "Poor")
  465. failed = sum(1 for r in results if r["scores"]["rating"] == "Failed")
  466. print(f" Total queries: {len(results)}")
  467. print(f" Average score: {avg_score:.1f}%")
  468. print(f" Ratings:")
  469. print(f" Excellent: {excellent}")
  470. print(f" Good: {good}")
  471. print(f" Acceptable: {acceptable}")
  472. print(f" Poor: {poor}")
  473. print(f" Failed: {failed}")
  474. # Save results
  475. with open(args.output, "w") as f:
  476. json.dump(results, f, indent=2)
  477. print(f"\n Results saved to: {args.output}")
  478. if __name__ == "__main__":
  479. main()