Browse Source

Add comprehensive scoring system for query expansion

New scoring criteria (0-100 points):
- Format (30): Must have lex: and vec: prefixes
- Diversity (30): Multiple types, no echoing query, diverse expansions
- Hyde (20): Optional, concise, no newlines, no word repetition
- Quality (20): Lex=keywords, vec=natural language

See SCORING.md for full documentation.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 4 tháng trước cách đây
mục cha
commit
c35dbd6cbd
2 tập tin đã thay đổi với 506 bổ sung50 xóa
  1. 176 0
      finetune/SCORING.md
  2. 330 50
      finetune/evaluate_model.py

+ 176 - 0
finetune/SCORING.md

@@ -0,0 +1,176 @@
+# QMD Query Expansion Scoring
+
+## Goal
+
+Transform a random typed query into a great set of retrieval-optimized expansions.
+
+**Input:** `"auth config"`
+**Output:**
+```
+lex: authentication configuration
+lex: auth settings setup
+vec: how to configure authentication settings
+vec: authentication configuration options
+hyde: Authentication can be configured by setting the AUTH_SECRET environment variable and enabling the auth middleware in your application's config file.
+```
+
+## Output Format
+
+| Prefix | Purpose | Required | Count |
+|--------|---------|----------|-------|
+| `lex:` | BM25 keyword variations (shorter, keyword-focused) | Yes | 1-3 |
+| `vec:` | Semantic reformulations (natural language) | Yes | 1-3 |
+| `hyde:` | Hypothetical document passage | Optional | 0-1 |
+
+## Scoring Criteria
+
+### 1. Format Compliance (0-30 points)
+
+| Criterion | Points | Deduction |
+|-----------|--------|-----------|
+| Has at least one `lex:` line | +10 | -10 if missing |
+| Has at least one `vec:` line | +10 | -10 if missing |
+| All lines have valid prefix (`lex:`, `vec:`, `hyde:`) | +10 | -5 per invalid line |
+| No garbage/prose outside of prefixed lines | - | -10 if present |
+
+### 2. Diversity & Coverage (0-30 points)
+
+| Criterion | Points | Deduction |
+|-----------|--------|-----------|
+| 2+ different types present (lex + vec) | +10 | -10 if only one type |
+| 2+ total expansions | +5 | -5 if only one |
+| Multiple lex: lines are diverse (edit distance > 3) | +5 | -2 per duplicate pair |
+| Multiple vec: lines are diverse (edit distance > 5) | +5 | -2 per duplicate pair |
+| lex/vec not identical to original query | +5 | -5 per line that equals query |
+
+### 3. Hyde Quality (0-20 points, optional bonus)
+
+| Criterion | Points | Deduction |
+|-----------|--------|-----------|
+| Hyde present and well-formed | +5 | - |
+| Hyde is concise (50-200 chars) | +5 | -3 if too short, -5 if too long |
+| Hyde has no newlines | +5 | -5 if contains newlines |
+| Hyde has no excessive repetition | +5 | -3 if word repeats 3+ times |
+
+### 4. Content Quality (0-20 points)
+
+| Criterion | Points | Deduction |
+|-----------|--------|-----------|
+| Expansions are relevant to query | +10 | Subjective |
+| Lex lines are keyword-focused (shorter) | +5 | -2 if lex is longer than vec |
+| Vec lines are natural language (complete phrases) | +5 | -2 if vec is just keywords |
+
+## Score Calculation
+
+```
+Total Score = Format + Diversity + Hyde + Quality
+Max Score = 100 (80 without hyde)
+```
+
+**Rating:**
+- 80-100: Excellent
+- 60-79: Good
+- 40-59: Acceptable
+- 20-39: Poor
+- 0-19: Failed
+
+## Examples
+
+### Good Example (Score: 85/100)
+
+**Query:** `"react hooks"`
+
+**Output:**
+```
+lex: react hooks tutorial
+lex: usestate useeffect
+vec: how to use react hooks in functional components
+vec: react hooks best practices guide
+hyde: React Hooks allow you to use state and lifecycle features in functional components without writing a class.
+```
+
+**Scoring:**
+- Format: 30/30 (all prefixes valid)
+- Diversity: 28/30 (2 types, 4 expansions, good diversity)
+- Hyde: 17/20 (present, good length, no repetition)
+- Quality: 10/20 (relevant, good structure)
+
+### Bad Example (Score: 15/100)
+
+**Query:** `"auth"`
+
+**Output:**
+```
+auth is an important concept that relates to authentication.
+The answer should be in Chinese.
+The answer should be in Chinese.
+```
+
+**Scoring:**
+- Format: 0/30 (no valid prefixes)
+- Diversity: 0/30 (no valid expansions)
+- Hyde: 0/20 (N/A)
+- Quality: 15/20 (some relevance but wrong format)
+
+### Medium Example (Score: 55/100)
+
+**Query:** `"docker networking"`
+
+**Output:**
+```
+lex: docker networking
+vec: docker networking
+hyde: Docker networking is an important concept. Docker networking is used for container communication. Docker networking configuration is essential.
+```
+
+**Scoring:**
+- Format: 30/30 (valid prefixes)
+- Diversity: 10/30 (lex=vec=query, no diversity)
+- Hyde: 5/20 (too repetitive - "docker networking" 3x)
+- Quality: 10/20 (relevant but low effort)
+
+## Heuristics
+
+### Repetition Detection
+
+```python
+def word_repetition_score(text):
+    words = text.lower().split()
+    counts = Counter(words)
+    # Deduct for words appearing 3+ times (excluding stopwords)
+    stopwords = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or'}
+    repeated = sum(1 for w, c in counts.items() if c >= 3 and w not in stopwords)
+    return max(0, 5 - repeated * 2)
+```
+
+### Diversity Check (Simple)
+
+```python
+def is_diverse(a, b, min_distance=3):
+    """Check if two strings are sufficiently different."""
+    a, b = a.lower().strip(), b.lower().strip()
+    if a == b:
+        return False
+    # Simple: check if one is not a substring of the other
+    if a in b or b in a:
+        return False
+    # Check edit distance (simplified)
+    return len(set(a.split()) ^ set(b.split())) >= min_distance
+```
+
+### Query Echo Detection
+
+```python
+def echoes_query(expansion, query):
+    """Check if expansion is just echoing the query."""
+    exp = expansion.lower().strip()
+    q = query.lower().strip()
+    return exp == q or exp in q or q in exp
+```
+
+## Training Data Requirements
+
+1. **EOM tokens**: Ensure training examples end with proper end-of-message tokens
+2. **Diverse examples**: Include varied query types (short, long, technical, casual)
+3. **Quality hyde**: Hyde passages should be informative, not template-y
+4. **No repetition**: Avoid "This is important. This is very important." patterns

