eval_retrieval.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "accelerate",
  8. # ]
  9. # ///
  10. """
  11. QMD Retrieval-Based Evaluation with Precision & Recall
  12. Evaluates model outputs against golden data (training set).
  13. Measures how well the model reproduces the expected expansions.
  14. Metrics:
  15. - Precision: Of model-generated expansions, how many match golden?
  16. - Recall: Of golden expansions, how many did the model generate?
  17. - F1: Harmonic mean of precision and recall
  18. Matching is done via token overlap (Jaccard similarity) with a threshold.
  19. Usage:
  20. uv run eval_retrieval.py ./outputs/sft
  21. uv run eval_retrieval.py tobil/qmd-query-expansion-1.7B --golden data/qmd_expansion_v3_structured.jsonl
  22. uv run eval_retrieval.py ./outputs/sft --threshold 0.5 --sample 100
  23. """
  24. import argparse
  25. import json
  26. import random
  27. import re
  28. import sys
  29. from collections import defaultdict
  30. from pathlib import Path
  31. # =============================================================================
  32. # Matching Functions
  33. # =============================================================================
  34. def tokenize(text: str) -> set[str]:
  35. """Tokenize text into lowercase word set, removing stopwords."""
  36. stopwords = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and',
  37. 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
  38. 'how', 'what', 'do', 'does', 'can', 'you', 'your', 'i'}
  39. words = re.findall(r'\b\w+\b', text.lower())
  40. return {w for w in words if w not in stopwords and len(w) > 1}
  41. def jaccard_similarity(a: str, b: str) -> float:
  42. """Jaccard similarity between two strings based on token overlap."""
  43. tokens_a = tokenize(a)
  44. tokens_b = tokenize(b)
  45. if not tokens_a or not tokens_b:
  46. return 0.0
  47. intersection = len(tokens_a & tokens_b)
  48. union = len(tokens_a | tokens_b)
  49. return intersection / union if union > 0 else 0.0
  50. def find_best_match(pred: str, golden_list: list[str], threshold: float) -> tuple[str | None, float]:
  51. """Find best matching golden expansion for a prediction."""
  52. best_match = None
  53. best_score = 0.0
  54. for golden in golden_list:
  55. score = jaccard_similarity(pred, golden)
  56. if score > best_score:
  57. best_score = score
  58. best_match = golden
  59. if best_score >= threshold:
  60. return best_match, best_score
  61. return None, best_score
  62. # =============================================================================
  63. # Parsing
  64. # =============================================================================
  65. def parse_model_output(text: str) -> dict[str, list[str]]:
  66. """Parse model output into {lex: [...], vec: [...], hyde: [...]}."""
  67. # Clean thinking tags
  68. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
  69. text = text.replace('<|im_end|>', '').strip()
  70. result = {"lex": [], "vec": [], "hyde": []}
  71. for line in text.strip().split("\n"):
  72. line = line.strip()
  73. if not line:
  74. continue
  75. if line.startswith("lex:"):
  76. result["lex"].append(line[4:].strip())
  77. elif line.startswith("vec:"):
  78. result["vec"].append(line[4:].strip())
  79. elif line.startswith("hyde:"):
  80. result["hyde"].append(line[5:].strip())
  81. return result
  82. def parse_golden_data(searches: list[dict] | str) -> dict[str, list[str]]:
  83. """Parse golden data format into {lex: [...], vec: [...], hyde: [...]}."""
  84. # If it's a string (from messages format), parse it
  85. if isinstance(searches, str):
  86. return parse_model_output(searches)
  87. # Otherwise it's the structured format [{type, query}, ...]
  88. result = {"lex": [], "vec": [], "hyde": []}
  89. for item in searches:
  90. exp_type = item.get("type", "")
  91. value = item.get("query", "") or item.get("value", "")
  92. if exp_type in result:
  93. result[exp_type].append(value)
  94. return result
  95. def load_golden_data(filepath: Path) -> list[dict]:
  96. """Load golden data from JSONL, supporting both structured and messages formats."""
  97. data = []
  98. with open(filepath) as f:
  99. for line in f:
  100. if not line.strip():
  101. continue
  102. item = json.loads(line)
  103. # Structured format: {query, searches}
  104. if "query" in item and "searches" in item:
  105. data.append({
  106. "query": item["query"],
  107. "searches": item["searches"]
  108. })
  109. # Messages format: {messages: [{role, content}, ...]}
  110. elif "messages" in item:
  111. messages = item["messages"]
  112. query = None
  113. searches = None
  114. for msg in messages:
  115. if msg["role"] == "user":
  116. # Extract query from "/no_think Expand this search query: ..."
  117. content = msg["content"]
  118. if "Expand this search query:" in content:
  119. query = content.split("Expand this search query:")[-1].strip()
  120. else:
  121. query = content.strip()
  122. elif msg["role"] == "assistant":
  123. # The assistant content IS the expected output
  124. searches = msg["content"]
  125. if query and searches:
  126. data.append({
  127. "query": query,
  128. "searches": searches # Will be parsed as string
  129. })
  130. return data
  131. # =============================================================================
  132. # Metrics Calculation
  133. # =============================================================================
  134. # Different thresholds by type - lex needs strict matching, hyde is more flexible
  135. DEFAULT_THRESHOLDS = {
  136. "lex": 0.5, # Keywords should overlap well
  137. "vec": 0.35, # Semantic sentences have more variation
  138. "hyde": 0.25, # Passages have the most variation
  139. }
  140. def calculate_metrics(
  141. predictions: dict[str, list[str]],
  142. golden: dict[str, list[str]],
  143. threshold: float | dict[str, float] = 0.4,
  144. return_mismatches: bool = False
  145. ) -> dict:
  146. """Calculate precision, recall, F1 per type and overall.
  147. Args:
  148. threshold: Either a single float, or dict mapping type -> threshold
  149. return_mismatches: If True, include lists of unmatched predictions/golden
  150. """
  151. if isinstance(threshold, (int, float)):
  152. thresholds = {"lex": threshold, "vec": threshold, "hyde": threshold}
  153. else:
  154. thresholds = threshold
  155. metrics = {}
  156. mismatches = {}
  157. total_tp = 0
  158. total_pred = 0
  159. total_golden = 0
  160. for exp_type in ["lex", "vec", "hyde"]:
  161. preds = predictions.get(exp_type, [])
  162. golds = golden.get(exp_type, [])
  163. type_threshold = thresholds.get(exp_type, 0.4)
  164. if not preds and not golds:
  165. continue
  166. # Track which golden items were matched
  167. matched_golden = set()
  168. unmatched_preds = []
  169. tp = 0
  170. for pred in preds:
  171. match, score = find_best_match(pred, golds, type_threshold)
  172. if match is not None:
  173. tp += 1
  174. matched_golden.add(match)
  175. else:
  176. unmatched_preds.append((pred, score))
  177. unmatched_golden = [g for g in golds if g not in matched_golden]
  178. precision = tp / len(preds) if preds else 0.0
  179. recall = len(matched_golden) / len(golds) if golds else 0.0
  180. f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
  181. metrics[exp_type] = {
  182. "precision": precision,
  183. "recall": recall,
  184. "f1": f1,
  185. "pred_count": len(preds),
  186. "golden_count": len(golds),
  187. "matched": tp,
  188. }
  189. if return_mismatches:
  190. mismatches[exp_type] = {
  191. "unmatched_preds": unmatched_preds,
  192. "unmatched_golden": unmatched_golden,
  193. }
  194. total_tp += tp
  195. total_pred += len(preds)
  196. total_golden += len(golds)
  197. # Overall metrics (micro-averaged)
  198. overall_precision = total_tp / total_pred if total_pred > 0 else 0.0
  199. overall_recall = total_tp / total_golden if total_golden > 0 else 0.0
  200. overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0.0
  201. metrics["overall"] = {
  202. "precision": overall_precision,
  203. "recall": overall_recall,
  204. "f1": overall_f1,
  205. "pred_count": total_pred,
  206. "golden_count": total_golden,
  207. "matched": total_tp,
  208. }
  209. if return_mismatches:
  210. metrics["_mismatches"] = mismatches
  211. return metrics
  212. # =============================================================================
  213. # Model Loading and Generation
  214. # =============================================================================
  215. def load_model(model_path: str):
  216. """Load model (adapter or merged)."""
  217. import torch
  218. from peft import PeftModel
  219. from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  220. model_path = Path(model_path)
  221. adapter_config = model_path / "adapter_config.json"
  222. # Get base model from adapter config or default
  223. base_model = "Qwen/Qwen3-1.7B"
  224. if adapter_config.exists():
  225. with open(adapter_config) as f:
  226. cfg = json.load(f)
  227. base_model = cfg.get("base_model_name_or_path", base_model)
  228. print(f"Loading base: {base_model}", file=sys.stderr)
  229. tokenizer = AutoTokenizer.from_pretrained(base_model)
  230. if tokenizer.pad_token is None:
  231. tokenizer.pad_token = tokenizer.eos_token
  232. tokenizer.padding_side = "left"
  233. config = AutoConfig.from_pretrained(base_model)
  234. config.tie_word_embeddings = False
  235. model = AutoModelForCausalLM.from_pretrained(
  236. base_model, dtype=torch.bfloat16, device_map={"": 0}, config=config
  237. )
  238. if model.generation_config is not None:
  239. model.generation_config.do_sample = False
  240. model.generation_config.temperature = None
  241. model.generation_config.top_p = None
  242. model.generation_config.top_k = None
  243. # Load adapter if present
  244. if adapter_config.exists():
  245. print(f"Loading adapter: {model_path}", file=sys.stderr)
  246. model = PeftModel.from_pretrained(model, str(model_path))
  247. model.eval()
  248. return model, tokenizer
  249. def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 400) -> str:
  250. """Generate expansion for a single query."""
  251. import torch
  252. prompt = tokenizer.apply_chat_template(
  253. [{"role": "user", "content": f"/no_think Expand this search query: {query}"}],
  254. tokenize=False,
  255. add_generation_prompt=True,
  256. )
  257. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  258. input_len = inputs["input_ids"].shape[1]
  259. with torch.inference_mode():
  260. out = model.generate(
  261. **inputs,
  262. max_new_tokens=max_new_tokens,
  263. do_sample=False,
  264. num_beams=1,
  265. pad_token_id=tokenizer.pad_token_id,
  266. eos_token_id=tokenizer.eos_token_id,
  267. use_cache=True,
  268. )
  269. gen_tokens = out[0][input_len:]
  270. return tokenizer.decode(gen_tokens, skip_special_tokens=True)
  271. # =============================================================================
  272. # Main Evaluation
  273. # =============================================================================
  274. def main():
  275. parser = argparse.ArgumentParser(description="QMD Retrieval-Based Evaluation")
  276. parser.add_argument("model", help="Model path (local or HF)")
  277. parser.add_argument("--golden", default="data/qmd_expansion_v3_structured.jsonl",
  278. help="Golden data JSONL file")
  279. parser.add_argument("--threshold", type=float, default=None,
  280. help="Jaccard similarity threshold for all types (overrides --type-thresholds)")
  281. parser.add_argument("--type-thresholds", action="store_true",
  282. help="Use type-specific thresholds (lex=0.5, vec=0.35, hyde=0.25)")
  283. parser.add_argument("--sample", type=int, default=0,
  284. help="Sample N queries (0 = all)")
  285. parser.add_argument("--seed", type=int, default=42,
  286. help="Random seed for sampling")
  287. parser.add_argument("--max-new-tokens", type=int, default=400,
  288. help="Max new tokens to generate")
  289. parser.add_argument("--verbose", "-v", action="store_true",
  290. help="Show per-query details")
  291. parser.add_argument("--show-mismatches", action="store_true",
  292. help="Show examples of mismatched predictions")
  293. args = parser.parse_args()
  294. # Determine thresholds
  295. if args.threshold is not None:
  296. thresholds = args.threshold
  297. elif args.type_thresholds:
  298. thresholds = DEFAULT_THRESHOLDS.copy()
  299. else:
  300. thresholds = 0.4 # Default single threshold
  301. # Load golden data
  302. golden_path = Path(args.golden)
  303. if not golden_path.exists():
  304. # Try relative to script directory
  305. golden_path = Path(__file__).parent / args.golden
  306. if not golden_path.exists():
  307. print(f"Error: Golden data file not found: {args.golden}", file=sys.stderr)
  308. sys.exit(1)
  309. print(f"Loading golden data from {golden_path}...", file=sys.stderr)
  310. golden_data = load_golden_data(golden_path)
  311. print(f"Loaded {len(golden_data)} golden examples", file=sys.stderr)
  312. # Sample if requested
  313. if args.sample > 0 and args.sample < len(golden_data):
  314. random.seed(args.seed)
  315. golden_data = random.sample(golden_data, args.sample)
  316. print(f"Sampled {len(golden_data)} examples", file=sys.stderr)
  317. # Load model
  318. model, tokenizer = load_model(args.model)
  319. # Evaluate
  320. all_metrics = []
  321. all_mismatches = []
  322. type_aggregates = defaultdict(lambda: {"precision": [], "recall": [], "f1": []})
  323. threshold_desc = thresholds if isinstance(thresholds, (int, float)) else f"lex={thresholds['lex']}, vec={thresholds['vec']}, hyde={thresholds['hyde']}"
  324. print(f"\nEvaluating {len(golden_data)} queries (thresholds: {threshold_desc})...\n")
  325. for i, item in enumerate(golden_data, 1):
  326. query = item["query"]
  327. golden_parsed = parse_golden_data(item["searches"])
  328. # Generate model output
  329. output = generate_expansion(model, tokenizer, query, args.max_new_tokens)
  330. pred_parsed = parse_model_output(output)
  331. # Calculate metrics
  332. metrics = calculate_metrics(pred_parsed, golden_parsed, thresholds, return_mismatches=args.show_mismatches)
  333. all_metrics.append({"query": query, "metrics": metrics, "pred": pred_parsed, "golden": golden_parsed})
  334. if args.show_mismatches and "_mismatches" in metrics:
  335. all_mismatches.append({"query": query, "mismatches": metrics.pop("_mismatches")})
  336. # Aggregate by type
  337. for exp_type in ["lex", "vec", "hyde", "overall"]:
  338. if exp_type in metrics:
  339. type_aggregates[exp_type]["precision"].append(metrics[exp_type]["precision"])
  340. type_aggregates[exp_type]["recall"].append(metrics[exp_type]["recall"])
  341. type_aggregates[exp_type]["f1"].append(metrics[exp_type]["f1"])
  342. # Progress
  343. overall = metrics.get("overall", {})
  344. p = overall.get("precision", 0) * 100
  345. r = overall.get("recall", 0) * 100
  346. f = overall.get("f1", 0) * 100
  347. if args.verbose:
  348. print(f"[{i:3d}/{len(golden_data)}] P={p:5.1f}% R={r:5.1f}% F1={f:5.1f}% {query[:50]}")
  349. elif i % 50 == 0 or i == len(golden_data):
  350. print(f" Processed {i}/{len(golden_data)}...", file=sys.stderr)
  351. # Summary
  352. print(f"\n{'='*60}")
  353. print(f"RESULTS: {args.model}")
  354. print(f"{'='*60}")
  355. print(f"Threshold: {args.threshold} | Samples: {len(golden_data)}")
  356. print()
  357. print(f"{'Type':<10} {'Precision':>10} {'Recall':>10} {'F1':>10}")
  358. print("-" * 42)
  359. for exp_type in ["lex", "vec", "hyde", "overall"]:
  360. if exp_type in type_aggregates:
  361. agg = type_aggregates[exp_type]
  362. avg_p = sum(agg["precision"]) / len(agg["precision"]) * 100 if agg["precision"] else 0
  363. avg_r = sum(agg["recall"]) / len(agg["recall"]) * 100 if agg["recall"] else 0
  364. avg_f = sum(agg["f1"]) / len(agg["f1"]) * 100 if agg["f1"] else 0
  365. label = exp_type.upper() if exp_type != "overall" else "OVERALL"
  366. print(f"{label:<10} {avg_p:>9.1f}% {avg_r:>9.1f}% {avg_f:>9.1f}%")
  367. print(f"{'='*60}")
  368. # Show worst examples
  369. print("\nBottom 5 by F1:")
  370. sorted_by_f1 = sorted(all_metrics, key=lambda x: x["metrics"].get("overall", {}).get("f1", 0))
  371. for item in sorted_by_f1[:5]:
  372. f1 = item["metrics"].get("overall", {}).get("f1", 0) * 100
  373. print(f" {f1:5.1f}% {item['query'][:60]}")
  374. # Show mismatches if requested
  375. if args.show_mismatches and all_mismatches:
  376. print(f"\n{'='*60}")
  377. print("MISMATCH EXAMPLES")
  378. print(f"{'='*60}")
  379. # Group by type and show up to 3 examples per type
  380. for exp_type in ["lex", "vec", "hyde"]:
  381. type_mismatches = []
  382. for item in all_mismatches:
  383. if exp_type in item["mismatches"]:
  384. mm = item["mismatches"][exp_type]
  385. if mm["unmatched_preds"] or mm["unmatched_golden"]:
  386. type_mismatches.append({
  387. "query": item["query"],
  388. **mm
  389. })
  390. if type_mismatches:
  391. print(f"\n--- {exp_type.upper()} mismatches ({len(type_mismatches)} queries) ---")
  392. for example in type_mismatches[:3]:
  393. print(f"\nQuery: {example['query'][:60]}")
  394. if example["unmatched_preds"]:
  395. print(f" Unmatched predictions:")
  396. for pred, score in example["unmatched_preds"][:2]:
  397. print(f" - [{score:.2f}] {pred[:80]}{'...' if len(pred) > 80 else ''}")
  398. if example["unmatched_golden"]:
  399. print(f" Missing golden:")
  400. for g in example["unmatched_golden"][:2]:
  401. print(f" - {g[:80]}{'...' if len(g) > 80 else ''}")
  402. return 0
  403. if __name__ == "__main__":
  404. sys.exit(main())