eval.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "huggingface_hub",
  8. # "accelerate",
  9. # ]
  10. # ///
  11. """
  12. Evaluate a QMD query expansion model.
  13. Generates expansions from a model and scores them using the reward function.
  14. Replaces the old two-step evals/run.py + evals/score.py workflow.
  15. Usage:
  16. uv run eval.py --model tobil/qmd-query-expansion-1.7B-sft
  17. uv run eval.py --model tobil/qmd-query-expansion-1.7B-grpo \
  18. --sft-model tobil/qmd-query-expansion-1.7B-sft \
  19. --base-model Qwen/Qwen3-1.7B
  20. uv run eval.py --model ./local-checkpoint --verbose
  21. uv run eval.py --score-only results.jsonl
  22. """
  23. import argparse
  24. import json
  25. import os
  26. import sys
  27. from collections import Counter
  28. from datetime import datetime
  29. from pathlib import Path
  30. # Import reward module
  31. sys.path.insert(0, os.path.dirname(__file__))
  32. from reward import score_expansion_detailed
  33. def load_queries(path: str) -> list[str]:
  34. """Load queries from file, one per line, ignoring comments and blanks."""
  35. queries = []
  36. with open(path) as f:
  37. for line in f:
  38. line = line.strip()
  39. if line and not line.startswith("#"):
  40. queries.append(line)
  41. return queries
  42. def load_model(model_path: str, base_model: str = None, sft_model: str = None):
  43. """Load model with optional SFT stacking for GRPO models."""
  44. import torch
  45. from peft import PeftModel
  46. from transformers import AutoModelForCausalLM, AutoTokenizer
  47. is_local = Path(model_path).exists()
  48. # Auto-detect adapter vs full model
  49. is_adapter = True
  50. if is_local:
  51. is_adapter = (Path(model_path) / "adapter_config.json").exists()
  52. if is_adapter and not base_model:
  53. with open(Path(model_path) / "adapter_config.json") as f:
  54. config = json.load(f)
  55. base_model = config.get("base_model_name_or_path", "Qwen/Qwen3-1.7B")
  56. if not base_model:
  57. base_model = "Qwen/Qwen3-1.7B"
  58. print(f"Loading tokenizer from {base_model}...", file=sys.stderr)
  59. tokenizer = AutoTokenizer.from_pretrained(base_model)
  60. if tokenizer.pad_token is None:
  61. tokenizer.pad_token = tokenizer.eos_token
  62. print(f"Loading base model {base_model}...", file=sys.stderr)
  63. model = AutoModelForCausalLM.from_pretrained(
  64. base_model, torch_dtype=torch.bfloat16, device_map="auto",
  65. )
  66. if sft_model:
  67. print(f"Loading and merging SFT adapter {sft_model}...", file=sys.stderr)
  68. model = PeftModel.from_pretrained(model, sft_model)
  69. model = model.merge_and_unload()
  70. if is_adapter:
  71. print(f"Loading adapter from {model_path}...", file=sys.stderr)
  72. model = PeftModel.from_pretrained(model, model_path)
  73. model.eval()
  74. return model, tokenizer
  75. def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200) -> str:
  76. """Generate a query expansion using Qwen3 chat template with /no_think."""
  77. import torch
  78. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  79. prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  80. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  81. with torch.no_grad():
  82. outputs = model.generate(
  83. **inputs,
  84. max_new_tokens=max_new_tokens,
  85. temperature=0.7,
  86. do_sample=True,
  87. pad_token_id=tokenizer.pad_token_id,
  88. eos_token_id=tokenizer.eos_token_id,
  89. )
  90. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  91. # Extract assistant response
  92. if "\nassistant\n" in full_output:
  93. expansion = full_output.split("\nassistant\n")[-1].strip()
  94. elif "assistant\n" in full_output:
  95. expansion = full_output.split("assistant\n")[-1].strip()
  96. else:
  97. expansion = full_output[len(prompt):].strip()
  98. # Strip leftover <think> blocks
  99. import re
  100. if "<think>" in expansion:
  101. expansion = re.sub(r'<think>.*?</think>', '', expansion, flags=re.DOTALL).strip()
  102. return expansion
  103. def print_result(query: str, expansion: str, scores: dict, verbose: bool = False):
  104. """Print a single scored result."""
  105. print(f"\n{'='*60}")
  106. print(f"Query: {query}")
  107. print(f"{'~'*60}")
  108. print(expansion)
  109. print(f"{'~'*60}")
  110. print(f"Score: {scores['percentage']:.0f}% ({scores['rating']})")
  111. print(f" Format: {scores['format']}/30 Diversity: {scores['diversity']}/30 "
  112. f"Hyde: {scores['hyde']}/20 Quality: {scores['quality']}/20 "
  113. f"Entity: {scores['entity']}/20 Think: {scores['think_bonus']}/20")
  114. if verbose and scores["deductions"]:
  115. print(f" Issues: {', '.join(scores['deductions'][:5])}")
  116. if verbose and scores["entities_detected"]:
  117. print(f" Entities: {scores['entities_detected']}")
  118. def print_summary(scored_results: list):
  119. """Print aggregate summary."""
  120. print(f"\n{'='*60}")
  121. print("SUMMARY")
  122. print(f"{'='*60}")
  123. avg_score = sum(r["scores"]["percentage"] for r in scored_results) / len(scored_results)
  124. ratings = Counter(r["scores"]["rating"] for r in scored_results)
  125. print(f" Total queries: {len(scored_results)}")
  126. print(f" Average score: {avg_score:.1f}%")
  127. print(f" Ratings:")
  128. for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]:
  129. count = ratings.get(rating, 0)
  130. print(f" {rating:10s}: {count:2d} {'#' * count}")
  131. def cmd_generate_and_score(args):
  132. """Generate expansions from a model and score them."""
  133. queries = load_queries(args.queries)
  134. print(f"Loaded {len(queries)} queries from {args.queries}", file=sys.stderr)
  135. model, tokenizer = load_model(args.model, args.base_model, args.sft_model)
  136. scored_results = []
  137. for i, query in enumerate(queries, 1):
  138. print(f"[{i}/{len(queries)}] {query}", file=sys.stderr)
  139. expansion = generate_expansion(model, tokenizer, query, args.max_tokens)
  140. scores = score_expansion_detailed(query, expansion)
  141. if not args.summary_only:
  142. print_result(query, expansion, scores, args.verbose)
  143. scored_results.append({
  144. "query": query,
  145. "expansion": expansion,
  146. "scores": {k: v for k, v in scores.items() if k not in ("parsed", "deductions", "entities_detected")},
  147. "deductions": scores["deductions"],
  148. "entities_detected": scores["entities_detected"],
  149. })
  150. print_summary(scored_results)
  151. if args.output:
  152. output_data = {
  153. "metadata": {"model": args.model, "timestamp": datetime.now().isoformat()},
  154. "summary": {
  155. "total": len(scored_results),
  156. "average_score": round(sum(r["scores"]["percentage"] for r in scored_results) / len(scored_results), 1),
  157. },
  158. "results": scored_results,
  159. }
  160. with open(args.output, "w") as f:
  161. json.dump(output_data, f, indent=2)
  162. print(f"\nScores saved to: {args.output}")
  163. def cmd_score_only(args):
  164. """Score an existing JSONL file (from the old run.py format)."""
  165. results = []
  166. with open(args.input) as f:
  167. for line in f:
  168. data = json.loads(line)
  169. if not data.get("_meta"):
  170. results.append(data)
  171. scored_results = []
  172. for result in results:
  173. query = result["query"]
  174. expansion = result["expansion"]
  175. scores = score_expansion_detailed(query, expansion)
  176. if not args.summary_only:
  177. print_result(query, expansion, scores, args.verbose)
  178. scored_results.append({
  179. "query": query,
  180. "expansion": expansion,
  181. "scores": {k: v for k, v in scores.items() if k not in ("parsed", "deductions", "entities_detected")},
  182. "deductions": scores["deductions"],
  183. "entities_detected": scores["entities_detected"],
  184. })
  185. print_summary(scored_results)
  186. def main():
  187. parser = argparse.ArgumentParser(
  188. description="Evaluate QMD query expansion models",
  189. formatter_class=argparse.RawDescriptionHelpFormatter,
  190. epilog="""
  191. Examples:
  192. uv run eval.py --model tobil/qmd-query-expansion-1.7B-sft
  193. uv run eval.py --model tobil/qmd-query-expansion-1.7B-grpo \\
  194. --sft-model tobil/qmd-query-expansion-1.7B-sft
  195. uv run eval.py --score-only evals/results.jsonl
  196. """,
  197. )
  198. # Model evaluation mode
  199. parser.add_argument("--model", help="Model path (HF Hub or local)")
  200. parser.add_argument("--base-model", default=None, help="Base model for tokenizer (default: Qwen/Qwen3-1.7B)")
  201. parser.add_argument("--sft-model", default=None, help="SFT adapter to merge first (for GRPO models)")
  202. parser.add_argument("--queries", default="evals/queries.txt", help="Queries file")
  203. parser.add_argument("--max-tokens", type=int, default=200, help="Max tokens per generation")
  204. # Score-only mode
  205. parser.add_argument("--score-only", metavar="JSONL", help="Score existing JSONL file instead of generating")
  206. # Output options
  207. parser.add_argument("--output", "-o", help="Save detailed scores to JSON file")
  208. parser.add_argument("--verbose", "-v", action="store_true")
  209. parser.add_argument("--summary-only", action="store_true")
  210. args = parser.parse_args()
  211. if args.score_only:
  212. args.input = args.score_only
  213. cmd_score_only(args)
  214. elif args.model:
  215. cmd_generate_and_score(args)
  216. else:
  217. parser.error("Either --model or --score-only is required")
  218. if __name__ == "__main__":
  219. main()