+ 330 - 50
finetune/evaluate_model.py

@@ -5,16 +5,19 @@
 #     "peft>=0.7.0",
 #     "torch",
 #     "huggingface_hub",
+#     "accelerate",
 # ]
 # ///
 """
 Evaluate QMD query expansion model quality.
 
-Generates expansions for test queries and outputs results for review.
+See SCORING.md for detailed scoring criteria.
 """
 
 import json
+import re
 import torch
+from collections import Counter
 from transformers import AutoModelForCausalLM, AutoTokenizer
 from peft import PeftModel
 
@@ -26,28 +29,23 @@ TEST_QUERIES = [
     "docker compose networking",
     "git rebase vs merge",
     "react useEffect cleanup",
-
     # Short/ambiguous queries
     "auth",
     "config",
     "setup",
     "api",
-
     # Personal notes / journals style
     "meeting notes project kickoff",
     "ideas for new feature",
     "todo list app architecture",
-
     # Research / learning
     "what is dependency injection",
     "difference between sql and nosql",
     "kubernetes vs docker swarm",
-
     # Error/debugging
     "connection timeout error",
     "memory leak debugging",
     "cors error fix",
-
     # Complex queries
     "how to implement caching with redis in nodejs",
     "best practices for api rate limiting",
@@ -58,6 +56,237 @@ PROMPT_TEMPLATE = """Expand this search query:
 
 {query}"""
 
+STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
+
+
+def parse_expansion(text: str) -> dict:
+    """Parse expansion into structured format."""
+    lines = text.strip().split("\n")
+    result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
+
+    for line in lines:
+        line = line.strip()
+        if not line:
+            continue
+        if line.startswith("lex:"):
+            result["lex"].append(line[4:].strip())
+        elif line.startswith("vec:"):
+            result["vec"].append(line[4:].strip())
+        elif line.startswith("hyde:"):
+            result["hyde"].append(line[5:].strip())
+        else:
+            result["invalid"].append(line)
+
+    return result
+
+
+def edit_distance_simple(a: str, b: str) -> int:
+    """Simple word-level edit distance."""
+    words_a = set(a.lower().split())
+    words_b = set(b.lower().split())
+    return len(words_a ^ words_b)  # Symmetric difference
+
+
+def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
+    """Check if two strings are sufficiently different."""
+    a, b = a.lower().strip(), b.lower().strip()
+    if a == b:
+        return False
+    if a in b or b in a:
+        return False
+    return edit_distance_simple(a, b) >= min_distance
+
+
+def echoes_query(expansion: str, query: str) -> bool:
+    """Check if expansion is just echoing the query."""
+    exp = expansion.lower().strip()
+    q = query.lower().strip()
+    # Exact match or very close
+    if exp == q:
+        return True
+    # Query is contained in expansion with little else
+    if q in exp and len(exp) < len(q) + 10:
+        return True
+    return False
+
+
+def word_repetition_penalty(text: str) -> int:
+    """Count penalty for repeated words (excluding stopwords)."""
+    words = re.findall(r'\b\w+\b', text.lower())
+    counts = Counter(words)
+    penalty = 0
+    for word, count in counts.items():
+        if count >= 3 and word not in STOPWORDS and len(word) > 2:
+            penalty += (count - 2) * 2
+    return penalty
+
+
+def score_expansion(query: str, expansion: str) -> dict:
+    """
+    Score an expansion based on SCORING.md criteria.
+
+    Returns dict with score breakdown and total (0-100).
+    """
+    parsed = parse_expansion(expansion)
+    scores = {
+        "format": 0,
+        "diversity": 0,
+        "hyde": 0,
+        "quality": 0,
+        "deductions": [],
+    }
+
+    # === FORMAT (0-30) ===
+    format_score = 0
+
+    # Has at least one lex: line (+10)
+    if parsed["lex"]:
+        format_score += 10
+    else:
+        scores["deductions"].append("missing lex: (-10)")
+
+    # Has at least one vec: line (+10)
+    if parsed["vec"]:
+        format_score += 10
+    else:
+        scores["deductions"].append("missing vec: (-10)")
+
+    # All lines have valid prefix (+10, -5 per invalid)
+    if not parsed["invalid"]:
+        format_score += 10
+    else:
+        invalid_penalty = min(10, len(parsed["invalid"]) * 5)
+        format_score += (10 - invalid_penalty)
+        scores["deductions"].append(f"{len(parsed['invalid'])} invalid lines (-{invalid_penalty})")
+
+    scores["format"] = max(0, format_score)
+
+    # === DIVERSITY (0-30) ===
+    diversity_score = 0
+
+    # 2+ different types present (+10)
+    types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
+    if types_present >= 2:
+        diversity_score += 10
+    else:
+        scores["deductions"].append("only one type present (-10)")
+
+    # 2+ total expansions (+5)
+    total_expansions = len(parsed["lex"]) + len(parsed["vec"])
+    if total_expansions >= 2:
+        diversity_score += 5
+    else:
+        scores["deductions"].append("fewer than 2 expansions (-5)")
+
+    # Multiple lex: lines are diverse (+5, -2 per duplicate pair)
+    lex_diverse_score = 5
+    for i, a in enumerate(parsed["lex"]):
+        for b in parsed["lex"][i+1:]:
+            if not is_diverse(a, b, min_distance=2):
+                lex_diverse_score -= 2
+                scores["deductions"].append(f"lex duplicates: '{a[:20]}...' ~ '{b[:20]}...'")
+    diversity_score += max(0, lex_diverse_score)
+
+    # Multiple vec: lines are diverse (+5, -2 per duplicate pair)
+    vec_diverse_score = 5
+    for i, a in enumerate(parsed["vec"]):
+        for b in parsed["vec"][i+1:]:
+            if not is_diverse(a, b, min_distance=3):
+                vec_diverse_score -= 2
+                scores["deductions"].append(f"vec duplicates: '{a[:20]}...' ~ '{b[:20]}...'")
+    diversity_score += max(0, vec_diverse_score)
+
+    # lex/vec not identical to original query (+5, -5 per echo)
+    echo_score = 5
+    for exp in parsed["lex"] + parsed["vec"]:
+        if echoes_query(exp, query):
+            echo_score -= 5
+            scores["deductions"].append(f"echoes query: '{exp[:30]}...'")
+    diversity_score += max(0, echo_score)
+
+    scores["diversity"] = max(0, diversity_score)
+
+    # === HYDE QUALITY (0-20, optional bonus) ===
+    hyde_score = 0
+
+    if parsed["hyde"]:
+        hyde_text = parsed["hyde"][0]  # Only first hyde counts
+
+        # Hyde present and well-formed (+5)
+        hyde_score += 5
+
+        # Hyde is concise: 50-200 chars (+5)
+        hyde_len = len(hyde_text)
+        if 50 <= hyde_len <= 200:
+            hyde_score += 5
+        elif hyde_len < 50:
+            hyde_score += 2
+            scores["deductions"].append(f"hyde too short ({hyde_len} chars)")
+        else:
+            scores["deductions"].append(f"hyde too long ({hyde_len} chars)")
+
+        # Hyde has no newlines (+5)
+        if "\n" not in hyde_text:
+            hyde_score += 5
+        else:
+            scores["deductions"].append("hyde contains newlines")
+
+        # Hyde has no excessive repetition (+5)
+        rep_penalty = word_repetition_penalty(hyde_text)
+        if rep_penalty == 0:
+            hyde_score += 5
+        else:
+            hyde_score += max(0, 5 - rep_penalty)
+            scores["deductions"].append(f"hyde repetition penalty (-{min(5, rep_penalty)})")
+
+    scores["hyde"] = hyde_score
+
+    # === QUALITY (0-20) ===
+    quality_score = 10  # Base relevance (assume relevant unless obvious garbage)
+
+    # Lex lines should be keyword-focused (shorter than vec on average)
+    if parsed["lex"] and parsed["vec"]:
+        avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
+        avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
+        if avg_lex <= avg_vec:
+            quality_score += 5
+        else:
+            scores["deductions"].append("lex longer than vec (should be keywords)")
+    else:
+        quality_score += 2  # Partial credit
+
+    # Vec lines should be natural language (contain spaces, longer)
+    if parsed["vec"]:
+        vec_natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
+        if vec_natural == len(parsed["vec"]):
+            quality_score += 5
+        else:
+            quality_score += 2
+            scores["deductions"].append("some vec lines too short/keyword-like")
+
+    scores["quality"] = quality_score
+
+    # === TOTAL ===
+    scores["total"] = scores["format"] + scores["diversity"] + scores["hyde"] + scores["quality"]
+    scores["max_possible"] = 100 if parsed["hyde"] else 80
+    scores["percentage"] = scores["total"] / scores["max_possible"] * 100
+
+    # Rating
+    pct = scores["percentage"]
+    if pct >= 80:
+        scores["rating"] = "Excellent"
+    elif pct >= 60:
+        scores["rating"] = "Good"
+    elif pct >= 40:
+        scores["rating"] = "Acceptable"
+    elif pct >= 20:
+        scores["rating"] = "Poor"
+    else:
+        scores["rating"] = "Failed"
+
+    scores["parsed"] = parsed
+    return scores
+
 
 def load_model(model_name: str, base_model: str = "Qwen/Qwen3-0.6B"):
     """Load the finetuned model."""
@@ -96,45 +325,82 @@ def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200)
             eos_token_id=tokenizer.eos_token_id,
         )
 
