eval.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "huggingface_hub",
  8. # "accelerate",
  9. # ]
  10. # ///
  11. """
  12. Evaluate a QMD query expansion model.
  13. Generates expansions from a model and scores them using the reward function.
  14. Replaces the old two-step evals/run.py + evals/score.py workflow.
  15. Usage:
  16. uv run eval.py --model tobil/qmd-query-expansion-1.7B-sft
  17. uv run eval.py --model tobil/qmd-query-expansion-1.7B-grpo \
  18. --sft-model tobil/qmd-query-expansion-1.7B-sft \
  19. --base-model Qwen/Qwen3-1.7B
  20. uv run eval.py --model ./local-checkpoint --verbose
  21. uv run eval.py --score-only results.jsonl
  22. """
  23. import argparse
  24. import json
  25. import os
  26. import sys
  27. from collections import Counter
  28. from datetime import datetime
  29. from pathlib import Path
  30. # Import reward module
  31. sys.path.insert(0, os.path.dirname(__file__))
  32. from reward import score_expansion_detailed
  33. def load_queries(path: str) -> list[str]:
  34. """Load queries from file, one per line, ignoring comments and blanks."""
  35. queries = []
  36. with open(path) as f:
  37. for line in f:
  38. line = line.strip()
  39. if line and not line.startswith("#"):
  40. queries.append(line)
  41. return queries
  42. def load_model(model_path: str, base_model: str = None, sft_model: str = None):
  43. """Load model with optional SFT stacking for GRPO models."""
  44. import torch
  45. from peft import PeftModel
  46. from transformers import AutoModelForCausalLM, AutoTokenizer
  47. is_local = Path(model_path).exists()
  48. # Auto-detect adapter vs full model
  49. is_adapter = True
  50. if is_local:
  51. is_adapter = (Path(model_path) / "adapter_config.json").exists()
  52. if is_adapter and not base_model:
  53. with open(Path(model_path) / "adapter_config.json") as f:
  54. config = json.load(f)
  55. base_model = config.get("base_model_name_or_path", "Qwen/Qwen3-1.7B")
  56. if not base_model:
  57. base_model = "Qwen/Qwen3-1.7B"
  58. print(f"Loading tokenizer from {base_model}...", file=sys.stderr)
  59. tokenizer = AutoTokenizer.from_pretrained(base_model)
  60. if tokenizer.pad_token is None:
  61. tokenizer.pad_token = tokenizer.eos_token
  62. print(f"Loading base model {base_model}...", file=sys.stderr)
  63. model = AutoModelForCausalLM.from_pretrained(
  64. base_model,
  65. torch_dtype=torch.float32,
  66. device_map="auto",
  67. )
  68. if sft_model:
  69. print(f"Loading and merging SFT adapter {sft_model}...", file=sys.stderr)
  70. model = PeftModel.from_pretrained(model, sft_model)
  71. model = model.merge_and_unload()
  72. if is_adapter:
  73. print(f"Loading adapter from {model_path}...", file=sys.stderr)
  74. model = PeftModel.from_pretrained(model, model_path)
  75. model.eval()
  76. return model, tokenizer
  77. def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200) -> str:
  78. """Generate a query expansion using Qwen3 chat template with /no_think."""
  79. import torch
  80. messages = [
  81. {"role": "user", "content": f"/no_think Expand this search query: {query}"}
  82. ]
  83. prompt = tokenizer.apply_chat_template(
  84. messages, tokenize=False, add_generation_prompt=True
  85. )
  86. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  87. with torch.no_grad():
  88. outputs = model.generate(
  89. **inputs,
  90. max_new_tokens=max_new_tokens,
  91. temperature=0.1,
  92. do_sample=True,
  93. top_p=0.9,
  94. pad_token_id=tokenizer.pad_token_id,
  95. eos_token_id=tokenizer.eos_token_id,
  96. )
  97. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  98. # Extract assistant response
  99. if "\nassistant\n" in full_output:
  100. expansion = full_output.split("\nassistant\n")[-1].strip()
  101. elif "assistant\n" in full_output:
  102. expansion = full_output.split("assistant\n")[-1].strip()
  103. else:
  104. expansion = full_output[len(prompt) :].strip()
  105. # Strip leftover <think> blocks
  106. import re
  107. if "<think>" in expansion:
  108. expansion = re.sub(
  109. r"<think>.*?</think>", "", expansion, flags=re.DOTALL
  110. ).strip()
  111. return expansion
  112. def print_result(query: str, expansion: str, scores: dict, verbose: bool = False):
  113. """Print a single scored result."""
  114. print(f"\n{'=' * 60}")
  115. print(f"Query: {query}")
  116. print(f"{'~' * 60}")
  117. print(expansion)
  118. print(f"{'~' * 60}")
  119. print(f"Score: {scores['percentage']:.0f}% ({scores['rating']})")
  120. print(
  121. f" Format: {scores['format']}/30 Diversity: {scores['diversity']}/30 "
  122. f"Hyde: {scores['hyde']}/20 Quality: {scores['quality']}/20 "
  123. f"Entity: {scores['entity']}/20 Think: {scores['think_bonus']}/20"
  124. )
  125. if verbose and scores["deductions"]:
  126. print(f" Issues: {', '.join(scores['deductions'][:5])}")
  127. if verbose and scores["entities_detected"]:
  128. print(f" Entities: {scores['entities_detected']}")
  129. def print_summary(scored_results: list):
  130. """Print aggregate summary."""
  131. print(f"\n{'=' * 60}")
  132. print("SUMMARY")
  133. print(f"{'=' * 60}")
  134. avg_score = sum(r["scores"]["percentage"] for r in scored_results) / len(
  135. scored_results
  136. )
  137. ratings = Counter(r["scores"]["rating"] for r in scored_results)
  138. print(f" Total queries: {len(scored_results)}")
  139. print(f" Average score: {avg_score:.1f}%")
  140. print(f" Ratings:")
  141. for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
  142. count = ratings.get(rating, 0)
  143. print(f" {rating:10s}: {count:2d} {'#' * count}")
  144. def cmd_generate_and_score(args):
  145. """Generate expansions from a model and score them."""
  146. queries = load_queries(args.queries)
  147. print(f"Loaded {len(queries)} queries from {args.queries}", file=sys.stderr)
  148. model, tokenizer = load_model(args.model, args.base_model, args.sft_model)
  149. scored_results = []
  150. for i, query in enumerate(queries, 1):
  151. print(f"[{i}/{len(queries)}] {query}", file=sys.stderr)
  152. expansion = generate_expansion(model, tokenizer, query, args.max_tokens)
  153. scores = score_expansion_detailed(query, expansion)
  154. if not args.summary_only:
  155. print_result(query, expansion, scores, args.verbose)
  156. scored_results.append(
  157. {
  158. "query": query,
  159. "expansion": expansion,
  160. "scores": {
  161. k: v
  162. for k, v in scores.items()
  163. if k not in ("parsed", "deductions", "entities_detected")
  164. },
  165. "deductions": scores["deductions"],
  166. "entities_detected": scores["entities_detected"],
  167. }
  168. )
  169. print_summary(scored_results)
  170. if args.output:
  171. output_data = {
  172. "metadata": {"model": args.model, "timestamp": datetime.now().isoformat()},
  173. "summary": {
  174. "total": len(scored_results),
  175. "average_score": round(
  176. sum(r["scores"]["percentage"] for r in scored_results)
  177. / len(scored_results),
  178. 1,
  179. ),
  180. },
  181. "results": scored_results,
  182. }
  183. with open(args.output, "w") as f:
  184. json.dump(output_data, f, indent=2)
  185. print(f"\nScores saved to: {args.output}")
  186. def cmd_score_only(args):
  187. """Score an existing JSONL file (from the old run.py format)."""
  188. results = []
  189. with open(args.input) as f:
  190. for line in f:
  191. data = json.loads(line)
  192. if not data.get("_meta"):
  193. results.append(data)
  194. scored_results = []
  195. for result in results:
  196. query = result["query"]
  197. expansion = result["expansion"]
  198. scores = score_expansion_detailed(query, expansion)
  199. if not args.summary_only:
  200. print_result(query, expansion, scores, args.verbose)
  201. scored_results.append(
  202. {
  203. "query": query,
  204. "expansion": expansion,
  205. "scores": {
  206. k: v
  207. for k, v in scores.items()
  208. if k not in ("parsed", "deductions", "entities_detected")
  209. },
  210. "deductions": scores["deductions"],
  211. "entities_detected": scores["entities_detected"],
  212. }
  213. )
  214. print_summary(scored_results)
  215. def main():
  216. parser = argparse.ArgumentParser(
  217. description="Evaluate QMD query expansion models",
  218. formatter_class=argparse.RawDescriptionHelpFormatter,
  219. epilog="""
  220. Examples:
  221. uv run eval.py --model tobil/qmd-query-expansion-1.7B-sft
  222. uv run eval.py --model tobil/qmd-query-expansion-1.7B-grpo \\
  223. --sft-model tobil/qmd-query-expansion-1.7B-sft
  224. uv run eval.py --score-only evals/results.jsonl
  225. """,
  226. )
  227. # Model evaluation mode
  228. parser.add_argument("--model", help="Model path (HF Hub or local)")
  229. parser.add_argument(
  230. "--base-model",
  231. default=None,
  232. help="Base model for tokenizer (default: Qwen/Qwen3-1.7B)",
  233. )
  234. parser.add_argument(
  235. "--sft-model", default=None, help="SFT adapter to merge first (for GRPO models)"
  236. )
  237. parser.add_argument("--queries", default="evals/queries.txt", help="Queries file")
  238. parser.add_argument(
  239. "--max-tokens", type=int, default=200, help="Max tokens per generation"
  240. )
  241. # Score-only mode
  242. parser.add_argument(
  243. "--score-only",
  244. metavar="JSONL",
  245. help="Score existing JSONL file instead of generating",
  246. )
  247. # Output options
  248. parser.add_argument("--output", "-o", help="Save detailed scores to JSON file")
  249. parser.add_argument("--verbose", "-v", action="store_true")
  250. parser.add_argument("--summary-only", action="store_true")
  251. args = parser.parse_args()
  252. if args.score_only:
  253. args.input = args.score_only
  254. cmd_score_only(args)
  255. elif args.model:
  256. cmd_generate_and_score(args)
  257. else:
  258. parser.error("Either --model or --score-only is required")
  259. if __name__ == "__main__":
  260. main()