evaluate_model.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "huggingface_hub",
  8. # "accelerate",
  9. # ]
  10. # ///
  11. """
  12. Evaluate QMD query expansion model quality.
  13. See SCORING.md for detailed scoring criteria.
  14. """
  15. import json
  16. import re
  17. import torch
  18. from collections import Counter
  19. from transformers import AutoModelForCausalLM, AutoTokenizer
  20. from peft import PeftModel
  21. # Test queries covering different QMD use cases
  22. TEST_QUERIES = [
  23. # Technical documentation
  24. "how to configure authentication",
  25. "typescript async await",
  26. "docker compose networking",
  27. "git rebase vs merge",
  28. "react useEffect cleanup",
  29. # Short/ambiguous queries
  30. "auth",
  31. "config",
  32. "setup",
  33. "api",
  34. # Personal notes / journals style
  35. "meeting notes project kickoff",
  36. "ideas for new feature",
  37. "todo list app architecture",
  38. # Research / learning
  39. "what is dependency injection",
  40. "difference between sql and nosql",
  41. "kubernetes vs docker swarm",
  42. # Error/debugging
  43. "connection timeout error",
  44. "memory leak debugging",
  45. "cors error fix",
  46. # Complex queries
  47. "how to implement caching with redis in nodejs",
  48. "best practices for api rate limiting",
  49. "setting up ci cd pipeline with github actions",
  50. ]
  51. PROMPT_TEMPLATE = """Expand this search query:
  52. {query}"""
  53. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  54. def parse_expansion(text: str) -> dict:
  55. """Parse expansion into structured format."""
  56. lines = text.strip().split("\n")
  57. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  58. for line in lines:
  59. line = line.strip()
  60. if not line:
  61. continue
  62. if line.startswith("lex:"):
  63. result["lex"].append(line[4:].strip())
  64. elif line.startswith("vec:"):
  65. result["vec"].append(line[4:].strip())
  66. elif line.startswith("hyde:"):
  67. result["hyde"].append(line[5:].strip())
  68. else:
  69. result["invalid"].append(line)
  70. return result
  71. def edit_distance_simple(a: str, b: str) -> int:
  72. """Simple word-level edit distance."""
  73. words_a = set(a.lower().split())
  74. words_b = set(b.lower().split())
  75. return len(words_a ^ words_b) # Symmetric difference
  76. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  77. """Check if two strings are sufficiently different."""
  78. a, b = a.lower().strip(), b.lower().strip()
  79. if a == b:
  80. return False
  81. if a in b or b in a:
  82. return False
  83. return edit_distance_simple(a, b) >= min_distance
  84. def echoes_query(expansion: str, query: str) -> bool:
  85. """Check if expansion is just echoing the query."""
  86. exp = expansion.lower().strip()
  87. q = query.lower().strip()
  88. # Exact match or very close
  89. if exp == q:
  90. return True
  91. # Query is contained in expansion with little else
  92. if q in exp and len(exp) < len(q) + 10:
  93. return True
  94. return False
  95. def word_repetition_penalty(text: str) -> int:
  96. """Count penalty for repeated words (excluding stopwords)."""
  97. words = re.findall(r'\b\w+\b', text.lower())
  98. counts = Counter(words)
  99. penalty = 0
  100. for word, count in counts.items():
  101. if count >= 3 and word not in STOPWORDS and len(word) > 2:
  102. penalty += (count - 2) * 2
  103. return penalty
  104. def score_expansion(query: str, expansion: str) -> dict:
  105. """
  106. Score an expansion based on SCORING.md criteria.
  107. Returns dict with score breakdown and total (0-100).
  108. """
  109. parsed = parse_expansion(expansion)
  110. scores = {
  111. "format": 0,
  112. "diversity": 0,
  113. "hyde": 0,
  114. "quality": 0,
  115. "deductions": [],
  116. }
  117. # === FORMAT (0-30) ===
  118. format_score = 0
  119. # Has at least one lex: line (+10)
  120. if parsed["lex"]:
  121. format_score += 10
  122. else:
  123. scores["deductions"].append("missing lex: (-10)")
  124. # Has at least one vec: line (+10)
  125. if parsed["vec"]:
  126. format_score += 10
  127. else:
  128. scores["deductions"].append("missing vec: (-10)")
  129. # All lines have valid prefix (+10, -5 per invalid)
  130. if not parsed["invalid"]:
  131. format_score += 10
  132. else:
  133. invalid_penalty = min(10, len(parsed["invalid"]) * 5)
  134. format_score += (10 - invalid_penalty)
  135. scores["deductions"].append(f"{len(parsed['invalid'])} invalid lines (-{invalid_penalty})")
  136. scores["format"] = max(0, format_score)
  137. # === DIVERSITY (0-30) ===
  138. diversity_score = 0
  139. # 2+ different types present (+10)
  140. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  141. if types_present >= 2:
  142. diversity_score += 10
  143. else:
  144. scores["deductions"].append("only one type present (-10)")
  145. # 2+ total expansions (+5)
  146. total_expansions = len(parsed["lex"]) + len(parsed["vec"])
  147. if total_expansions >= 2:
  148. diversity_score += 5
  149. else:
  150. scores["deductions"].append("fewer than 2 expansions (-5)")
  151. # Multiple lex: lines are diverse (+5, -2 per duplicate pair)
  152. lex_diverse_score = 5
  153. for i, a in enumerate(parsed["lex"]):
  154. for b in parsed["lex"][i+1:]:
  155. if not is_diverse(a, b, min_distance=2):
  156. lex_diverse_score -= 2
  157. scores["deductions"].append(f"lex duplicates: '{a[:20]}...' ~ '{b[:20]}...'")
  158. diversity_score += max(0, lex_diverse_score)
  159. # Multiple vec: lines are diverse (+5, -2 per duplicate pair)
  160. vec_diverse_score = 5
  161. for i, a in enumerate(parsed["vec"]):
  162. for b in parsed["vec"][i+1:]:
  163. if not is_diverse(a, b, min_distance=3):
  164. vec_diverse_score -= 2
  165. scores["deductions"].append(f"vec duplicates: '{a[:20]}...' ~ '{b[:20]}...'")
  166. diversity_score += max(0, vec_diverse_score)
  167. # lex/vec not identical to original query (+5, -5 per echo)
  168. echo_score = 5
  169. for exp in parsed["lex"] + parsed["vec"]:
  170. if echoes_query(exp, query):
  171. echo_score -= 5
  172. scores["deductions"].append(f"echoes query: '{exp[:30]}...'")
  173. diversity_score += max(0, echo_score)
  174. scores["diversity"] = max(0, diversity_score)
  175. # === HYDE QUALITY (0-20, optional bonus) ===
  176. hyde_score = 0
  177. if parsed["hyde"]:
  178. hyde_text = parsed["hyde"][0] # Only first hyde counts
  179. # Hyde present and well-formed (+5)
  180. hyde_score += 5
  181. # Hyde is concise: 50-200 chars (+5)
  182. hyde_len = len(hyde_text)
  183. if 50 <= hyde_len <= 200:
  184. hyde_score += 5
  185. elif hyde_len < 50:
  186. hyde_score += 2
  187. scores["deductions"].append(f"hyde too short ({hyde_len} chars)")
  188. else:
  189. scores["deductions"].append(f"hyde too long ({hyde_len} chars)")
  190. # Hyde has no newlines (+5)
  191. if "\n" not in hyde_text:
  192. hyde_score += 5
  193. else:
  194. scores["deductions"].append("hyde contains newlines")
  195. # Hyde has no excessive repetition (+5)
  196. rep_penalty = word_repetition_penalty(hyde_text)
  197. if rep_penalty == 0:
  198. hyde_score += 5
  199. else:
  200. hyde_score += max(0, 5 - rep_penalty)
  201. scores["deductions"].append(f"hyde repetition penalty (-{min(5, rep_penalty)})")
  202. scores["hyde"] = hyde_score
  203. # === QUALITY (0-20) ===
  204. quality_score = 10 # Base relevance (assume relevant unless obvious garbage)
  205. # Lex lines should be keyword-focused (shorter than vec on average)
  206. if parsed["lex"] and parsed["vec"]:
  207. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  208. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  209. if avg_lex <= avg_vec:
  210. quality_score += 5
  211. else:
  212. scores["deductions"].append("lex longer than vec (should be keywords)")
  213. else:
  214. quality_score += 2 # Partial credit
  215. # Vec lines should be natural language (contain spaces, longer)
  216. if parsed["vec"]:
  217. vec_natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  218. if vec_natural == len(parsed["vec"]):
  219. quality_score += 5
  220. else:
  221. quality_score += 2
  222. scores["deductions"].append("some vec lines too short/keyword-like")
  223. scores["quality"] = quality_score
  224. # === TOTAL ===
  225. scores["total"] = scores["format"] + scores["diversity"] + scores["hyde"] + scores["quality"]
  226. scores["max_possible"] = 100 if parsed["hyde"] else 80
  227. scores["percentage"] = scores["total"] / scores["max_possible"] * 100
  228. # Rating
  229. pct = scores["percentage"]
  230. if pct >= 80:
  231. scores["rating"] = "Excellent"
  232. elif pct >= 60:
  233. scores["rating"] = "Good"
  234. elif pct >= 40:
  235. scores["rating"] = "Acceptable"
  236. elif pct >= 20:
  237. scores["rating"] = "Poor"
  238. else:
  239. scores["rating"] = "Failed"
  240. scores["parsed"] = parsed
  241. return scores
  242. def load_model(model_name: str, base_model: str = "Qwen/Qwen3-0.6B"):
  243. """Load the finetuned model."""
  244. print(f"Loading tokenizer from {base_model}...")
  245. tokenizer = AutoTokenizer.from_pretrained(base_model)
  246. if tokenizer.pad_token is None:
  247. tokenizer.pad_token = tokenizer.eos_token
  248. print(f"Loading base model...")
  249. base = AutoModelForCausalLM.from_pretrained(
  250. base_model,
  251. torch_dtype=torch.bfloat16,
  252. device_map="auto",
  253. )
  254. print(f"Loading adapter from {model_name}...")
  255. model = PeftModel.from_pretrained(base, model_name)
  256. model.eval()
  257. return model, tokenizer
  258. def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200) -> str:
  259. """Generate query expansion."""
  260. prompt = PROMPT_TEMPLATE.format(query=query)
  261. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  262. with torch.no_grad():
  263. outputs = model.generate(
  264. **inputs,
  265. max_new_tokens=max_new_tokens,
  266. temperature=0.7,
  267. do_sample=True,
  268. pad_token_id=tokenizer.pad_token_id,
  269. eos_token_id=tokenizer.eos_token_id,
  270. )
  271. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  272. expansion = full_output[len(prompt):].strip()
  273. return expansion
  274. def print_score_breakdown(scores: dict):
  275. """Pretty print score breakdown."""
  276. print(f" Score: {scores['total']}/{scores['max_possible']} ({scores['percentage']:.0f}%) - {scores['rating']}")
  277. print(f" Format: {scores['format']}/30")
  278. print(f" Diversity: {scores['diversity']}/30")
  279. print(f" Hyde: {scores['hyde']}/20")
  280. print(f" Quality: {scores['quality']}/20")
  281. if scores["deductions"]:
  282. print(f" Deductions:")
  283. for d in scores["deductions"][:5]: # Show top 5
  284. print(f" - {d}")
  285. if len(scores["deductions"]) > 5:
  286. print(f" ... and {len(scores['deductions']) - 5} more")
  287. def run_examples():
  288. """Run good and bad examples to demonstrate scoring."""
  289. print("=" * 70)
  290. print("SCORING EXAMPLES")
  291. print("=" * 70)
  292. # Good example
  293. good_expansion = """lex: react hooks tutorial
  294. lex: usestate useeffect
  295. vec: how to use react hooks in functional components
  296. vec: react hooks best practices guide
  297. hyde: React Hooks allow you to use state and lifecycle features in functional components without writing a class."""
  298. print("\n[GOOD EXAMPLE]")
  299. print(f"Query: react hooks")
  300. print(f"Output:\n{good_expansion}")
  301. scores = score_expansion("react hooks", good_expansion)
  302. print_score_breakdown(scores)
  303. # Bad example
  304. bad_expansion = """auth is an important concept that relates to authentication.
  305. The answer should be in Chinese.
  306. The answer should be in Chinese."""
  307. print("\n[BAD EXAMPLE]")
  308. print(f"Query: auth")
  309. print(f"Output:\n{bad_expansion}")
  310. scores = score_expansion("auth", bad_expansion)
  311. print_score_breakdown(scores)
  312. # Medium example - repetitive hyde
  313. medium_expansion = """lex: docker networking
  314. vec: docker networking
  315. hyde: Docker networking is an important concept. Docker networking is used for container communication. Docker networking configuration is essential."""
  316. print("\n[MEDIUM EXAMPLE - Repetitive]")
  317. print(f"Query: docker networking")
  318. print(f"Output:\n{medium_expansion}")
  319. scores = score_expansion("docker networking", medium_expansion)
  320. print_score_breakdown(scores)
  321. # Medium example - echoes query
  322. echo_expansion = """lex: auth
  323. lex: authentication
  324. vec: auth
  325. vec: authentication configuration
  326. hyde: Authentication is the process of verifying identity."""
  327. print("\n[MEDIUM EXAMPLE - Echoes Query]")
  328. print(f"Query: auth")
  329. print(f"Output:\n{echo_expansion}")
  330. scores = score_expansion("auth", echo_expansion)
  331. print_score_breakdown(scores)
  332. print("\n" + "=" * 70)
  333. def main():
  334. import argparse
  335. parser = argparse.ArgumentParser()
  336. parser.add_argument("--model", default="tobil/qmd-query-expansion-0.6B",
  337. help="Model to evaluate")
  338. parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
  339. help="Base model")
  340. parser.add_argument("--output", default="evaluation_results.json",
  341. help="Output file for results")
  342. parser.add_argument("--queries", type=str, help="Custom queries file (one per line)")
  343. parser.add_argument("--examples", action="store_true", help="Run scoring examples only")
  344. args = parser.parse_args()
  345. # Run examples if requested
  346. if args.examples:
  347. run_examples()
  348. return
  349. # Load custom queries if provided
  350. queries = TEST_QUERIES
  351. if args.queries:
  352. with open(args.queries) as f:
  353. queries = [l.strip() for l in f if l.strip()]
  354. # Load model
  355. model, tokenizer = load_model(args.model, args.base_model)
  356. # Run evaluation
  357. results = []
  358. print(f"\n{'='*70}")
  359. print("EVALUATION RESULTS")
  360. print(f"{'='*70}\n")
  361. for i, query in enumerate(queries, 1):
  362. print(f"[{i}/{len(queries)}] Query: {query}")
  363. print("-" * 50)
  364. expansion = generate_expansion(model, tokenizer, query)
  365. scores = score_expansion(query, expansion)
  366. print(expansion)
  367. print()
  368. print_score_breakdown(scores)
  369. print()
  370. results.append({
  371. "query": query,
  372. "expansion": expansion,
  373. "scores": {k: v for k, v in scores.items() if k != "parsed"},
  374. "parsed": scores["parsed"],
  375. })
  376. # Summary
  377. print(f"\n{'='*70}")
  378. print("SUMMARY")
  379. print(f"{'='*70}")
  380. avg_score = sum(r["scores"]["percentage"] for r in results) / len(results)
  381. excellent = sum(1 for r in results if r["scores"]["rating"] == "Excellent")
  382. good = sum(1 for r in results if r["scores"]["rating"] == "Good")
  383. acceptable = sum(1 for r in results if r["scores"]["rating"] == "Acceptable")
  384. poor = sum(1 for r in results if r["scores"]["rating"] == "Poor")
  385. failed = sum(1 for r in results if r["scores"]["rating"] == "Failed")
  386. print(f" Total queries: {len(results)}")
  387. print(f" Average score: {avg_score:.1f}%")
  388. print(f" Ratings:")
  389. print(f" Excellent: {excellent}")
  390. print(f" Good: {good}")
  391. print(f" Acceptable: {acceptable}")
  392. print(f" Poor: {poor}")
  393. print(f" Failed: {failed}")
  394. # Save results
  395. with open(args.output, "w") as f:
  396. json.dump(results, f, indent=2)
  397. print(f"\n Results saved to: {args.output}")
  398. if __name__ == "__main__":
  399. main()