-    # Decode and extract just the generated part
     full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
-
-    # Remove the prompt to get just the expansion
-    if "Output:" in full_output:
-        expansion = full_output.split("Output:")[-1].strip()
-    else:
-        expansion = full_output[len(prompt):].strip()
+    expansion = full_output[len(prompt):].strip()
 
     return expansion
 
 
-def evaluate_expansion(query: str, expansion: str) -> dict:
-    """Basic automatic evaluation metrics."""
-    lines = expansion.strip().split("\n")
-
-    has_lex = any(l.strip().startswith("lex:") for l in lines)
-    has_vec = any(l.strip().startswith("vec:") for l in lines)
-    has_hyde = any(l.strip().startswith("hyde:") for l in lines)
-
-    # Count valid lines
-    valid_lines = sum(1 for l in lines if l.strip().startswith(("lex:", "vec:", "hyde:")))
-
-    # Check for repetition
-    contents = []
-    for l in lines:
-        if ":" in l:
-            contents.append(l.split(":", 1)[1].strip().lower())
-    unique_contents = len(set(contents))
-
-    return {
-        "has_lex": has_lex,
-        "has_vec": has_vec,
-        "has_hyde": has_hyde,
-        "valid_lines": valid_lines,
-        "total_lines": len(lines),
-        "unique_contents": unique_contents,
-        "format_score": (has_lex + has_vec + has_hyde) / 3,
-    }
+def print_score_breakdown(scores: dict):
+    """Pretty print score breakdown."""
+    print(f"  Score: {scores['total']}/{scores['max_possible']} ({scores['percentage']:.0f}%) - {scores['rating']}")
+    print(f"    Format:    {scores['format']}/30")
+    print(f"    Diversity: {scores['diversity']}/30")
+    print(f"    Hyde:      {scores['hyde']}/20")
+    print(f"    Quality:   {scores['quality']}/20")
+    if scores["deductions"]:
+        print(f"  Deductions:")
+        for d in scores["deductions"][:5]:  # Show top 5
+            print(f"    - {d}")
+        if len(scores["deductions"]) > 5:
+            print(f"    ... and {len(scores['deductions']) - 5} more")
+
+
+def run_examples():
+    """Run good and bad examples to demonstrate scoring."""
+    print("=" * 70)
+    print("SCORING EXAMPLES")
+    print("=" * 70)
+
+    # Good example
+    good_expansion = """lex: react hooks tutorial
+lex: usestate useeffect
+vec: how to use react hooks in functional components
+vec: react hooks best practices guide
+hyde: React Hooks allow you to use state and lifecycle features in functional components without writing a class."""
+
+    print("\n[GOOD EXAMPLE]")
+    print(f"Query: react hooks")
+    print(f"Output:\n{good_expansion}")
+    scores = score_expansion("react hooks", good_expansion)
+    print_score_breakdown(scores)
+
+    # Bad example
+    bad_expansion = """auth is an important concept that relates to authentication.
+The answer should be in Chinese.
+The answer should be in Chinese."""
+
+    print("\n[BAD EXAMPLE]")
+    print(f"Query: auth")
+    print(f"Output:\n{bad_expansion}")
+    scores = score_expansion("auth", bad_expansion)
+    print_score_breakdown(scores)
+
+    # Medium example - repetitive hyde
+    medium_expansion = """lex: docker networking
+vec: docker networking
+hyde: Docker networking is an important concept. Docker networking is used for container communication. Docker networking configuration is essential."""
+
+    print("\n[MEDIUM EXAMPLE - Repetitive]")
+    print(f"Query: docker networking")
+    print(f"Output:\n{medium_expansion}")
+    scores = score_expansion("docker networking", medium_expansion)
+    print_score_breakdown(scores)
+
+    # Medium example - echoes query
+    echo_expansion = """lex: auth
+lex: authentication
+vec: auth
+vec: authentication configuration
+hyde: Authentication is the process of verifying identity."""
+
+    print("\n[MEDIUM EXAMPLE - Echoes Query]")
+    print(f"Query: auth")
+    print(f"Output:\n{echo_expansion}")
+    scores = score_expansion("auth", echo_expansion)
+    print_score_breakdown(scores)
+
+    print("\n" + "=" * 70)
 
 
 def main():
