evaluate_model.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "huggingface_hub",
  8. # ]
  9. # ///
  10. """
  11. Evaluate QMD query expansion model quality.
  12. Generates expansions for test queries and outputs results for review.
  13. """
  14. import json
  15. import torch
  16. from transformers import AutoModelForCausalLM, AutoTokenizer
  17. from peft import PeftModel
  18. # Test queries covering different QMD use cases
  19. TEST_QUERIES = [
  20. # Technical documentation
  21. "how to configure authentication",
  22. "typescript async await",
  23. "docker compose networking",
  24. "git rebase vs merge",
  25. "react useEffect cleanup",
  26. # Short/ambiguous queries
  27. "auth",
  28. "config",
  29. "setup",
  30. "api",
  31. # Personal notes / journals style
  32. "meeting notes project kickoff",
  33. "ideas for new feature",
  34. "todo list app architecture",
  35. # Research / learning
  36. "what is dependency injection",
  37. "difference between sql and nosql",
  38. "kubernetes vs docker swarm",
  39. # Error/debugging
  40. "connection timeout error",
  41. "memory leak debugging",
  42. "cors error fix",
  43. # Complex queries
  44. "how to implement caching with redis in nodejs",
  45. "best practices for api rate limiting",
  46. "setting up ci cd pipeline with github actions",
  47. ]
  48. PROMPT_TEMPLATE = """Expand this search query:
  49. {query}"""
  50. def load_model(model_name: str, base_model: str = "Qwen/Qwen3-0.6B"):
  51. """Load the finetuned model."""
  52. print(f"Loading tokenizer from {base_model}...")
  53. tokenizer = AutoTokenizer.from_pretrained(base_model)
  54. if tokenizer.pad_token is None:
  55. tokenizer.pad_token = tokenizer.eos_token
  56. print(f"Loading base model...")
  57. base = AutoModelForCausalLM.from_pretrained(
  58. base_model,
  59. torch_dtype=torch.bfloat16,
  60. device_map="auto",
  61. )
  62. print(f"Loading adapter from {model_name}...")
  63. model = PeftModel.from_pretrained(base, model_name)
  64. model.eval()
  65. return model, tokenizer
  66. def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200) -> str:
  67. """Generate query expansion."""
  68. prompt = PROMPT_TEMPLATE.format(query=query)
  69. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  70. with torch.no_grad():
  71. outputs = model.generate(
  72. **inputs,
  73. max_new_tokens=max_new_tokens,
  74. temperature=0.7,
  75. do_sample=True,
  76. pad_token_id=tokenizer.pad_token_id,
  77. eos_token_id=tokenizer.eos_token_id,
  78. )
  79. # Decode and extract just the generated part
  80. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  81. # Remove the prompt to get just the expansion
  82. if "Output:" in full_output:
  83. expansion = full_output.split("Output:")[-1].strip()
  84. else:
  85. expansion = full_output[len(prompt):].strip()
  86. return expansion
  87. def evaluate_expansion(query: str, expansion: str) -> dict:
  88. """Basic automatic evaluation metrics."""
  89. lines = expansion.strip().split("\n")
  90. has_lex = any(l.strip().startswith("lex:") for l in lines)
  91. has_vec = any(l.strip().startswith("vec:") for l in lines)
  92. has_hyde = any(l.strip().startswith("hyde:") for l in lines)
  93. # Count valid lines
  94. valid_lines = sum(1 for l in lines if l.strip().startswith(("lex:", "vec:", "hyde:")))
  95. # Check for repetition
  96. contents = []
  97. for l in lines:
  98. if ":" in l:
  99. contents.append(l.split(":", 1)[1].strip().lower())
  100. unique_contents = len(set(contents))
  101. return {
  102. "has_lex": has_lex,
  103. "has_vec": has_vec,
  104. "has_hyde": has_hyde,
  105. "valid_lines": valid_lines,
  106. "total_lines": len(lines),
  107. "unique_contents": unique_contents,
  108. "format_score": (has_lex + has_vec + has_hyde) / 3,
  109. }
  110. def main():
  111. import argparse
  112. parser = argparse.ArgumentParser()
  113. parser.add_argument("--model", default="tobil/qmd-query-expansion-0.6B",
  114. help="Model to evaluate")
  115. parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
  116. help="Base model")
  117. parser.add_argument("--output", default="evaluation_results.json",
  118. help="Output file for results")
  119. parser.add_argument("--queries", type=str, help="Custom queries file (one per line)")
  120. args = parser.parse_args()
  121. # Load custom queries if provided
  122. queries = TEST_QUERIES
  123. if args.queries:
  124. with open(args.queries) as f:
  125. queries = [l.strip() for l in f if l.strip()]
  126. # Load model
  127. model, tokenizer = load_model(args.model, args.base_model)
  128. # Run evaluation
  129. results = []
  130. print(f"\n{'='*70}")
  131. print("EVALUATION RESULTS")
  132. print(f"{'='*70}\n")
  133. for i, query in enumerate(queries, 1):
  134. print(f"[{i}/{len(queries)}] Query: {query}")
  135. print("-" * 50)
  136. expansion = generate_expansion(model, tokenizer, query)
  137. metrics = evaluate_expansion(query, expansion)
  138. print(expansion)
  139. print(f"\n Format: {'✓' if metrics['format_score'] == 1.0 else '⚠'} "
  140. f"(lex:{metrics['has_lex']}, vec:{metrics['has_vec']}, hyde:{metrics['has_hyde']})")
  141. print(f" Lines: {metrics['valid_lines']}/{metrics['total_lines']} valid, "
  142. f"{metrics['unique_contents']} unique")
  143. print()
  144. results.append({
  145. "query": query,
  146. "expansion": expansion,
  147. "metrics": metrics,
  148. })
  149. # Summary
  150. print(f"\n{'='*70}")
  151. print("SUMMARY")
  152. print(f"{'='*70}")
  153. avg_format = sum(r["metrics"]["format_score"] for r in results) / len(results)
  154. full_format = sum(1 for r in results if r["metrics"]["format_score"] == 1.0)
  155. print(f" Total queries: {len(results)}")
  156. print(f" Average format score: {avg_format:.2%}")
  157. print(f" Full format compliance: {full_format}/{len(results)} ({full_format/len(results):.0%})")
  158. # Save results
  159. with open(args.output, "w") as f:
  160. json.dump(results, f, indent=2)
  161. print(f"\n Results saved to: {args.output}")
  162. if __name__ == "__main__":
  163. main()