eval.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "accelerate",
  8. # ]
  9. # ///
  10. """
  11. Minimal QMD query expansion evaluator.
  12. Usage:
  13. uv run eval.py ./outputs/sft
  14. uv run eval.py ./outputs/sft --queries evals/queries.txt
  15. By default, query file defaults to evals/queries.txt and runs all queries unless --max-queries is set.
  16. """
  17. import argparse
  18. import json
  19. import re
  20. import sys
  21. from pathlib import Path
  22. # Import reward scoring
  23. sys.path.insert(0, str(Path(__file__).parent))
  24. from reward import score_expansion_detailed
  25. DEFAULT_QUERY_FILE = Path(__file__).parent / "evals" / "queries.txt"
  26. def load_model(model_path: str):
  27. """Load model (adapter or merged)."""
  28. import torch
  29. from peft import PeftModel
  30. from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  31. model_path = Path(model_path)
  32. adapter_config = model_path / "adapter_config.json"
  33. # Get base model from adapter config or default
  34. base_model = "Qwen/Qwen3-1.7B"
  35. if adapter_config.exists():
  36. with open(adapter_config) as f:
  37. cfg = json.load(f)
  38. base_model = cfg.get("base_model_name_or_path", base_model)
  39. print(f"Loading base: {base_model}", file=sys.stderr)
  40. tokenizer = AutoTokenizer.from_pretrained(base_model)
  41. if tokenizer.pad_token is None:
  42. tokenizer.pad_token = tokenizer.eos_token
  43. tokenizer.padding_side = "left"
  44. config = AutoConfig.from_pretrained(base_model)
  45. config.tie_word_embeddings = False
  46. model = AutoModelForCausalLM.from_pretrained(
  47. base_model, dtype=torch.bfloat16, device_map={"": 0}, config=config
  48. )
  49. if model.generation_config is not None:
  50. model.generation_config.do_sample = False
  51. model.generation_config.temperature = None
  52. model.generation_config.top_p = None
  53. model.generation_config.top_k = None
  54. # Load adapter if present
  55. if adapter_config.exists():
  56. print(f"Loading adapter: {model_path}", file=sys.stderr)
  57. model = PeftModel.from_pretrained(model, str(model_path))
  58. model.eval()
  59. return model, tokenizer
  60. def generate_batch(
  61. model, tokenizer, queries: list[str], max_new_tokens: int, max_time: float | None
  62. ) -> list[str]:
  63. """Generate expansions for a batch of queries."""
  64. import torch
  65. prompts = [
  66. tokenizer.apply_chat_template(
  67. [{"role": "user", "content": f"/no_think Expand this search query: {q}"}],
  68. tokenize=False,
  69. add_generation_prompt=True,
  70. )
  71. for q in queries
  72. ]
  73. inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
  74. input_len = inputs["input_ids"].shape[1]
  75. generate_kwargs = {
  76. "max_new_tokens": max_new_tokens,
  77. "do_sample": False,
  78. "num_beams": 1,
  79. "pad_token_id": tokenizer.pad_token_id,
  80. "eos_token_id": tokenizer.eos_token_id,
  81. "use_cache": True,
  82. }
  83. if max_time and max_time > 0:
  84. generate_kwargs["max_time"] = max_time
  85. with torch.inference_mode():
  86. out = model.generate(**inputs, **generate_kwargs)
  87. outputs = []
  88. for i in range(len(queries)):
  89. gen_tokens = out[i][input_len:]
  90. text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
  91. text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
  92. outputs.append(text.strip())
  93. return outputs
  94. def main():
  95. parser = argparse.ArgumentParser(description="Evaluate QMD model")
  96. parser.add_argument("model", help="Model path (local or HF)")
  97. parser.add_argument(
  98. "--queries",
  99. default=str(DEFAULT_QUERY_FILE),
  100. help="Queries file (one per line) [default: evals/queries.txt]",
  101. )
  102. parser.add_argument(
  103. "--max-new-tokens",
  104. type=int,
  105. default=400,
  106. help="Maximum new tokens to generate (default: 400)",
  107. )
  108. parser.add_argument(
  109. "--max-time",
  110. type=float,
  111. default=0,
  112. help="Max seconds per batch generation (0 disables)",
  113. )
  114. parser.add_argument(
  115. "--batch-size",
  116. type=int,
  117. default=2,
  118. help="Batch size for generation (default: 2)",
  119. )
  120. parser.add_argument(
  121. "--max-queries",
  122. type=int,
  123. default=0,
  124. help="Limit number of queries (0 disables)",
  125. )
  126. args = parser.parse_args()
  127. # Load queries (default to full evals/queries.txt)
  128. query_file = Path(args.queries)
  129. if not query_file.exists():
  130. raise FileNotFoundError(f"Queries file not found: {query_file}")
  131. with query_file.open(encoding="utf-8") as f:
  132. queries = [
  133. l.strip() for l in f if l.strip() and not l.strip().startswith("#")
  134. ]
  135. if args.max_queries and args.max_queries > 0:
  136. queries = queries[: args.max_queries]
  137. # Load model
  138. model, tokenizer = load_model(args.model)
  139. # Run eval
  140. scores = []
  141. batch_size = max(1, args.batch_size)
  142. total = len(queries)
  143. for start in range(0, total, batch_size):
  144. batch = queries[start : start + batch_size]
  145. batch_outputs = generate_batch(
  146. model, tokenizer, batch, args.max_new_tokens, args.max_time
  147. )
  148. for i, (query, expansion) in enumerate(zip(batch, batch_outputs), start + 1):
  149. print(f"\n[{i}/{total}] {query}")
  150. print("-" * 50)
  151. result = score_expansion_detailed(query, expansion)
  152. print(expansion[:300] + ("..." if len(expansion) > 300 else ""))
  153. print(f"Score: {result['percentage']:.0f}% ({result['rating']})")
  154. scores.append(result["percentage"])
  155. # Summary
  156. avg = sum(scores) / len(scores)
  157. print(f"\n{'=' * 50}")
  158. print(f"Average: {avg:.1f}% | Model: {args.model}")
  159. print(f"{'=' * 50}")
  160. return 0 if avg >= 50 else 1
  161. if __name__ == "__main__":
  162. sys.exit(main())