eval.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "huggingface_hub>=0.20.0",
  8. # "accelerate",
  9. # ]
  10. # ///
  11. """
  12. Evaluate QMD query expansion models on HuggingFace Jobs.
  13. Self-contained script — inlines the reward function and test queries.
  14. hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py
  15. hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py -- --sft-only
  16. """
  17. import argparse
  18. import csv
  19. import io
  20. import json
  21. import os
  22. import re
  23. import sys
  24. from collections import Counter
  25. import torch
  26. from huggingface_hub import HfApi, login
  27. from peft import PeftModel
  28. from transformers import AutoModelForCausalLM, AutoTokenizer
  29. # --- Config ---
  30. BASE_MODEL = "Qwen/Qwen3-1.7B"
  31. SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
  32. GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
  33. # --- Test queries (inlined from evals/queries.txt) ---
  34. QUERIES = [
  35. # Technical documentation
  36. "how to configure authentication",
  37. "typescript async await",
  38. "docker compose networking",
  39. "git rebase vs merge",
  40. "react useEffect cleanup",
  41. # Short/ambiguous
  42. "auth",
  43. "config",
  44. "setup",
  45. "api",
  46. # Named entities
  47. "who is TDS motorsports",
  48. "React hooks tutorial",
  49. "Docker container networking",
  50. "Kubernetes pod deployment",
  51. "AWS Lambda functions",
  52. # Personal notes / journals
  53. "meeting notes project kickoff",
  54. "ideas for new feature",
  55. "todo list app architecture",
  56. # Research / learning
  57. "what is dependency injection",
  58. "difference between sql and nosql",
  59. "kubernetes vs docker swarm",
  60. # Error/debugging
  61. "connection timeout error",
  62. "memory leak debugging",
  63. "cors error fix",
  64. # Temporal / recency
  65. "recent news about Shopify",
  66. "latest AI developments",
  67. "best laptops right now",
  68. "what changed in kubernetes latest version",
  69. # Complex
  70. "how to implement caching with redis in nodejs",
  71. "best practices for api rate limiting",
  72. "setting up ci cd pipeline with github actions",
  73. ]
  74. # =============================================================================
  75. # Reward function (inlined from reward.py)
  76. # =============================================================================
  77. STOPWORDS = frozenset({
  78. 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
  79. 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
  80. })
  81. KEY_TERM_STOPWORDS = frozenset({
  82. 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  83. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  84. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
  85. })
  86. GENERIC_LEX_PHRASES = frozenset({
  87. 'find information about', 'search for', 'look up', 'get information',
  88. 'learn about', 'information on', 'details about', 'find out about',
  89. 'what is', 'how to', 'guide to', 'help with',
  90. })
  91. CHAT_TEMPLATE_TOKENS = frozenset({
  92. '<|im_start|>', '<|im_end|>', '<|endoftext|>',
  93. '\nassistant\n', '\nuser\n',
  94. })
  95. def parse_expansion(text):
  96. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  97. for line in text.strip().split("\n"):
  98. line = line.strip()
  99. if not line:
  100. continue
  101. if line.startswith("lex:"):
  102. result["lex"].append(line[4:].strip())
  103. elif line.startswith("vec:"):
  104. result["vec"].append(line[4:].strip())
  105. elif line.startswith("hyde:"):
  106. result["hyde"].append(line[5:].strip())
  107. else:
  108. result["invalid"].append(line)
  109. return result
  110. def clean_model_output(text):
  111. text = text.replace('<|im_end|>', '').strip()
  112. used_thinking = '<think>' in text and '</think>' in text
  113. if used_thinking:
  114. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
  115. return text, used_thinking
  116. def extract_named_entities(query):
  117. entities = set()
  118. words = query.split()
  119. prev_was_entity = False
  120. for i, word in enumerate(words):
  121. clean = word.strip('.,!?:;()[]"\'')
  122. if not clean:
  123. prev_was_entity = False
  124. continue
  125. is_entity = False
  126. if clean.isupper() and len(clean) >= 2:
  127. entities.add(clean.lower()); is_entity = True
  128. elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  129. entities.add(clean.lower()); is_entity = True
  130. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  131. entities.add(clean.lower()); is_entity = True
  132. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  133. entities.add(clean.lower()); is_entity = True
  134. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  135. entities.add(clean.lower()); is_entity = True
  136. prev_was_entity = is_entity
  137. return entities
  138. def get_key_terms(query):
  139. return set(query.lower().split()) - KEY_TERM_STOPWORDS
  140. def lex_preserves_key_terms(lex_line, query):
  141. key_terms = get_key_terms(query)
  142. return not key_terms or bool(key_terms & set(lex_line.lower().split()))
  143. def lex_preserves_entities(line, entities):
  144. if not entities: return True
  145. return any(e in line.lower() for e in entities)
  146. def lex_is_generic(lex_line):
  147. lower = lex_line.lower().strip()
  148. for phrase in GENERIC_LEX_PHRASES:
  149. if phrase in lower or lower.startswith(phrase.split()[0]):
  150. remaining = lower
  151. for word in phrase.split():
  152. remaining = remaining.replace(word, '', 1).strip()
  153. if len(remaining) < 3:
  154. return True
  155. return False
  156. def word_set_distance(a, b):
  157. return len(set(a.lower().split()) ^ set(b.lower().split()))
  158. def is_diverse(a, b, min_distance=2):
  159. a, b = a.lower().strip(), b.lower().strip()
  160. if a == b or a in b or b in a: return False
  161. return word_set_distance(a, b) >= min_distance
  162. def echoes_query(expansion, query):
  163. exp, q = expansion.lower().strip(), query.lower().strip()
  164. return exp == q or (q in exp and len(exp) < len(q) + 10)
  165. def word_repetition_penalty(text):
  166. counts = Counter(re.findall(r'\b\w+\b', text.lower()))
  167. return sum((c - 2) * 2 for w, c in counts.items()
  168. if c >= 3 and w not in STOPWORDS and len(w) > 2)
  169. def score_expansion_detailed(query, expansion):
  170. text, used_thinking = clean_model_output(expansion.strip())
  171. deductions = []
  172. def _fail(reason):
  173. return {
  174. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  175. "think_bonus": 0, "total": 0, "max_possible": 100,
  176. "percentage": 0.0, "rating": "Failed", "deductions": [reason],
  177. }
  178. if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
  179. return _fail("CHAT TEMPLATE LEAKAGE")
  180. for line in text.split("\n"):
  181. line = line.strip()
  182. if line and not line.startswith(("lex:", "vec:", "hyde:")):
  183. return _fail(f"INVALID LINE: {line[:50]}")
  184. parsed = parse_expansion(text)
  185. format_score = 10
  186. if parsed["lex"]: format_score += 10
  187. else: deductions.append("missing lex:")
  188. if parsed["vec"]: format_score += 10
  189. else: deductions.append("missing vec:")
  190. diversity_score = 0
  191. types_present = sum(1 for t in ("lex", "vec") if parsed[t])
  192. if types_present >= 2: diversity_score += 10
  193. if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5
  194. lex_div = 5
  195. for i, a in enumerate(parsed["lex"]):
  196. for b in parsed["lex"][i+1:]:
  197. if not is_diverse(a, b, 2): lex_div -= 2
  198. diversity_score += max(0, lex_div)
  199. vec_div = 5
  200. for i, a in enumerate(parsed["vec"]):
  201. for b in parsed["vec"][i+1:]:
  202. if not is_diverse(a, b, 3): vec_div -= 2
  203. diversity_score += max(0, vec_div)
  204. echo = 5
  205. for exp in parsed["lex"] + parsed["vec"]:
  206. if echoes_query(exp, query): echo -= 3
  207. diversity_score += max(0, echo)
  208. hyde_score = 0
  209. if parsed["hyde"]:
  210. hyde_text = parsed["hyde"][0]
  211. hyde_score += 5
  212. hyde_len = len(hyde_text)
  213. if 50 <= hyde_len <= 200: hyde_score += 5
  214. elif hyde_len < 50: hyde_score += 2
  215. if "\n" not in hyde_text: hyde_score += 5
  216. hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
  217. quality_score = 5
  218. if parsed["lex"] and parsed["vec"]:
  219. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  220. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  221. if avg_lex <= avg_vec: quality_score += 5
  222. if parsed["vec"]:
  223. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  224. quality_score += 5 if natural == len(parsed["vec"]) else 2
  225. if parsed["lex"]:
  226. with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  227. if with_terms == len(parsed["lex"]): quality_score += 5
  228. elif with_terms > 0: quality_score += 2
  229. entity_score = 0
  230. entities = extract_named_entities(query)
  231. if entities and parsed["lex"]:
  232. with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  233. if with_entities == len(parsed["lex"]): entity_score += 15
  234. elif with_entities > 0: entity_score += 5
  235. else: entity_score -= 30
  236. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  237. if generic_count: entity_score -= generic_count * 15
  238. if parsed["vec"]:
  239. vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  240. if vec_with > 0: entity_score += 5
  241. elif not entities:
  242. entity_score = 10
  243. think_bonus = 0 if used_thinking else 20
  244. total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
  245. max_possible = 140 if parsed["hyde"] else 120
  246. percentage = max(0.0, min(100.0, total / max_possible * 100))
  247. if percentage >= 80: rating = "Excellent"
  248. elif percentage >= 60: rating = "Good"
  249. elif percentage >= 40: rating = "Acceptable"
  250. elif percentage >= 20: rating = "Poor"
  251. else: rating = "Failed"
  252. return {
  253. "format": format_score, "diversity": diversity_score, "hyde": hyde_score,
  254. "quality": quality_score, "entity": max(0, entity_score),
  255. "think_bonus": think_bonus, "total": max(0, total),
  256. "max_possible": max_possible, "percentage": round(percentage, 1),
  257. "rating": rating, "deductions": deductions,
  258. "entities_detected": list(entities) if entities else [],
  259. }
  260. # =============================================================================
  261. # Model loading and generation
  262. # =============================================================================
  263. def load_model(base, sft=None, grpo=None):
  264. print(f"Loading tokenizer from {base}...")
  265. tokenizer = AutoTokenizer.from_pretrained(base)
  266. if tokenizer.pad_token is None:
  267. tokenizer.pad_token = tokenizer.eos_token
  268. print(f"Loading base model {base}...")
  269. model = AutoModelForCausalLM.from_pretrained(
  270. base, torch_dtype=torch.bfloat16, device_map="auto",
  271. )
  272. if sft:
  273. print(f"Loading and merging SFT adapter {sft}...")
  274. model = PeftModel.from_pretrained(model, sft)
  275. model = model.merge_and_unload()
  276. if grpo:
  277. print(f"Loading GRPO adapter {grpo}...")
  278. model = PeftModel.from_pretrained(model, grpo)
  279. model.eval()
  280. return model, tokenizer
  281. def generate_expansion(model, tokenizer, query, max_new_tokens=200):
  282. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  283. prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  284. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  285. with torch.no_grad():
  286. outputs = model.generate(
  287. **inputs, max_new_tokens=max_new_tokens,
  288. temperature=0.7, do_sample=True,
  289. pad_token_id=tokenizer.pad_token_id,
  290. eos_token_id=tokenizer.eos_token_id,
  291. )
  292. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  293. if "\nassistant\n" in full_output:
  294. expansion = full_output.split("\nassistant\n")[-1].strip()
  295. elif "assistant\n" in full_output:
  296. expansion = full_output.split("assistant\n")[-1].strip()
  297. else:
  298. expansion = full_output[len(prompt):].strip()
  299. if "<think>" in expansion:
  300. expansion = re.sub(r'<think>.*?</think>', '', expansion, flags=re.DOTALL).strip()
  301. return expansion
  302. # =============================================================================
  303. # Main
  304. # =============================================================================
  305. def results_to_csv(results, label):
  306. """Convert eval results to CSV string."""
  307. buf = io.StringIO()
  308. writer = csv.writer(buf)
  309. writer.writerow([
  310. "model", "query", "expansion", "score_pct", "rating",
  311. "format", "diversity", "hyde", "quality", "entity", "think_bonus",
  312. "total", "max_possible", "deductions",
  313. ])
  314. for r in results:
  315. s = r["scores"]
  316. writer.writerow([
  317. label, r["query"], r["expansion"], s["percentage"], s["rating"],
  318. s["format"], s["diversity"], s["hyde"], s["quality"], s["entity"],
  319. s["think_bonus"], s["total"], s["max_possible"],
  320. "; ".join(s.get("deductions", [])),
  321. ])
  322. return buf.getvalue()
  323. def upload_csv(results, label, repo_id, api):
  324. """Upload eval results CSV to HuggingFace Hub."""
  325. csv_data = results_to_csv(results, label)
  326. tag = label.split("/")[-1].replace(" ", "_").lower()
  327. filename = f"eval_{tag}.csv"
  328. print(f" Uploading {filename} to {repo_id}...")
  329. api.upload_file(
  330. path_or_fileobj=csv_data.encode("utf-8"),
  331. path_in_repo=filename,
  332. repo_id=repo_id,
  333. repo_type="model",
  334. )
  335. print(f" Uploaded: https://huggingface.co/{repo_id}/blob/main/{filename}")
  336. def evaluate_model(model, tokenizer, label):
  337. print(f"\n{'='*70}")
  338. print(f" EVALUATING: {label}")
  339. print(f"{'='*70}")
  340. results = []
  341. for i, query in enumerate(QUERIES, 1):
  342. expansion = generate_expansion(model, tokenizer, query)
  343. scores = score_expansion_detailed(query, expansion)
  344. results.append({"query": query, "expansion": expansion, "scores": scores})
  345. marker = "+" if scores["percentage"] >= 80 else "-" if scores["percentage"] < 60 else "~"
  346. print(f" [{marker}] {i:2d}/{len(QUERIES)} {scores['percentage']:5.1f}% {scores['rating']:10s} {query}")
  347. avg = sum(r["scores"]["percentage"] for r in results) / len(results)
  348. ratings = Counter(r["scores"]["rating"] for r in results)
  349. print(f"\n {'─'*50}")
  350. print(f" Average score: {avg:.1f}%")
  351. print(f" Ratings:")
  352. for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
  353. count = ratings.get(rating, 0)
  354. if count > 0:
  355. print(f" {rating:10s}: {count:2d} {'█' * count}")
  356. # Show worst queries
  357. worst = sorted(results, key=lambda r: r["scores"]["percentage"])[:5]
  358. print(f"\n Bottom 5:")
  359. for r in worst:
  360. print(f" {r['scores']['percentage']:5.1f}% {r['query']}")
  361. if r["scores"]["deductions"]:
  362. print(f" {', '.join(r['scores']['deductions'][:3])}")
  363. return results, avg
  364. def main():
  365. parser = argparse.ArgumentParser()
  366. parser.add_argument("--sft-only", action="store_true", help="Only evaluate SFT model")
  367. parser.add_argument("--upload-repo", default="tobil/qmd-query-expansion-evals",
  368. help="HF repo to upload CSV results")
  369. args = parser.parse_args()
  370. hf_token = os.environ.get("HF_TOKEN")
  371. if hf_token:
  372. login(token=hf_token)
  373. api = HfApi()
  374. api.create_repo(repo_id=args.upload_repo, repo_type="model", exist_ok=True)
  375. # Evaluate SFT
  376. model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL)
  377. sft_results, sft_avg = evaluate_model(model, tokenizer, f"SFT: {SFT_MODEL}")
  378. upload_csv(sft_results, "sft", args.upload_repo, api)
  379. if not args.sft_only:
  380. # For GRPO: reload base, merge SFT, then load GRPO adapter
  381. del model
  382. torch.cuda.empty_cache()
  383. model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
  384. grpo_results, grpo_avg = evaluate_model(model, tokenizer, f"GRPO: {GRPO_MODEL}")
  385. upload_csv(grpo_results, "grpo", args.upload_repo, api)
  386. # Upload combined comparison CSV
  387. combined = results_to_csv(sft_results, "sft") + results_to_csv(grpo_results, "grpo").split("\n", 1)[1]
  388. api.upload_file(
  389. path_or_fileobj=combined.encode("utf-8"),
  390. path_in_repo="eval_comparison.csv",
  391. repo_id=args.upload_repo,
  392. repo_type="model",
  393. )
  394. print(f" Uploaded: eval_comparison.csv")
  395. # Comparison
  396. print(f"\n{'='*70}")
  397. print(f" COMPARISON")
  398. print(f"{'='*70}")
  399. print(f" SFT average: {sft_avg:.1f}%")
  400. print(f" GRPO average: {grpo_avg:.1f}%")
  401. print(f" Delta: {grpo_avg - sft_avg:+.1f}%")
  402. improved = sum(1 for s, g in zip(sft_results, grpo_results)
  403. if g["scores"]["percentage"] > s["scores"]["percentage"])
  404. regressed = sum(1 for s, g in zip(sft_results, grpo_results)
  405. if g["scores"]["percentage"] < s["scores"]["percentage"])
  406. print(f" Improved: {improved}/{len(QUERIES)}, Regressed: {regressed}/{len(QUERIES)}")
  407. if __name__ == "__main__":
  408. main()