score.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = []
  4. # ///
  5. """
  6. Score query expansion results from a JSONL file.
  7. Usage:
  8. uv run evals/score.py evals/results_model.jsonl
  9. uv run evals/score.py evals/results_model.jsonl --output scores.json
  10. uv run evals/score.py evals/results_model.jsonl --verbose
  11. """
  12. import argparse
  13. import json
  14. import re
  15. import sys
  16. from collections import Counter
  17. from pathlib import Path
  18. # ============== SCORING CONSTANTS ==============
  19. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  20. KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  21. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  22. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
  23. GENERIC_LEX_PHRASES = {
  24. 'find information about', 'search for', 'look up', 'get information',
  25. 'learn about', 'information on', 'details about', 'find out about',
  26. 'what is', 'how to', 'guide to', 'help with'
  27. }
  28. # ============== HELPER FUNCTIONS ==============
  29. def extract_named_entities(query: str) -> set:
  30. """Extract named entities from query using simple heuristics."""
  31. entities = set()
  32. words = query.split()
  33. prev_was_entity = False
  34. for i, word in enumerate(words):
  35. clean = word.strip('.,!?:;()[]"\'')
  36. if not clean:
  37. prev_was_entity = False
  38. continue
  39. is_entity = False
  40. # All-caps words (acronyms): TDS, API, GPU
  41. if clean.isupper() and len(clean) >= 2:
  42. entities.add(clean.lower())
  43. is_entity = True
  44. # Capitalized words (not first word)
  45. elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  46. entities.add(clean.lower())
  47. is_entity = True
  48. # Technical terms: node.js, C++
  49. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  50. entities.add(clean.lower())
  51. is_entity = True
  52. # CamelCase: JavaScript
  53. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  54. entities.add(clean.lower())
  55. is_entity = True
  56. # Word following an entity (compound names)
  57. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  58. entities.add(clean.lower())
  59. is_entity = True
  60. prev_was_entity = is_entity
  61. return entities
  62. def get_key_terms(query: str) -> set:
  63. """Get key terms (non-stopwords) from query."""
  64. words = set(query.lower().split())
  65. return words - KEY_TERM_STOPWORDS
  66. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  67. """Check if lex line preserves key terms from query."""
  68. key_terms = get_key_terms(query)
  69. if not key_terms:
  70. return True
  71. lex_words = set(lex_line.lower().split())
  72. return bool(key_terms & lex_words)
  73. def lex_preserves_entities(lex_line: str, entities: set) -> bool:
  74. """Check if lex line contains at least one named entity."""
  75. if not entities:
  76. return True
  77. lex_lower = lex_line.lower()
  78. return any(entity in lex_lower for entity in entities)
  79. def lex_is_generic(lex_line: str) -> bool:
  80. """Check if lex line is a generic filler phrase."""
  81. lex_lower = lex_line.lower().strip()
  82. for phrase in GENERIC_LEX_PHRASES:
  83. if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
  84. remaining = lex_lower
  85. for word in phrase.split():
  86. remaining = remaining.replace(word, '', 1).strip()
  87. if len(remaining) < 3:
  88. return True
  89. return False
  90. def parse_expansion(text: str) -> dict:
  91. """Parse expansion into structured format."""
  92. lines = text.strip().split("\n")
  93. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  94. for line in lines:
  95. line = line.strip()
  96. if not line:
  97. continue
  98. if line.startswith("lex:"):
  99. result["lex"].append(line[4:].strip())
  100. elif line.startswith("vec:"):
  101. result["vec"].append(line[4:].strip())
  102. elif line.startswith("hyde:"):
  103. result["hyde"].append(line[5:].strip())
  104. else:
  105. result["invalid"].append(line)
  106. return result
  107. def edit_distance_simple(a: str, b: str) -> int:
  108. """Simple word-level edit distance."""
  109. words_a = set(a.lower().split())
  110. words_b = set(b.lower().split())
  111. return len(words_a ^ words_b)
  112. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  113. """Check if two strings are sufficiently different."""
  114. a, b = a.lower().strip(), b.lower().strip()
  115. if a == b:
  116. return False
  117. if a in b or b in a:
  118. return False
  119. return edit_distance_simple(a, b) >= min_distance
  120. def echoes_query(expansion: str, query: str) -> bool:
  121. """Check if expansion is just echoing the query."""
  122. exp = expansion.lower().strip()
  123. q = query.lower().strip()
  124. if exp == q:
  125. return True
  126. if q in exp and len(exp) < len(q) + 10:
  127. return True
  128. return False
  129. def word_repetition_penalty(text: str) -> int:
  130. """Count penalty for repeated words."""
  131. words = re.findall(r'\b\w+\b', text.lower())
  132. counts = Counter(words)
  133. penalty = 0
  134. for word, count in counts.items():
  135. if count >= 3 and word not in STOPWORDS and len(word) > 2:
  136. penalty += (count - 2) * 2
  137. return penalty
  138. # ============== MAIN SCORING FUNCTION ==============
  139. def score_expansion(query: str, expansion: str) -> dict:
  140. """Score an expansion. Returns detailed breakdown."""
  141. text = expansion.strip()
  142. deductions = []
  143. # HARD FAIL: Chat template artifacts
  144. if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
  145. '\nassistant\n', '\nuser\n', '<|endoftext|>']):
  146. return {
  147. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  148. "total": 0, "max_possible": 100, "percentage": 0, "rating": "Failed",
  149. "deductions": ["CHAT TEMPLATE LEAKAGE"],
  150. "parsed": {"lex": [], "vec": [], "hyde": [], "invalid": [text[:100]]},
  151. "entities_detected": [],
  152. }
  153. # HARD FAIL: Every line must start with lex:, vec:, or hyde:
  154. for line in text.split("\n"):
  155. line = line.strip()
  156. if not line:
  157. continue
  158. if not line.startswith(("lex:", "vec:", "hyde:")):
  159. return {
  160. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  161. "total": 0, "max_possible": 100, "percentage": 0, "rating": "Failed",
  162. "deductions": [f"INVALID LINE: {line[:50]}"],
  163. "parsed": parse_expansion(expansion),
  164. "entities_detected": [],
  165. }
  166. parsed = parse_expansion(expansion)
  167. # FORMAT (0-30)
  168. format_score = 0
  169. if parsed["lex"]:
  170. format_score += 10
  171. else:
  172. deductions.append("missing lex:")
  173. if parsed["vec"]:
  174. format_score += 10
  175. else:
  176. deductions.append("missing vec:")
  177. format_score += 10 # No invalid lines (guaranteed by hard fail)
  178. # DIVERSITY (0-30)
  179. diversity_score = 0
  180. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  181. if types_present >= 2:
  182. diversity_score += 10
  183. else:
  184. deductions.append("only one type")
  185. total_expansions = len(parsed["lex"]) + len(parsed["vec"])
  186. if total_expansions >= 2:
  187. diversity_score += 5
  188. lex_score = 5
  189. for i, a in enumerate(parsed["lex"]):
  190. for b in parsed["lex"][i+1:]:
  191. if not is_diverse(a, b, 2):
  192. lex_score -= 2
  193. deductions.append(f"lex duplicate: {a[:20]}...")
  194. diversity_score += max(0, lex_score)
  195. vec_score = 5
  196. for i, a in enumerate(parsed["vec"]):
  197. for b in parsed["vec"][i+1:]:
  198. if not is_diverse(a, b, 3):
  199. vec_score -= 2
  200. deductions.append(f"vec duplicate: {a[:20]}...")
  201. diversity_score += max(0, vec_score)
  202. echo_score = 5
  203. for exp in parsed["lex"] + parsed["vec"]:
  204. if echoes_query(exp, query):
  205. echo_score -= 3
  206. deductions.append(f"echoes query: {exp[:20]}...")
  207. diversity_score += max(0, echo_score)
  208. # HYDE (0-20)
  209. hyde_score = 0
  210. if parsed["hyde"]:
  211. hyde_text = parsed["hyde"][0]
  212. hyde_score += 5
  213. hyde_len = len(hyde_text)
  214. if 50 <= hyde_len <= 200:
  215. hyde_score += 5
  216. elif hyde_len < 50:
  217. hyde_score += 2
  218. deductions.append(f"hyde too short ({hyde_len})")
  219. else:
  220. deductions.append(f"hyde too long ({hyde_len})")
  221. if "\n" not in hyde_text:
  222. hyde_score += 5
  223. rep_penalty = word_repetition_penalty(hyde_text)
  224. hyde_score += max(0, 5 - rep_penalty)
  225. # QUALITY (0-20)
  226. quality_score = 5
  227. if parsed["lex"] and parsed["vec"]:
  228. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  229. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  230. if avg_lex <= avg_vec:
  231. quality_score += 5
  232. else:
  233. deductions.append("lex longer than vec")
  234. if parsed["vec"]:
  235. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  236. if natural == len(parsed["vec"]):
  237. quality_score += 5
  238. else:
  239. quality_score += 2
  240. if parsed["lex"]:
  241. lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  242. if lex_with_terms == len(parsed["lex"]):
  243. quality_score += 5
  244. elif lex_with_terms > 0:
  245. quality_score += 2
  246. else:
  247. deductions.append("lex missing key terms")
  248. # NAMED ENTITY PRESERVATION (0-20, can go negative)
  249. entity_score = 0
  250. entities = extract_named_entities(query)
  251. if entities and parsed["lex"]:
  252. lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  253. if lex_with_entities == len(parsed["lex"]):
  254. entity_score += 15
  255. elif lex_with_entities > 0:
  256. entity_score += 5
  257. else:
  258. entity_score -= 30
  259. deductions.append(f"lex missing entities: {entities}")
  260. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  261. if generic_count > 0:
  262. entity_score -= generic_count * 15
  263. deductions.append(f"{generic_count} generic lex phrases")
  264. if parsed["vec"]:
  265. vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  266. if vec_with_entities > 0:
  267. entity_score += 5
  268. elif not entities:
  269. entity_score = 10
  270. # TOTAL
  271. total = format_score + diversity_score + hyde_score + quality_score + entity_score
  272. max_possible = 120 if parsed["hyde"] else 100
  273. percentage = max(0.0, min(100.0, total / max_possible * 100))
  274. # Rating
  275. if percentage >= 80:
  276. rating = "Excellent"
  277. elif percentage >= 60:
  278. rating = "Good"
  279. elif percentage >= 40:
  280. rating = "Acceptable"
  281. elif percentage >= 20:
  282. rating = "Poor"
  283. else:
  284. rating = "Failed"
  285. return {
  286. "format": format_score,
  287. "diversity": diversity_score,
  288. "hyde": hyde_score,
  289. "quality": quality_score,
  290. "entity": max(0, entity_score),
  291. "total": max(0, total),
  292. "max_possible": max_possible,
  293. "percentage": round(percentage, 1),
  294. "rating": rating,
  295. "deductions": deductions,
  296. "parsed": parsed,
  297. "entities_detected": list(entities) if entities else [],
  298. }
  299. def print_result(query: str, expansion: str, scores: dict, verbose: bool = False):
  300. """Print a single result."""
  301. print(f"\n{'='*60}")
  302. print(f"Query: {query}")
  303. print(f"{'─'*60}")
  304. print(expansion)
  305. print(f"{'─'*60}")
  306. print(f"Score: {scores['percentage']:.0f}% ({scores['rating']})")
  307. print(f" Format: {scores['format']}/30 Diversity: {scores['diversity']}/30 "
  308. f"Hyde: {scores['hyde']}/20 Quality: {scores['quality']}/20 Entity: {scores['entity']}/20")
  309. if verbose and scores["deductions"]:
  310. print(f" Deductions: {', '.join(scores['deductions'][:5])}")
  311. if verbose and scores["entities_detected"]:
  312. print(f" Entities: {scores['entities_detected']}")
  313. def main():
  314. parser = argparse.ArgumentParser(description="Score query expansion results")
  315. parser.add_argument("input", help="Input JSONL file from run.py")
  316. parser.add_argument("--output", help="Output JSON file with scores")
  317. parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed output")
  318. parser.add_argument("--summary-only", action="store_true", help="Only show summary")
  319. args = parser.parse_args()
  320. # Load results
  321. results = []
  322. metadata = None
  323. with open(args.input) as f:
  324. for line in f:
  325. data = json.loads(line)
  326. if data.get("_meta"):
  327. metadata = data
  328. else:
  329. results.append(data)
  330. print(f"Scoring {len(results)} results from {args.input}", file=sys.stderr)
  331. if metadata:
  332. print(f"Model: {metadata.get('model', 'unknown')}", file=sys.stderr)
  333. # Score each result
  334. scored_results = []
  335. for result in results:
  336. query = result["query"]
  337. expansion = result["expansion"]
  338. scores = score_expansion(query, expansion)
  339. if not args.summary_only:
  340. print_result(query, expansion, scores, args.verbose)
  341. scored_results.append({
  342. "query": query,
  343. "expansion": expansion,
  344. "scores": {k: v for k, v in scores.items() if k not in ["parsed", "deductions", "entities_detected"]},
  345. "deductions": scores["deductions"],
  346. "entities_detected": scores["entities_detected"],
  347. })
  348. # Summary
  349. print(f"\n{'='*60}")
  350. print("SUMMARY")
  351. print(f"{'='*60}")
  352. avg_score = sum(r["scores"]["percentage"] for r in scored_results) / len(scored_results)
  353. ratings = Counter(r["scores"]["rating"] for r in scored_results)
  354. print(f" Total queries: {len(scored_results)}")
  355. print(f" Average score: {avg_score:.1f}%")
  356. print(f" Ratings:")
  357. for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
  358. count = ratings.get(rating, 0)
  359. bar = "█" * count
  360. print(f" {rating:10s}: {count:2d} {bar}")
  361. # Save output
  362. if args.output:
  363. output_data = {
  364. "metadata": metadata,
  365. "summary": {
  366. "total": len(scored_results),
  367. "average_score": round(avg_score, 1),
  368. "ratings": dict(ratings),
  369. },
  370. "results": scored_results,
  371. }
  372. with open(args.output, "w") as f:
  373. json.dump(output_data, f, indent=2)
  374. print(f"\nScores saved to: {args.output}")
  375. if __name__ == "__main__":
  376. main()