@@ -147,8 +413,14 @@ def main():
     parser.add_argument("--output", default="evaluation_results.json",
                         help="Output file for results")
     parser.add_argument("--queries", type=str, help="Custom queries file (one per line)")
+    parser.add_argument("--examples", action="store_true", help="Run scoring examples only")
     args = parser.parse_args()
 
+    # Run examples if requested
+    if args.examples:
+        run_examples()
+        return
+
     # Load custom queries if provided
     queries = TEST_QUERIES
     if args.queries:
@@ -169,19 +441,18 @@ def main():
         print("-" * 50)
 
         expansion = generate_expansion(model, tokenizer, query)
-        metrics = evaluate_expansion(query, expansion)
+        scores = score_expansion(query, expansion)
 
         print(expansion)
-        print(f"\n  Format: {'✓' if metrics['format_score'] == 1.0 else '⚠'} "
-              f"(lex:{metrics['has_lex']}, vec:{metrics['has_vec']}, hyde:{metrics['has_hyde']})")
-        print(f"  Lines: {metrics['valid_lines']}/{metrics['total_lines']} valid, "
-              f"{metrics['unique_contents']} unique")
+        print()
+        print_score_breakdown(scores)
         print()
 
         results.append({
             "query": query,
             "expansion": expansion,
-            "metrics": metrics,
+            "scores": {k: v for k, v in scores.items() if k != "parsed"},
+            "parsed": scores["parsed"],
         })
 
     # Summary
