evaluate_baseline.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "torch",
  6. # "huggingface_hub",
  7. # "accelerate",
  8. # ]
  9. # ///
  10. """
  11. Evaluate base model (untrained) for comparison.
  12. """
  13. import json
  14. import torch
  15. from transformers import AutoModelForCausalLM, AutoTokenizer
  16. # Test queries covering different QMD use cases
  17. TEST_QUERIES = [
  18. "how to configure authentication",
  19. "typescript async await",
  20. "docker compose networking",
  21. "git rebase vs merge",
  22. "react useEffect cleanup",
  23. "auth",
  24. "config",
  25. "setup",
  26. "api",
  27. "meeting notes project kickoff",
  28. "ideas for new feature",
  29. "todo list app architecture",
  30. "what is dependency injection",
  31. "difference between sql and nosql",
  32. "kubernetes vs docker swarm",
  33. "connection timeout error",
  34. "memory leak debugging",
  35. "cors error fix",
  36. "how to implement caching with redis in nodejs",
  37. "best practices for api rate limiting",
  38. "setting up ci cd pipeline with github actions",
  39. ]
  40. PROMPT_TEMPLATE = """Expand this search query:
  41. {query}"""
  42. def load_model(model_name: str):
  43. """Load the base model without adapter."""
  44. print(f"Loading tokenizer and model from {model_name}...")
  45. tokenizer = AutoTokenizer.from_pretrained(model_name)
  46. if tokenizer.pad_token is None:
  47. tokenizer.pad_token = tokenizer.eos_token
  48. model = AutoModelForCausalLM.from_pretrained(
  49. model_name,
  50. torch_dtype=torch.bfloat16,
  51. device_map="auto",
  52. )
  53. model.eval()
  54. return model, tokenizer
  55. def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200) -> str:
  56. """Generate query expansion."""
  57. prompt = PROMPT_TEMPLATE.format(query=query)
  58. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  59. with torch.no_grad():
  60. outputs = model.generate(
  61. **inputs,
  62. max_new_tokens=max_new_tokens,
  63. temperature=0.7,
  64. do_sample=True,
  65. pad_token_id=tokenizer.pad_token_id,
  66. eos_token_id=tokenizer.eos_token_id,
  67. )
  68. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  69. expansion = full_output[len(prompt):].strip()
  70. return expansion
  71. def evaluate_expansion(query: str, expansion: str) -> dict:
  72. """Basic automatic evaluation metrics."""
  73. lines = expansion.strip().split("\n")
  74. has_lex = any(l.strip().startswith("lex:") for l in lines)
  75. has_vec = any(l.strip().startswith("vec:") for l in lines)
  76. has_hyde = any(l.strip().startswith("hyde:") for l in lines)
  77. valid_lines = sum(1 for l in lines if l.strip().startswith(("lex:", "vec:", "hyde:")))
  78. contents = []
  79. for l in lines:
  80. if ":" in l:
  81. contents.append(l.split(":", 1)[1].strip().lower())
  82. unique_contents = len(set(contents))
  83. return {
  84. "has_lex": has_lex,
  85. "has_vec": has_vec,
  86. "has_hyde": has_hyde,
  87. "valid_lines": valid_lines,
  88. "total_lines": len(lines),
  89. "unique_contents": unique_contents,
  90. "format_score": (has_lex + has_vec + has_hyde) / 3,
  91. }
  92. def main():
  93. import argparse
  94. parser = argparse.ArgumentParser()
  95. parser.add_argument("--model", default="Qwen/Qwen3-0.6B",
  96. help="Base model to evaluate")
  97. parser.add_argument("--output", default="evaluation_baseline.json",
  98. help="Output file for results")
  99. parser.add_argument("--num-queries", type=int, default=5,
  100. help="Number of queries to test (for quick baseline)")
  101. args = parser.parse_args()
  102. model, tokenizer = load_model(args.model)
  103. queries = TEST_QUERIES[:args.num_queries]
  104. results = []
  105. print(f"\n{'='*70}")
  106. print("BASELINE EVALUATION RESULTS")
  107. print(f"{'='*70}\n")
  108. for i, query in enumerate(queries, 1):
  109. print(f"[{i}/{len(queries)}] Query: {query}")
  110. print("-" * 50)
  111. expansion = generate_expansion(model, tokenizer, query)
  112. metrics = evaluate_expansion(query, expansion)
  113. print(expansion[:500] + "..." if len(expansion) > 500 else expansion)
  114. print(f"\n Format: {'✓' if metrics['format_score'] == 1.0 else '⚠'} "
  115. f"(lex:{metrics['has_lex']}, vec:{metrics['has_vec']}, hyde:{metrics['has_hyde']})")
  116. print()
  117. results.append({
  118. "query": query,
  119. "expansion": expansion,
  120. "metrics": metrics,
  121. })
  122. print(f"\n{'='*70}")
  123. print("SUMMARY")
  124. print(f"{'='*70}")
  125. avg_format = sum(r["metrics"]["format_score"] for r in results) / len(results)
  126. full_format = sum(1 for r in results if r["metrics"]["format_score"] == 1.0)
  127. print(f" Total queries: {len(results)}")
  128. print(f" Average format score: {avg_format:.2%}")
  129. print(f" Full format compliance: {full_format}/{len(results)} ({full_format/len(results):.0%})")
  130. with open(args.output, "w") as f:
  131. json.dump(results, f, indent=2)
  132. print(f"\n Results saved to: {args.output}")
  133. if __name__ == "__main__":
  134. main()