|
|
@@ -5,16 +5,19 @@
|
|
|
# "peft>=0.7.0",
|
|
|
# "torch",
|
|
|
# "huggingface_hub",
|
|
|
+# "accelerate",
|
|
|
# ]
|
|
|
# ///
|
|
|
"""
|
|
|
Evaluate QMD query expansion model quality.
|
|
|
|
|
|
-Generates expansions for test queries and outputs results for review.
|
|
|
+See SCORING.md for detailed scoring criteria.
|
|
|
"""
|
|
|
|
|
|
import json
|
|
|
+import re
|
|
|
import torch
|
|
|
+from collections import Counter
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
from peft import PeftModel
|
|
|
|
|
|
@@ -26,28 +29,23 @@ TEST_QUERIES = [
|
|
|
"docker compose networking",
|
|
|
"git rebase vs merge",
|
|
|
"react useEffect cleanup",
|
|
|
-
|
|
|
# Short/ambiguous queries
|
|
|
"auth",
|
|
|
"config",
|
|
|
"setup",
|
|
|
"api",
|
|
|
-
|
|
|
# Personal notes / journals style
|
|
|
"meeting notes project kickoff",
|
|
|
"ideas for new feature",
|
|
|
"todo list app architecture",
|
|
|
-
|
|
|
# Research / learning
|
|
|
"what is dependency injection",
|
|
|
"difference between sql and nosql",
|
|
|
"kubernetes vs docker swarm",
|
|
|
-
|
|
|
# Error/debugging
|
|
|
"connection timeout error",
|
|
|
"memory leak debugging",
|
|
|
"cors error fix",
|
|
|
-
|
|
|
# Complex queries
|
|
|
"how to implement caching with redis in nodejs",
|
|
|
"best practices for api rate limiting",
|
|
|
@@ -58,6 +56,237 @@ PROMPT_TEMPLATE = """Expand this search query:
|
|
|
|
|
|
{query}"""
|
|
|
|
|
|
+STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
|
|
|
+
|
|
|
+
|
|
|
+def parse_expansion(text: str) -> dict:
|
|
|
+ """Parse expansion into structured format."""
|
|
|
+ lines = text.strip().split("\n")
|
|
|
+ result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
|
|
+
|
|
|
+ for line in lines:
|
|
|
+ line = line.strip()
|
|
|
+ if not line:
|
|
|
+ continue
|
|
|
+ if line.startswith("lex:"):
|
|
|
+ result["lex"].append(line[4:].strip())
|
|
|
+ elif line.startswith("vec:"):
|
|
|
+ result["vec"].append(line[4:].strip())
|
|
|
+ elif line.startswith("hyde:"):
|
|
|
+ result["hyde"].append(line[5:].strip())
|
|
|
+ else:
|
|
|
+ result["invalid"].append(line)
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+
|
|
|
+def edit_distance_simple(a: str, b: str) -> int:
|
|
|
+ """Simple word-level edit distance."""
|
|
|
+ words_a = set(a.lower().split())
|
|
|
+ words_b = set(b.lower().split())
|
|
|
+ return len(words_a ^ words_b) # Symmetric difference
|
|
|
+
|
|
|
+
|
|
|
+def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
|
|
|
+ """Check if two strings are sufficiently different."""
|
|
|
+ a, b = a.lower().strip(), b.lower().strip()
|
|
|
+ if a == b:
|
|
|
+ return False
|
|
|
+ if a in b or b in a:
|
|
|
+ return False
|
|
|
+ return edit_distance_simple(a, b) >= min_distance
|
|
|
+
|
|
|
+
|
|
|
+def echoes_query(expansion: str, query: str) -> bool:
|
|
|
+ """Check if expansion is just echoing the query."""
|
|
|
+ exp = expansion.lower().strip()
|
|
|
+ q = query.lower().strip()
|
|
|
+ # Exact match or very close
|
|
|
+ if exp == q:
|
|
|
+ return True
|
|
|
+ # Query is contained in expansion with little else
|
|
|
+ if q in exp and len(exp) < len(q) + 10:
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def word_repetition_penalty(text: str) -> int:
|
|
|
+ """Count penalty for repeated words (excluding stopwords)."""
|
|
|
+ words = re.findall(r'\b\w+\b', text.lower())
|
|
|
+ counts = Counter(words)
|
|
|
+ penalty = 0
|
|
|
+ for word, count in counts.items():
|
|
|
+ if count >= 3 and word not in STOPWORDS and len(word) > 2:
|
|
|
+ penalty += (count - 2) * 2
|
|
|
+ return penalty
|
|
|
+
|
|
|
+
|
|
|
+def score_expansion(query: str, expansion: str) -> dict:
|
|
|
+ """
|
|
|
+ Score an expansion based on SCORING.md criteria.
|
|
|
+
|
|
|
+ Returns dict with score breakdown and total (0-100).
|
|
|
+ """
|
|
|
+ parsed = parse_expansion(expansion)
|
|
|
+ scores = {
|
|
|
+ "format": 0,
|
|
|
+ "diversity": 0,
|
|
|
+ "hyde": 0,
|
|
|
+ "quality": 0,
|
|
|
+ "deductions": [],
|
|
|
+ }
|
|
|
+
|
|
|
+ # === FORMAT (0-30) ===
|
|
|
+ format_score = 0
|
|
|
+
|
|
|
+ # Has at least one lex: line (+10)
|
|
|
+ if parsed["lex"]:
|
|
|
+ format_score += 10
|
|
|
+ else:
|
|
|
+ scores["deductions"].append("missing lex: (-10)")
|
|
|
+
|
|
|
+ # Has at least one vec: line (+10)
|
|
|
+ if parsed["vec"]:
|
|
|
+ format_score += 10
|
|
|
+ else:
|
|
|
+ scores["deductions"].append("missing vec: (-10)")
|
|
|
+
|
|
|
+ # All lines have valid prefix (+10, -5 per invalid)
|
|
|
+ if not parsed["invalid"]:
|
|
|
+ format_score += 10
|
|
|
+ else:
|
|
|
+ invalid_penalty = min(10, len(parsed["invalid"]) * 5)
|
|
|
+ format_score += (10 - invalid_penalty)
|
|
|
+ scores["deductions"].append(f"{len(parsed['invalid'])} invalid lines (-{invalid_penalty})")
|
|
|
+
|
|
|
+ scores["format"] = max(0, format_score)
|
|
|
+
|
|
|
+ # === DIVERSITY (0-30) ===
|
|
|
+ diversity_score = 0
|
|
|
+
|
|
|
+ # 2+ different types present (+10)
|
|
|
+ types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
|
|
|
+ if types_present >= 2:
|
|
|
+ diversity_score += 10
|
|
|
+ else:
|
|
|
+ scores["deductions"].append("only one type present (-10)")
|
|
|
+
|
|
|
+ # 2+ total expansions (+5)
|
|
|
+ total_expansions = len(parsed["lex"]) + len(parsed["vec"])
|
|
|
+ if total_expansions >= 2:
|
|
|
+ diversity_score += 5
|
|
|
+ else:
|
|
|
+ scores["deductions"].append("fewer than 2 expansions (-5)")
|
|
|
+
|
|
|
+ # Multiple lex: lines are diverse (+5, -2 per duplicate pair)
|
|
|
+ lex_diverse_score = 5
|
|
|
+ for i, a in enumerate(parsed["lex"]):
|
|
|
+ for b in parsed["lex"][i+1:]:
|
|
|
+ if not is_diverse(a, b, min_distance=2):
|
|
|
+ lex_diverse_score -= 2
|
|
|
+ scores["deductions"].append(f"lex duplicates: '{a[:20]}...' ~ '{b[:20]}...'")
|
|
|
+ diversity_score += max(0, lex_diverse_score)
|
|
|
+
|
|
|
+ # Multiple vec: lines are diverse (+5, -2 per duplicate pair)
|
|
|
+ vec_diverse_score = 5
|
|
|
+ for i, a in enumerate(parsed["vec"]):
|
|
|
+ for b in parsed["vec"][i+1:]:
|
|
|
+ if not is_diverse(a, b, min_distance=3):
|
|
|
+ vec_diverse_score -= 2
|
|
|
+ scores["deductions"].append(f"vec duplicates: '{a[:20]}...' ~ '{b[:20]}...'")
|
|
|
+ diversity_score += max(0, vec_diverse_score)
|
|
|
+
|
|
|
+ # lex/vec not identical to original query (+5, -5 per echo)
|
|
|
+ echo_score = 5
|
|
|
+ for exp in parsed["lex"] + parsed["vec"]:
|
|
|
+ if echoes_query(exp, query):
|
|
|
+ echo_score -= 5
|
|
|
+ scores["deductions"].append(f"echoes query: '{exp[:30]}...'")
|
|
|
+ diversity_score += max(0, echo_score)
|
|
|
+
|
|
|
+ scores["diversity"] = max(0, diversity_score)
|
|
|
+
|
|
|
+ # === HYDE QUALITY (0-20, optional bonus) ===
|
|
|
+ hyde_score = 0
|
|
|
+
|
|
|
+ if parsed["hyde"]:
|
|
|
+ hyde_text = parsed["hyde"][0] # Only first hyde counts
|
|
|
+
|
|
|
+ # Hyde present and well-formed (+5)
|
|
|
+ hyde_score += 5
|
|
|
+
|
|
|
+ # Hyde is concise: 50-200 chars (+5)
|
|
|
+ hyde_len = len(hyde_text)
|
|
|
+ if 50 <= hyde_len <= 200:
|
|
|
+ hyde_score += 5
|
|
|
+ elif hyde_len < 50:
|
|
|
+ hyde_score += 2
|
|
|
+ scores["deductions"].append(f"hyde too short ({hyde_len} chars)")
|
|
|
+ else:
|
|
|
+ scores["deductions"].append(f"hyde too long ({hyde_len} chars)")
|
|
|
+
|
|
|
+ # Hyde has no newlines (+5)
|
|
|
+ if "\n" not in hyde_text:
|
|
|
+ hyde_score += 5
|
|
|
+ else:
|
|
|
+ scores["deductions"].append("hyde contains newlines")
|
|
|
+
|
|
|
+ # Hyde has no excessive repetition (+5)
|
|
|
+ rep_penalty = word_repetition_penalty(hyde_text)
|
|
|
+ if rep_penalty == 0:
|
|
|
+ hyde_score += 5
|
|
|
+ else:
|
|
|
+ hyde_score += max(0, 5 - rep_penalty)
|
|
|
+ scores["deductions"].append(f"hyde repetition penalty (-{min(5, rep_penalty)})")
|
|
|
+
|
|
|
+ scores["hyde"] = hyde_score
|
|
|
+
|
|
|
+ # === QUALITY (0-20) ===
|
|
|
+ quality_score = 10 # Base relevance (assume relevant unless obvious garbage)
|
|
|
+
|
|
|
+ # Lex lines should be keyword-focused (shorter than vec on average)
|
|
|
+ if parsed["lex"] and parsed["vec"]:
|
|
|
+ avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
|
|
|
+ avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
|
|
|
+ if avg_lex <= avg_vec:
|
|
|
+ quality_score += 5
|
|
|
+ else:
|
|
|
+ scores["deductions"].append("lex longer than vec (should be keywords)")
|
|
|
+ else:
|
|
|
+ quality_score += 2 # Partial credit
|
|
|
+
|
|
|
+ # Vec lines should be natural language (contain spaces, longer)
|
|
|
+ if parsed["vec"]:
|
|
|
+ vec_natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
|
|
|
+ if vec_natural == len(parsed["vec"]):
|
|
|
+ quality_score += 5
|
|
|
+ else:
|
|
|
+ quality_score += 2
|
|
|
+ scores["deductions"].append("some vec lines too short/keyword-like")
|
|
|
+
|
|
|
+ scores["quality"] = quality_score
|
|
|
+
|
|
|
+ # === TOTAL ===
|
|
|
+ scores["total"] = scores["format"] + scores["diversity"] + scores["hyde"] + scores["quality"]
|
|
|
+ scores["max_possible"] = 100 if parsed["hyde"] else 80
|
|
|
+ scores["percentage"] = scores["total"] / scores["max_possible"] * 100
|
|
|
+
|
|
|
+ # Rating
|
|
|
+ pct = scores["percentage"]
|
|
|
+ if pct >= 80:
|
|
|
+ scores["rating"] = "Excellent"
|
|
|
+ elif pct >= 60:
|
|
|
+ scores["rating"] = "Good"
|
|
|
+ elif pct >= 40:
|
|
|
+ scores["rating"] = "Acceptable"
|
|
|
+ elif pct >= 20:
|
|
|
+ scores["rating"] = "Poor"
|
|
|
+ else:
|
|
|
+ scores["rating"] = "Failed"
|
|
|
+
|
|
|
+ scores["parsed"] = parsed
|
|
|
+ return scores
|
|
|
+
|
|
|
|
|
|
def load_model(model_name: str, base_model: str = "Qwen/Qwen3-0.6B"):
|
|
|
"""Load the finetuned model."""
|
|
|
@@ -96,45 +325,82 @@ def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200)
|
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
|
)
|
|
|
|
|
|
- # Decode and extract just the generated part
|
|
|
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
-
|
|
|
- # Remove the prompt to get just the expansion
|
|
|
- if "Output:" in full_output:
|
|
|
- expansion = full_output.split("Output:")[-1].strip()
|
|
|
- else:
|
|
|
- expansion = full_output[len(prompt):].strip()
|
|
|
+ expansion = full_output[len(prompt):].strip()
|
|
|
|
|
|
return expansion
|
|
|
|
|
|
|
|
|
-def evaluate_expansion(query: str, expansion: str) -> dict:
|
|
|
- """Basic automatic evaluation metrics."""
|
|
|
- lines = expansion.strip().split("\n")
|
|
|
-
|
|
|
- has_lex = any(l.strip().startswith("lex:") for l in lines)
|
|
|
- has_vec = any(l.strip().startswith("vec:") for l in lines)
|
|
|
- has_hyde = any(l.strip().startswith("hyde:") for l in lines)
|
|
|
-
|
|
|
- # Count valid lines
|
|
|
- valid_lines = sum(1 for l in lines if l.strip().startswith(("lex:", "vec:", "hyde:")))
|
|
|
-
|
|
|
- # Check for repetition
|
|
|
- contents = []
|
|
|
- for l in lines:
|
|
|
- if ":" in l:
|
|
|
- contents.append(l.split(":", 1)[1].strip().lower())
|
|
|
- unique_contents = len(set(contents))
|
|
|
-
|
|
|
- return {
|
|
|
- "has_lex": has_lex,
|
|
|
- "has_vec": has_vec,
|
|
|
- "has_hyde": has_hyde,
|
|
|
- "valid_lines": valid_lines,
|
|
|
- "total_lines": len(lines),
|
|
|
- "unique_contents": unique_contents,
|
|
|
- "format_score": (has_lex + has_vec + has_hyde) / 3,
|
|
|
- }
|
|
|
+def print_score_breakdown(scores: dict):
|
|
|
+ """Pretty print score breakdown."""
|
|
|
+ print(f" Score: {scores['total']}/{scores['max_possible']} ({scores['percentage']:.0f}%) - {scores['rating']}")
|
|
|
+ print(f" Format: {scores['format']}/30")
|
|
|
+ print(f" Diversity: {scores['diversity']}/30")
|
|
|
+ print(f" Hyde: {scores['hyde']}/20")
|
|
|
+ print(f" Quality: {scores['quality']}/20")
|
|
|
+ if scores["deductions"]:
|
|
|
+ print(f" Deductions:")
|
|
|
+ for d in scores["deductions"][:5]: # Show top 5
|
|
|
+ print(f" - {d}")
|
|
|
+ if len(scores["deductions"]) > 5:
|
|
|
+ print(f" ... and {len(scores['deductions']) - 5} more")
|
|
|
+
|
|
|
+
|
|
|
+def run_examples():
|
|
|
+ """Run good and bad examples to demonstrate scoring."""
|
|
|
+ print("=" * 70)
|
|
|
+ print("SCORING EXAMPLES")
|
|
|
+ print("=" * 70)
|
|
|
+
|
|
|
+ # Good example
|
|
|
+ good_expansion = """lex: react hooks tutorial
|
|
|
+lex: usestate useeffect
|
|
|
+vec: how to use react hooks in functional components
|
|
|
+vec: react hooks best practices guide
|
|
|
+hyde: React Hooks allow you to use state and lifecycle features in functional components without writing a class."""
|
|
|
+
|
|
|
+ print("\n[GOOD EXAMPLE]")
|
|
|
+ print(f"Query: react hooks")
|
|
|
+ print(f"Output:\n{good_expansion}")
|
|
|
+ scores = score_expansion("react hooks", good_expansion)
|
|
|
+ print_score_breakdown(scores)
|
|
|
+
|
|
|
+ # Bad example
|
|
|
+ bad_expansion = """auth is an important concept that relates to authentication.
|
|
|
+The answer should be in Chinese.
|
|
|
+The answer should be in Chinese."""
|
|
|
+
|
|
|
+ print("\n[BAD EXAMPLE]")
|
|
|
+ print(f"Query: auth")
|
|
|
+ print(f"Output:\n{bad_expansion}")
|
|
|
+ scores = score_expansion("auth", bad_expansion)
|
|
|
+ print_score_breakdown(scores)
|
|
|
+
|
|
|
+ # Medium example - repetitive hyde
|
|
|
+ medium_expansion = """lex: docker networking
|
|
|
+vec: docker networking
|
|
|
+hyde: Docker networking is an important concept. Docker networking is used for container communication. Docker networking configuration is essential."""
|
|
|
+
|
|
|
+ print("\n[MEDIUM EXAMPLE - Repetitive]")
|
|
|
+ print(f"Query: docker networking")
|
|
|
+ print(f"Output:\n{medium_expansion}")
|
|
|
+ scores = score_expansion("docker networking", medium_expansion)
|
|
|
+ print_score_breakdown(scores)
|
|
|
+
|
|
|
+ # Medium example - echoes query
|
|
|
+ echo_expansion = """lex: auth
|
|
|
+lex: authentication
|
|
|
+vec: auth
|
|
|
+vec: authentication configuration
|
|
|
+hyde: Authentication is the process of verifying identity."""
|
|
|
+
|
|
|
+ print("\n[MEDIUM EXAMPLE - Echoes Query]")
|
|
|
+ print(f"Query: auth")
|
|
|
+ print(f"Output:\n{echo_expansion}")
|
|
|
+ scores = score_expansion("auth", echo_expansion)
|
|
|
+ print_score_breakdown(scores)
|
|
|
+
|
|
|
+ print("\n" + "=" * 70)
|
|
|
|
|
|
|
|
|
def main():
|
|
|
@@ -147,8 +413,14 @@ def main():
|
|
|
parser.add_argument("--output", default="evaluation_results.json",
|
|
|
help="Output file for results")
|
|
|
parser.add_argument("--queries", type=str, help="Custom queries file (one per line)")
|
|
|
+ parser.add_argument("--examples", action="store_true", help="Run scoring examples only")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
+ # Run examples if requested
|
|
|
+ if args.examples:
|
|
|
+ run_examples()
|
|
|
+ return
|
|
|
+
|
|
|
# Load custom queries if provided
|
|
|
queries = TEST_QUERIES
|
|
|
if args.queries:
|
|
|
@@ -169,19 +441,18 @@ def main():
|
|
|
print("-" * 50)
|
|
|
|
|
|
expansion = generate_expansion(model, tokenizer, query)
|
|
|
- metrics = evaluate_expansion(query, expansion)
|
|
|
+ scores = score_expansion(query, expansion)
|
|
|
|
|
|
print(expansion)
|
|
|
- print(f"\n Format: {'✓' if metrics['format_score'] == 1.0 else '⚠'} "
|
|
|
- f"(lex:{metrics['has_lex']}, vec:{metrics['has_vec']}, hyde:{metrics['has_hyde']})")
|
|
|
- print(f" Lines: {metrics['valid_lines']}/{metrics['total_lines']} valid, "
|
|
|
- f"{metrics['unique_contents']} unique")
|
|
|
+ print()
|
|
|
+ print_score_breakdown(scores)
|
|
|
print()
|
|
|
|
|
|
results.append({
|
|
|
"query": query,
|
|
|
"expansion": expansion,
|
|
|
- "metrics": metrics,
|
|
|
+ "scores": {k: v for k, v in scores.items() if k != "parsed"},
|
|
|
+ "parsed": scores["parsed"],
|
|
|
})
|
|
|
|
|
|
# Summary
|
|
|
@@ -189,12 +460,21 @@ def main():
|
|
|
print("SUMMARY")
|
|
|
print(f"{'='*70}")
|
|
|
|
|
|
- avg_format = sum(r["metrics"]["format_score"] for r in results) / len(results)
|
|
|
- full_format = sum(1 for r in results if r["metrics"]["format_score"] == 1.0)
|
|
|
+ avg_score = sum(r["scores"]["percentage"] for r in results) / len(results)
|
|
|
+ excellent = sum(1 for r in results if r["scores"]["rating"] == "Excellent")
|
|
|
+ good = sum(1 for r in results if r["scores"]["rating"] == "Good")
|
|
|
+ acceptable = sum(1 for r in results if r["scores"]["rating"] == "Acceptable")
|
|
|
+ poor = sum(1 for r in results if r["scores"]["rating"] == "Poor")
|
|
|
+ failed = sum(1 for r in results if r["scores"]["rating"] == "Failed")
|
|
|
|
|
|
print(f" Total queries: {len(results)}")
|
|
|
- print(f" Average format score: {avg_format:.2%}")
|
|
|
- print(f" Full format compliance: {full_format}/{len(results)} ({full_format/len(results):.0%})")
|
|
|
+ print(f" Average score: {avg_score:.1f}%")
|
|
|
+ print(f" Ratings:")
|
|
|
+ print(f" Excellent: {excellent}")
|
|
|
+ print(f" Good: {good}")
|
|
|
+ print(f" Acceptable: {acceptable}")
|
|
|
+ print(f" Poor: {poor}")
|
|
|
+ print(f" Failed: {failed}")
|
|
|
|
|
|
# Save results
|
|
|
with open(args.output, "w") as f:
|