@@ -189,12 +460,21 @@ def main():
     print("SUMMARY")
     print(f"{'='*70}")
 
-    avg_format = sum(r["metrics"]["format_score"] for r in results) / len(results)
-    full_format = sum(1 for r in results if r["metrics"]["format_score"] == 1.0)
+    avg_score = sum(r["scores"]["percentage"] for r in results) / len(results)
+    excellent = sum(1 for r in results if r["scores"]["rating"] == "Excellent")
+    good = sum(1 for r in results if r["scores"]["rating"] == "Good")
+    acceptable = sum(1 for r in results if r["scores"]["rating"] == "Acceptable")
+    poor = sum(1 for r in results if r["scores"]["rating"] == "Poor")
+    failed = sum(1 for r in results if r["scores"]["rating"] == "Failed")
 
     print(f"  Total queries: {len(results)}")
-    print(f"  Average format score: {avg_format:.2%}")
-    print(f"  Full format compliance: {full_format}/{len(results)} ({full_format/len(results):.0%})")
+    print(f"  Average score: {avg_score:.1f}%")
+    print(f"  Ratings:")
+    print(f"    Excellent: {excellent}")
+    print(f"    Good:      {good}")
+    print(f"    Acceptable: {acceptable}")
+    print(f"    Poor:      {poor}")
+    print(f"    Failed:    {failed}")
 
     # Save results
     with open(args.output, "w") as f: