eval.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "accelerate",
  8. # ]
  9. # ///
  10. """
  11. Minimal QMD query expansion evaluator.
  12. Usage:
  13. uv run eval.py ./outputs/sft
  14. uv run eval.py tobil/qmd-query-expansion-1.7B --queries evals/queries.txt
  15. """
  16. import argparse
  17. import json
  18. import re
  19. import sys
  20. import os
  21. from pathlib import Path
  22. # Import reward scoring
  23. sys.path.insert(0, str(Path(__file__).parent))
  24. from reward import score_expansion_detailed
  25. QUERIES = [
  26. "how to configure authentication",
  27. "docker compose networking",
  28. "auth",
  29. "who is TDS motorsports",
  30. "React hooks tutorial",
  31. "recent news about Shopify",
  32. "how to implement caching with redis in nodejs",
  33. "auth /only:lex",
  34. "kubernetes pod deployment /only:vec",
  35. "AWS Lambda cold start /only:hyde",
  36. ]
  37. def load_model(model_path: str):
  38. """Load model (adapter or merged)."""
  39. import torch
  40. from peft import PeftModel
  41. from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  42. model_path = Path(model_path)
  43. adapter_config = model_path / "adapter_config.json"
  44. # Get base model from adapter config or default
  45. base_model = "Qwen/Qwen3-1.7B"
  46. if adapter_config.exists():
  47. with open(adapter_config) as f:
  48. cfg = json.load(f)
  49. base_model = cfg.get("base_model_name_or_path", base_model)
  50. print(f"Loading base: {base_model}", file=sys.stderr)
  51. tokenizer = AutoTokenizer.from_pretrained(base_model)
  52. if tokenizer.pad_token is None:
  53. tokenizer.pad_token = tokenizer.eos_token
  54. tokenizer.padding_side = "left"
  55. config = AutoConfig.from_pretrained(base_model)
  56. config.tie_word_embeddings = False
  57. model = AutoModelForCausalLM.from_pretrained(
  58. base_model, dtype=torch.bfloat16, device_map={"": 0}, config=config
  59. )
  60. if model.generation_config is not None:
  61. model.generation_config.do_sample = False
  62. model.generation_config.temperature = None
  63. model.generation_config.top_p = None
  64. model.generation_config.top_k = None
  65. # Load adapter if present
  66. if adapter_config.exists():
  67. print(f"Loading adapter: {model_path}", file=sys.stderr)
  68. model = PeftModel.from_pretrained(model, str(model_path))
  69. model.eval()
  70. return model, tokenizer
  71. def generate_batch(
  72. model, tokenizer, queries: list[str], max_new_tokens: int, max_time: float | None
  73. ) -> list[str]:
  74. """Generate expansions for a batch of queries."""
  75. import torch
  76. prompts = [
  77. tokenizer.apply_chat_template(
  78. [{"role": "user", "content": f"/no_think Expand this search query: {q}"}],
  79. tokenize=False,
  80. add_generation_prompt=True,
  81. )
  82. for q in queries
  83. ]
  84. inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
  85. input_len = inputs["input_ids"].shape[1]
  86. generate_kwargs = {
  87. "max_new_tokens": max_new_tokens,
  88. "do_sample": False,
  89. "num_beams": 1,
  90. "pad_token_id": tokenizer.pad_token_id,
  91. "eos_token_id": tokenizer.eos_token_id,
  92. "use_cache": True,
  93. }
  94. if max_time and max_time > 0:
  95. generate_kwargs["max_time"] = max_time
  96. with torch.inference_mode():
  97. out = model.generate(**inputs, **generate_kwargs)
  98. outputs = []
  99. for i in range(len(queries)):
  100. gen_tokens = out[i][input_len:]
  101. text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
  102. text = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL)
  103. outputs.append(text.strip())
  104. return outputs
  105. def main():
  106. parser = argparse.ArgumentParser(description="Evaluate QMD model")
  107. parser.add_argument("model", help="Model path (local or HF)")
  108. parser.add_argument("--queries", help="Queries file (one per line)")
  109. parser.add_argument(
  110. "--max-new-tokens",
  111. type=int,
  112. default=400,
  113. help="Maximum new tokens to generate (default: 400)",
  114. )
  115. parser.add_argument(
  116. "--max-time",
  117. type=float,
  118. default=0,
  119. help="Max seconds per batch generation (0 disables)",
  120. )
  121. parser.add_argument(
  122. "--batch-size",
  123. type=int,
  124. default=2,
  125. help="Batch size for generation (default: 2)",
  126. )
  127. parser.add_argument(
  128. "--max-queries",
  129. type=int,
  130. default=0,
  131. help="Limit number of queries (0 disables)",
  132. )
  133. args = parser.parse_args()
  134. # Load queries
  135. queries = QUERIES
  136. if args.queries:
  137. with open(args.queries) as f:
  138. queries = [l.strip() for l in f if l.strip() and not l.startswith("#")]
  139. if args.max_queries and args.max_queries > 0:
  140. queries = queries[: args.max_queries]
  141. # Load model
  142. model, tokenizer = load_model(args.model)
  143. # Run eval
  144. scores = []
  145. batch_size = max(1, args.batch_size)
  146. total = len(queries)
  147. for start in range(0, total, batch_size):
  148. batch = queries[start : start + batch_size]
  149. batch_outputs = generate_batch(
  150. model, tokenizer, batch, args.max_new_tokens, args.max_time
  151. )
  152. for i, (query, expansion) in enumerate(zip(batch, batch_outputs), start + 1):
  153. print(f"\n[{i}/{total}] {query}")
  154. print("-" * 50)
  155. result = score_expansion_detailed(query, expansion)
  156. print(expansion[:300] + ("..." if len(expansion) > 300 else ""))
  157. print(f"Score: {result['percentage']:.0f}% ({result['rating']})")
  158. scores.append(result["percentage"])
  159. # Summary
  160. avg = sum(scores) / len(scores)
  161. print(f"\n{'=' * 50}")
  162. print(f"Average: {avg:.1f}% | Model: {args.model}")
  163. print(f"{'=' * 50}")
  164. return 0 if avg >= 50 else 1
  165. if __name__ == "__main__":
  166. sys.exit(main())