eval_common.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. """
  2. Common evaluation and reward scoring for QMD query expansion models.
  3. Shared by sft.py and grpo.py for post-training evaluation.
  4. """
  5. import csv
  6. import io
  7. import re
  8. from collections import Counter
  9. import torch
  10. from huggingface_hub import HfApi
  11. # =============================================================================
  12. # Reward function (single source of truth)
  13. # =============================================================================
  14. STOPWORDS = frozenset({
  15. 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
  16. 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
  17. })
  18. KEY_TERM_STOPWORDS = frozenset({
  19. 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  20. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  21. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
  22. })
  23. GENERIC_LEX_PHRASES = frozenset({
  24. 'find information about', 'search for', 'look up', 'get information',
  25. 'learn about', 'information on', 'details about', 'find out about',
  26. 'what is', 'how to', 'guide to', 'help with',
  27. })
  28. CHAT_TEMPLATE_TOKENS = frozenset({
  29. '<|im_start|>', '<|im_end|>', '<|endoftext|>',
  30. '\nassistant\n', '\nuser\n',
  31. })
  32. def parse_expansion(text):
  33. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  34. for line in text.strip().split("\n"):
  35. line = line.strip()
  36. if not line:
  37. continue
  38. if line.startswith("lex:"):
  39. result["lex"].append(line[4:].strip())
  40. elif line.startswith("vec:"):
  41. result["vec"].append(line[4:].strip())
  42. elif line.startswith("hyde:"):
  43. result["hyde"].append(line[5:].strip())
  44. else:
  45. result["invalid"].append(line)
  46. return result
  47. def clean_model_output(text):
  48. text = text.replace('<|im_end|>', '').strip()
  49. used_thinking = '<think>' in text and '</think>' in text
  50. if used_thinking:
  51. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
  52. return text, used_thinking
  53. def extract_named_entities(query):
  54. entities = set()
  55. words = query.split()
  56. prev_was_entity = False
  57. for i, word in enumerate(words):
  58. clean = word.strip('.,!?:;()[]"\'')
  59. if not clean:
  60. prev_was_entity = False
  61. continue
  62. is_entity = False
  63. if clean.isupper() and len(clean) >= 2:
  64. entities.add(clean.lower()); is_entity = True
  65. elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  66. entities.add(clean.lower()); is_entity = True
  67. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  68. entities.add(clean.lower()); is_entity = True
  69. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  70. entities.add(clean.lower()); is_entity = True
  71. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  72. entities.add(clean.lower()); is_entity = True
  73. prev_was_entity = is_entity
  74. return entities
  75. def get_key_terms(query):
  76. return set(query.lower().split()) - KEY_TERM_STOPWORDS
  77. def lex_preserves_key_terms(lex_line, query):
  78. key_terms = get_key_terms(query)
  79. return not key_terms or bool(key_terms & set(lex_line.lower().split()))
  80. def lex_preserves_entities(line, entities):
  81. if not entities:
  82. return True
  83. return any(e in line.lower() for e in entities)
  84. def lex_is_generic(lex_line):
  85. lower = lex_line.lower().strip()
  86. for phrase in GENERIC_LEX_PHRASES:
  87. if phrase in lower or lower.startswith(phrase.split()[0]):
  88. remaining = lower
  89. for word in phrase.split():
  90. remaining = remaining.replace(word, '', 1).strip()
  91. if len(remaining) < 3:
  92. return True
  93. return False
  94. def word_set_distance(a, b):
  95. return len(set(a.lower().split()) ^ set(b.lower().split()))
  96. def is_diverse(a, b, min_distance=2):
  97. a, b = a.lower().strip(), b.lower().strip()
  98. if a == b or a in b or b in a:
  99. return False
  100. return word_set_distance(a, b) >= min_distance
  101. def echoes_query(expansion, query):
  102. exp, q = expansion.lower().strip(), query.lower().strip()
  103. return exp == q or (q in exp and len(exp) < len(q) + 10)
  104. def word_repetition_penalty(text):
  105. counts = Counter(re.findall(r'\b\w+\b', text.lower()))
  106. return sum((c - 2) * 2 for w, c in counts.items()
  107. if c >= 3 and w not in STOPWORDS and len(w) > 2)
  108. def score_expansion(query, expansion):
  109. """Score expansion as float in [0.0, 1.0] for RL reward."""
  110. text, used_thinking = clean_model_output(expansion.strip())
  111. if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
  112. return 0.0
  113. for line in text.split("\n"):
  114. line = line.strip()
  115. if line and not line.startswith(("lex:", "vec:", "hyde:")):
  116. return 0.0
  117. parsed = parse_expansion(text)
  118. format_score = 10
  119. if parsed["lex"]: format_score += 10
  120. if parsed["vec"]: format_score += 10
  121. diversity_score = 0
  122. if sum(1 for t in ("lex", "vec") if parsed[t]) >= 2: diversity_score += 10
  123. if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
  124. lex_div = 5
  125. for i, a in enumerate(parsed["lex"]):
  126. for b in parsed["lex"][i+1:]:
  127. if not is_diverse(a, b, 2): lex_div -= 2
  128. diversity_score += max(0, lex_div)
  129. vec_div = 5
  130. for i, a in enumerate(parsed["vec"]):
  131. for b in parsed["vec"][i+1:]:
  132. if not is_diverse(a, b, 3): vec_div -= 2
  133. diversity_score += max(0, vec_div)
  134. echo = 5
  135. for exp in parsed["lex"] + parsed["vec"]:
  136. if echoes_query(exp, query): echo -= 3
  137. diversity_score += max(0, echo)
  138. hyde_score = 0
  139. if parsed["hyde"]:
  140. hyde_text = parsed["hyde"][0]
  141. hyde_score += 5
  142. if 50 <= len(hyde_text) <= 200: hyde_score += 5
  143. elif len(hyde_text) < 50: hyde_score += 2
  144. if "\n" not in hyde_text: hyde_score += 5
  145. hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
  146. quality_score = 5
  147. if parsed["lex"] and parsed["vec"]:
  148. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  149. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  150. if avg_lex <= avg_vec: quality_score += 5
  151. if parsed["vec"]:
  152. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  153. quality_score += 5 if natural == len(parsed["vec"]) else 2
  154. if parsed["lex"]:
  155. with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  156. if with_terms == len(parsed["lex"]): quality_score += 5
  157. elif with_terms > 0: quality_score += 2
  158. entity_score = 0
  159. entities = extract_named_entities(query)
  160. if entities and parsed["lex"]:
  161. with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  162. if with_entities == len(parsed["lex"]): entity_score += 15
  163. elif with_entities > 0: entity_score += 5
  164. else: entity_score -= 30
  165. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  166. if generic_count: entity_score -= generic_count * 15
  167. if parsed["vec"]:
  168. vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  169. if vec_with > 0: entity_score += 5
  170. elif not entities:
  171. entity_score = 10
  172. think_bonus = 0 if used_thinking else 20
  173. total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
  174. max_possible = 140 if parsed["hyde"] else 120
  175. return max(0.0, min(1.0, total / max_possible))
  176. def extract_query_from_prompt(prompt):
  177. """Extract the search query from a formatted prompt string."""
  178. if "Expand this search query:" in prompt:
  179. query = prompt.split("Expand this search query:")[-1].strip()
  180. if "<|im_end|>" in query:
  181. query = query.split("<|im_end|>")[0].strip()
  182. return query
  183. return prompt.strip()
  184. class QMDRewardFunction:
  185. """Reward function wrapper for TRL's GRPOTrainer."""
  186. __name__ = "qmd_scoring_reward"
  187. def __call__(self, completions, prompts=None, **kwargs):
  188. rewards = []
  189. for i, completion in enumerate(completions):
  190. query = ""
  191. if prompts and i < len(prompts):
  192. query = extract_query_from_prompt(prompts[i])
  193. rewards.append(score_expansion(query, completion))
  194. return rewards
  195. # =============================================================================
  196. # Evaluation
  197. # =============================================================================
  198. EVAL_QUERIES = [
  199. # Technical documentation
  200. "how to configure authentication",
  201. "typescript async await",
  202. "docker compose networking",
  203. "git rebase vs merge",
  204. "react useEffect cleanup",
  205. # Short/ambiguous
  206. "auth", "config", "setup", "api",
  207. # Named entities
  208. "who is TDS motorsports",
  209. "React hooks tutorial",
  210. "Docker container networking",
  211. "Kubernetes pod deployment",
  212. "AWS Lambda functions",
  213. # Personal notes / journals
  214. "meeting notes project kickoff",
  215. "ideas for new feature",
  216. "todo list app architecture",
  217. # Research / learning
  218. "what is dependency injection",
  219. "difference between sql and nosql",
  220. "kubernetes vs docker swarm",
  221. # Error/debugging
  222. "connection timeout error",
  223. "memory leak debugging",
  224. "cors error fix",
  225. # Temporal / recency
  226. "recent news about Shopify",
  227. "latest AI developments",
  228. "best laptops right now",
  229. "what changed in kubernetes latest version",
  230. # Complex
  231. "how to implement caching with redis in nodejs",
  232. "best practices for api rate limiting",
  233. "setting up ci cd pipeline with github actions",
  234. ]
  235. def generate_expansion(model, tokenizer, query, max_new_tokens=200):
  236. """Generate a query expansion using the model."""
  237. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  238. prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  239. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  240. with torch.no_grad():
  241. outputs = model.generate(
  242. **inputs, max_new_tokens=max_new_tokens,
  243. temperature=0.7, do_sample=True,
  244. pad_token_id=tokenizer.pad_token_id,
  245. eos_token_id=tokenizer.eos_token_id,
  246. )
  247. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  248. if "\nassistant\n" in full_output:
  249. return full_output.split("\nassistant\n")[-1].strip()
  250. elif "assistant\n" in full_output:
  251. return full_output.split("assistant\n")[-1].strip()
  252. return full_output[len(prompt):].strip()
  253. def run_eval(model, tokenizer, label, upload_repo="tobil/qmd-query-expansion-evals"):
  254. """Evaluate model on EVAL_QUERIES, print results, upload CSV."""
  255. api = HfApi()
  256. api.create_repo(repo_id=upload_repo, repo_type="model", exist_ok=True)
  257. print(f"\n{'='*70}")
  258. print(f" EVALUATING: {label}")
  259. print(f"{'='*70}")
  260. results = []
  261. for i, query in enumerate(EVAL_QUERIES, 1):
  262. expansion = generate_expansion(model, tokenizer, query)
  263. score = score_expansion(query, expansion)
  264. pct = round(score * 100, 1)
  265. rating = ("Excellent" if pct >= 80 else "Good" if pct >= 60
  266. else "Acceptable" if pct >= 40 else "Poor" if pct >= 20 else "Failed")
  267. marker = "+" if pct >= 80 else "-" if pct < 60 else "~"
  268. print(f" [{marker}] {i:2d}/{len(EVAL_QUERIES)} {pct:5.1f}% {rating:10s} {query}")
  269. results.append({"query": query, "expansion": expansion, "score": pct, "rating": rating})
  270. avg = sum(r["score"] for r in results) / len(results)
  271. ratings = Counter(r["rating"] for r in results)
  272. print(f"\n {'─'*50}")
  273. print(f" Average score: {avg:.1f}%")
  274. for r in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
  275. c = ratings.get(r, 0)
  276. if c:
  277. print(f" {r:10s}: {c:2d} {'█' * c}")
  278. worst = sorted(results, key=lambda r: r["score"])[:5]
  279. print(f"\n Bottom 5:")
  280. for r in worst:
  281. print(f" {r['score']:5.1f}% {r['query']}")
  282. buf = io.StringIO()
  283. writer = csv.writer(buf)
  284. writer.writerow(["model", "query", "expansion", "score_pct", "rating"])
  285. for r in results:
  286. writer.writerow([label, r["query"], r["expansion"], r["score"], r["rating"]])
  287. filename = f"eval_{label}.csv"
  288. print(f"\n Uploading {filename} to {upload_repo}...")
  289. api.upload_file(
  290. path_or_fileobj=buf.getvalue().encode("utf-8"),
  291. path_in_repo=filename,
  292. repo_id=upload_repo,
  293. repo_type="model",
  294. )
  295. print(f" Done: https://huggingface.co/{upload_repo}/blob/main/{filename}")