Переглянути джерело

Add named entity extraction to GRPO reward function

Key changes:
- Extract named entities (acronyms, proper nouns, technical terms)
- Heavy penalty (-30) when lex queries miss named entities
- Penalty (-15) for generic filler phrases like "find information about"
- Compound entity detection (TDS motorsports -> both words)
- Update GRPO config with KL regularization (beta=0.04)
- Lower learning rate (5e-7) and add max_steps (200)

Test results:
- "who is TDS motorsports" good: 1.00, bad: 0.30 (was 0.75)
- "how to use React hooks" good: 0.87, bad: 0.45 (was 0.75)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 4 місяців тому
батько
коміт
6062dc769f
4 змінених файлів з 372 додано та 101 видалено
  1. 100 75
      finetune/README.md
  2. 116 12
      finetune/SCORING.md
  3. 12 8
      finetune/configs/grpo_v4.yaml
  4. 144 6
      finetune/rl.py

+ 100 - 75
finetune/README.md

@@ -20,92 +20,117 @@ hyde: To configure authentication, set the AUTH_SECRET environment variable and
 
 | Type | Purpose | Count |
 |------|---------|-------|
-| `lex` | BM25 keyword variations | 1-3 |
-| `vec` | Semantic reformulations | 1-3 |
-| `hyde` | Hypothetical document passage | 0-1 |
+| `lex:` | BM25 keyword variations (short, keyword-focused) | 1-3 |
+| `vec:` | Semantic reformulations (natural language) | 1-3 |
+| `hyde:` | Hypothetical document passage (50-150 chars) | 0-1 |
 
 ## Trained Models
 
-| Model | HuggingFace | Format Compliance | Status |
-|-------|-------------|-------------------|--------|
-| **Qwen3-0.6B (SFT)** | [tobil/qmd-query-expansion-0.6B](https://huggingface.co/tobil/qmd-query-expansion-0.6B) | **95%** | Recommended |
-| Qwen3-1.7B v2 (SFT) | [tobil/qmd-query-expansion-1.7B-v2](https://huggingface.co/tobil/qmd-query-expansion-1.7B-v2) | TBD | Completed |
-| Qwen3-0.6B (GRPO) | [tobil/qmd-query-expansion-0.6B-grpo](https://huggingface.co/tobil/qmd-query-expansion-0.6B-grpo) | 0% | Failed - lost formatting |
-| Qwen3-1.7B v1 (SFT) | [tobil/qmd-query-expansion-1.7B](https://huggingface.co/tobil/qmd-query-expansion-1.7B) | 0% | Training issues |
-| Qwen3-0.6B (baseline) | - | 0% | Untrained |
+| Model | HuggingFace | Score | Status |
+|-------|-------------|-------|--------|
+| **Qwen3-0.6B v4 (SFT)** | [tobil/qmd-query-expansion-0.6B-v4](https://huggingface.co/tobil/qmd-query-expansion-0.6B-v4) | **98.8%** | Recommended |
+| Qwen3-0.6B v4 (GRPO) | [tobil/qmd-query-expansion-0.6B-v4-grpo](https://huggingface.co/tobil/qmd-query-expansion-0.6B-v4-grpo) | 0% | Failed - catastrophic drift |
 
-**Note:** GRPO (RL) training caused catastrophic forgetting - the model lost all learned formatting.
+## Prompt Format
 
-## Training Dataset
+The models use **Qwen3 chat template** with `/no_think` to disable thinking mode.
+
+### Inference (Python)
+
+```python
+from transformers import AutoTokenizer
+
+tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
+
+# CRITICAL: Use /no_think to disable Qwen3's thinking mode
+messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
+
+prompt = tokenizer.apply_chat_template(
+    messages,
+    tokenize=False,
+    add_generation_prompt=True
+)
+
+# Generate and decode
+output = tokenizer.decode(tokens, skip_special_tokens=True)
+
+# Extract assistant response (skip_special_tokens converts to "user\n...\nassistant\n...")
+if "\nassistant\n" in output:
+    expansion = output.split("\nassistant\n")[-1].strip()
+```
+
+### Raw Format
+
+```
+<|im_start|>user
+/no_think Expand this search query: auth<|im_end|>
+<|im_start|>assistant
+lex: authentication configuration
+lex: auth settings
+vec: how to configure authentication
+vec: authentication setup guide
+hyde: To configure authentication, set AUTH_SECRET in your environment.<|im_end|>
+```
 
-- **Dataset**: [tobil/qmd-query-expansion-train](https://huggingface.co/datasets/tobil/qmd-query-expansion-train)
-- **Source**: Transformed from [s-emanuilov/query-expansion](https://huggingface.co/datasets/s-emanuilov/query-expansion) (CC BY 4.0)
-- **Size**: 5,157 examples (train: 4,641, eval: 516)
-- **Format**: Chat messages with user query and assistant response in lex/vec/hyde format
+See `PROMPT_FORMAT.md` for complete specification.
 
 ## Directory Structure
 
 ```
 finetune/
-├── README.md                 # This file
-├── DATASETS.md               # Dataset research findings
-├── TRAINING_JOBS.md          # HuggingFace Jobs tracking
-├── generate_data_offline.py  # Transform s-emanuilov dataset to QMD format
-├── prepare_data.py           # Upload to HuggingFace Hub
-├── train_0.6B.py             # Training script for 0.6B model
-├── train_1.7B.py             # Training script for 1.7B model
-├── train_grpo.py             # GRPO RL training (optional)
-├── evaluate_model.py         # Evaluate finetuned models
-├── evaluate_baseline.py      # Evaluate base models
-├── data/
-│   ├── qmd_expansion.jsonl   # Generated training data
-│   └── train/                # Prepared chat format
-└── evaluation_*.json         # Evaluation results
+├── train.py              # SFT training (uses YAML config)
+├── rl.py                 # GRPO/RL training (uses YAML config)
+├── evaluate_model.py     # Evaluate finetuned models
+├── tui.py                # Interactive testing interface
+├── configs/
+│   ├── sft_v4.yaml       # SFT training config
+│   └── grpo_v4.yaml      # GRPO training config
+├── dataset/
+│   ├── prepare_data.py   # Prepare training data
+│   ├── clean_data.py     # Data quality improvements
+│   └── generate_data*.py # Generate from source datasets
+├── PROMPT_FORMAT.md      # Prompt format specification
+├── SCORING.md            # Scoring criteria
+└── data/
+    └── train/            # Prepared training data
 ```
 
 ## Quick Start
 
-### 1. Generate Training Data
+### 1. Prepare Training Data
 
 ```bash
-# Transform s-emanuilov dataset to QMD format (no API needed)
-uv run generate_data_offline.py
+cd dataset
+uv run prepare_data.py --add-short 5
 ```
 
-### 2. Prepare and Upload Dataset
+### 2. Train with YAML Config
 
 ```bash
-# Convert to chat format and upload to HuggingFace Hub
-uv run prepare_data.py
-```
+# Local training
+uv run train.py --config configs/sft_v4.yaml
 
-### 3. Train on HuggingFace Jobs
-
-```bash
-# Train Qwen3-0.6B (recommended)
-hf jobs uv run --flavor a10g-large --timeout 3h --secrets HF_TOKEN \
-  "https://huggingface.co/tobil/qmd-training-scripts/resolve/main/train_0.6B.py"
+# Or on HuggingFace Jobs
+hf jobs uv run --flavor a10g-large --timeout 2h --secrets HF_TOKEN \
+  "https://huggingface.co/datasets/tobil/qmd-query-expansion-train-v2/resolve/main/train_sft_v4.py"
 ```
 
-### 4. Evaluate
+### 3. Evaluate
 
 ```bash
-# Evaluate finetuned model
-uv run evaluate_model.py --model tobil/qmd-query-expansion-0.6B --base-model Qwen/Qwen3-0.6B
-
-# Compare to baseline
-uv run evaluate_baseline.py --model Qwen/Qwen3-0.6B --num-queries 10
+uv run evaluate_model.py --model tobil/qmd-query-expansion-0.6B-v4
 ```
 
-### 5. Export to GGUF
+### 4. Interactive Testing
 
 ```bash
-# Convert to GGUF for node-llama-cpp (TODO)
-uv run export_gguf.py --model tobil/qmd-query-expansion-0.6B --quantization Q8_0
+uv run tui.py
 ```
 
 ## Training Configuration
 
+Default SFT config (`configs/sft_v4.yaml`):
+
 | Parameter | Value |
 |-----------|-------|
 | Method | LoRA (rank 16, alpha 32) |
@@ -115,37 +140,37 @@ uv run export_gguf.py --model tobil/qmd-query-expansion-0.6B --quantization Q8_0
 | Max Seq Length | 512 |
 | Target Modules | q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj |
 
-## Prompt Format
-
-The models are trained on this simple prompt format:
-
-```
-Expand this search query:
+## Training Dataset
 
-{query}
-```
+- **Dataset**: [tobil/qmd-query-expansion-train-v2](https://huggingface.co/datasets/tobil/qmd-query-expansion-train-v2)
+- **Size**: 6,180 examples (26.5% short queries)
+- **Format**: Qwen3 chat messages with `/no_think` directive
 
-The model responds with lex/vec/hyde lines directly.
+Key improvements in v2:
+- Short query examples with proper expansions
+- Hyde passages truncated to 150 chars
+- Key term preservation in lex lines
 
 ## Evaluation Results
 
-### 0.6B Finetuned Model (95% format compliance)
+### SFT v4 (98.8% average score)
+
+All 21 test queries rated "Excellent":
 
-Sample outputs:
+| Query | Score | Rating |
+|-------|-------|--------|
+| `how to configure authentication` | 99% | Excellent |
+| `auth` | 95% | Excellent |
+| `git rebase vs merge` | 100% | Excellent |
+| `react useEffect cleanup` | 100% | Excellent |
 
-| Query | Output |
-|-------|--------|
-| `how to configure authentication` | lex: steps for setting up authentication<br>vec: steps for setting up authentication in cloud services<br>hyde: The process of configure authentication... |
-| `kubernetes vs docker swarm` | lex: kubernetes and docker swarm<br>vec: kubernetes vs docker swarm<br>hyde: Kubernetes vs docker swarm is an important concept... |
-| `cors error fix` | lex: how to fix cors<br>vec: how to fix cors issues in web apps<br>hyde: The topic of cors error fix guide... |
+### GRPO v4 (0% - Failed)
 
-### Baseline Model (0% format compliance)
+The GRPO training caused catastrophic drift. The model now generates verbose explanations instead of structured `lex:/vec:/hyde:` format.
 
-The untrained model generates random prose, code blocks, or repetitive text with no understanding of the lex/vec/hyde format.
+**Root cause**: Reward function didn't enforce format strictly enough. The model learned that verbose explanations could score higher than concise structured output.
 
-## Future Work
+## Known Issues
 
-- [ ] Export to GGUF for local inference
-- [ ] Integrate into QMD as default query expansion model
-- [ ] GRPO training for improved diversity (optional)
-- [ ] Fix 1.7B training issues
+- **GRPO drift**: RL training causes the model to lose SFT-learned formatting. Needs stricter format enforcement in reward function.
+- **Key term preservation**: Some lex lines still too generic (missing query key terms)

+ 116 - 12
finetune/SCORING.md

@@ -56,29 +56,63 @@ hyde: Authentication can be configured by setting the AUTH_SECRET environment va
 
 | Criterion | Points | Deduction |
 |-----------|--------|-----------|
-| Base relevance | +10 | Subjective |
-| Lex lines preserve key terms from query | +5 | -5 if lex is generic (e.g., "features" without "shop.app") |
+| Base relevance | +5 | Subjective |
+| Lex lines preserve key terms from query | +5 | -5 if lex is generic |
 | Lex lines are keyword-focused (shorter) | +5 | -2 if lex is longer than vec |
 | Vec lines are natural language (complete phrases) | +5 | -2 if vec is just keywords |
 
-**Key Term Preservation Rule**: `lex:` lines MUST contain at least one significant word from the original query (excluding stopwords like "what", "is", "how", "the", etc.).
+### 5. Named Entity Preservation (0-20 points, CRITICAL)
 
-Bad: `what is shop.app` → `lex: features and benefits`
-Good: `what is shop.app` → `lex: shop.app features`
+Named entities are proper nouns, brand names, technical terms, and acronyms that MUST appear in lex queries. This prevents generic expansions that lose the specific topic.
+
+| Criterion | Points | Deduction |
+|-----------|--------|-----------|
+| All lex lines contain at least one entity | +15 | - |
+| Some lex lines contain entities | +5 | - |
+| NO lex lines contain entities | - | **-30 HEAVY PENALTY** |
+| Generic filler phrases in lex | - | -15 per phrase |
+| Entities also in vec lines | +5 | - |
+
+**Named Entity Detection:**
+- All-caps acronyms: `TDS`, `API`, `GPU`, `AWS`
+- Capitalized proper nouns: `React`, `Docker`, `Kubernetes`
+- Technical terms: `node.js`, `C++`, `.NET`
+- CamelCase: `JavaScript`, `TypeScript`
+- Compound names: `TDS motorsports` → both words are entities
+
+**Generic Filler Phrases (BANNED in lex):**
+- "find information about"
+- "search for", "look up"
+- "get information", "learn about"
+- "details about", "guide to"
+
+**Examples:**
+
+| Query | Bad Lex (Score: 0.30) | Good Lex (Score: 1.00) |
+|-------|----------------------|------------------------|
+| `who is TDS motorsports` | `lex: find information about` | `lex: TDS motorsports history` |
+| | `lex: company details` | `lex: TDS motorsports founders` |
+| `how to use React hooks` | `lex: programming tutorial` | `lex: React hooks tutorial` |
+| | `lex: how to code` | `lex: useEffect useState hooks` |
+
+**Key Rule**: If a query mentions a specific entity (brand, product, technology), EVERY lex line should include that entity or a direct variation of it.
 
 ## Score Calculation
 
 ```
-Total Score = Format + Diversity + Hyde + Quality
-Max Score = 100 (80 without hyde)
+Total Score = Format + Diversity + Hyde + Quality + Entity
+Max Score = 120 (100 without hyde)
+Normalized = Total / Max (0.0 - 1.0)
 ```
 
 **Rating:**
-- 80-100: Excellent
-- 60-79: Good
-- 40-59: Acceptable
-- 20-39: Poor
-- 0-19: Failed
+- 0.80-1.00: Excellent
+- 0.60-0.79: Good
+- 0.40-0.59: Acceptable
+- 0.20-0.39: Poor
+- 0.00-0.19: Failed
+
+**Note:** Entity score can go negative, heavily penalizing outputs that miss named entities.
 
 ## Examples
 
@@ -174,6 +208,76 @@ def echoes_query(expansion, query):
     return exp == q or exp in q or q in exp
 ```
 
+### Named Entity Extraction
+
+```python
+KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
+                      'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
+                      'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
+
+def extract_named_entities(query: str) -> set:
+    """Extract named entities using simple heuristics."""
+    entities = set()
+    words = query.split()
+    prev_was_entity = False
+
+    for i, word in enumerate(words):
+        clean = word.strip('.,!?:;()[]"\'')
+        if not clean:
+            prev_was_entity = False
+            continue
+
+        is_entity = False
+
+        # All-caps acronyms: TDS, API, GPU
+        if clean.isupper() and len(clean) >= 2:
+            entities.add(clean.lower())
+            is_entity = True
+        # Capitalized proper nouns (not first word)
+        elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
+            entities.add(clean.lower())
+            is_entity = True
+        # Technical terms: node.js, C++
+        elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
+            entities.add(clean.lower())
+            is_entity = True
+        # CamelCase: JavaScript
+        elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
+            entities.add(clean.lower())
+            is_entity = True
+        # Word following an entity (compound names: TDS motorsports)
+        elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
+            entities.add(clean.lower())
+            is_entity = True
+
+        prev_was_entity = is_entity
+
+    return entities
+```
+
+### Generic Phrase Detection
+
+```python
+GENERIC_LEX_PHRASES = {
+    'find information about', 'search for', 'look up', 'get information',
+    'learn about', 'information on', 'details about', 'find out about',
+    'what is', 'how to', 'guide to', 'help with'
+}
+
+def lex_is_generic(lex_line: str) -> bool:
+    """Check if lex line is a useless generic filler."""
+    lex_lower = lex_line.lower().strip()
+    for phrase in GENERIC_LEX_PHRASES:
+        if phrase in lex_lower:
+            # Check if there's specific content beyond the generic phrase
+            remaining = lex_lower
+            for word in phrase.split():
+                remaining = remaining.replace(word, '', 1).strip()
+            if len(remaining) < 3:  # Nothing specific left
+                return True
+    return False
+```
+
 ## Training Data Requirements
 
 1. **EOM tokens**: Ensure training examples end with proper end-of-message tokens

+ 12 - 8
finetune/configs/grpo_v4.yaml

@@ -1,5 +1,10 @@
-# GRPO Training Config for QMD Query Expansion v4
+# GRPO Training Config for QMD Query Expansion v5
 # Uses SFT model as starting point with RL fine-tuning
+#
+# Key learnings from v4 failure (0% score, catastrophic drift):
+# 1. beta=0.0 (default) allows unlimited drift from SFT checkpoint
+# 2. Format-based rewards need KL regularization unlike math reasoning
+# 3. Reward function must give 0 for wrong format (now implemented)
 
 model:
   sft: "tobil/qmd-query-expansion-0.6B-v4"  # Starting point
@@ -9,18 +14,21 @@ model:
 dataset:
   name: "tobil/qmd-query-expansion-train-v2"
   prompt_field: "messages"
-  max_samples: 2000
+  max_samples: 1000  # Reduced to prevent overfitting
 
 training:
   epochs: 1
   batch_size: 2
   gradient_accumulation_steps: 8
-  learning_rate: 1e-6  # Very low for RL stability
+  learning_rate: 5e-7  # Even lower for stability
   max_grad_norm: 0.5
+  max_steps: 200  # Limit total steps to prevent drift
 
 grpo:
   num_generations: 4
   max_completion_length: 200
+  beta: 0.04  # KL regularization - prevents drift from SFT checkpoint
+  temperature: 0.7  # Slightly lower for more focused generations
 
 lora:
   rank: 4  # Smaller for RL stability
@@ -32,8 +40,4 @@ lora:
 
 tracking:
   project: "qmd-query-expansion"
-  run_name: "grpo-v4-key-term-preservation"
-
-# Note: GRPO v4 failed due to reward function not enforcing format strictly enough.
-# The model drifted to verbose explanations instead of lex:/vec:/hyde: format.
-# Recommendation: Add hard penalty when output doesn't start with valid prefix.
+  run_name: "grpo-v5-kl-regularized"

+ 144 - 6
finetune/rl.py

@@ -38,15 +38,77 @@ from trl import GRPOTrainer, GRPOConfig
 
 STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
 KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
-                      'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we'}
+                      'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
+                      'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
+
+# Generic filler phrases that should never be in lex queries
+GENERIC_LEX_PHRASES = {
+    'find information about', 'search for', 'look up', 'get information',
+    'learn about', 'information on', 'details about', 'find out about',
+    'what is', 'how to', 'guide to', 'help with'
+}
+
+
+def extract_named_entities(query: str) -> set:
+    """Extract named entities from query using simple heuristics.
+
+    Named entities are:
+    - Capitalized words (except first word which may just be sentence start)
+    - All-caps words/acronyms (TDS, API, GPU)
+    - Technical terms with special chars (node.js, C++, .NET)
+    - Words following acronyms/proper nouns (TDS motorsports -> both words)
+    """
+    entities = set()
+    words = query.split()
+    prev_was_entity = False
+
+    for i, word in enumerate(words):
+        # Clean punctuation but keep internal special chars
+        clean = word.strip('.,!?:;()[]"\'')
+        if not clean:
+            prev_was_entity = False
+            continue
+
+        is_entity = False
+
+        # All-caps words (acronyms): TDS, API, GPU, etc.
+        if clean.isupper() and len(clean) >= 2:
+            entities.add(clean.lower())
+            is_entity = True
+
+        # Capitalized words (not first word, not common words)
+        elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
+            entities.add(clean.lower())
+            is_entity = True
+
+        # Technical terms with special chars: node.js, C++, .NET
+        elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
+            entities.add(clean.lower())
+            is_entity = True
+
+        # CamelCase: JavaScript, TypeScript, etc.
+        elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
+            entities.add(clean.lower())
+            is_entity = True
+
+        # Word following an entity is likely part of compound name (TDS motorsports)
+        elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
+            entities.add(clean.lower())
+            is_entity = True
+
+        prev_was_entity = is_entity
+
+    return entities
 
 
 def get_key_terms(query: str) -> set:
+    """Get key terms (non-stopwords) from query."""
     words = set(query.lower().split())
     return words - KEY_TERM_STOPWORDS
 
 
 def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
+    """Check if lex line preserves key terms from query."""
     key_terms = get_key_terms(query)
     if not key_terms:
         return True
@@ -54,6 +116,28 @@ def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
     return bool(key_terms & lex_words)
 
 
+def lex_preserves_entities(lex_line: str, entities: set) -> bool:
+    """Check if lex line contains at least one named entity."""
+    if not entities:
+        return True  # No entities to preserve
+    lex_lower = lex_line.lower()
+    return any(entity in lex_lower for entity in entities)
+
+
+def lex_is_generic(lex_line: str) -> bool:
+    """Check if lex line is a generic filler phrase."""
+    lex_lower = lex_line.lower().strip()
+    for phrase in GENERIC_LEX_PHRASES:
+        if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
+            # Also check if it's ONLY the generic phrase with no specifics
+            remaining = lex_lower
+            for word in phrase.split():
+                remaining = remaining.replace(word, '', 1).strip()
+            if len(remaining) < 3:  # Nothing specific left
+                return True
+    return False
+
+
 def parse_expansion(text: str) -> dict:
     lines = text.strip().split("\n")
     result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
@@ -193,9 +277,37 @@ def score_expansion(query: str, expansion: str) -> float:
         elif lex_with_terms > 0:
             quality_score += 2
 
-    total = format_score + diversity_score + hyde_score + quality_score
-    max_possible = 100 if parsed["hyde"] else 80
-    return total / max_possible
+    # NAMED ENTITY PRESERVATION (critical for quality)
+    # This score can go heavily negative to punish missing entities
+    entity_score = 0
+    entities = extract_named_entities(query)
+    if entities and parsed["lex"]:
+        # Count lex lines that preserve at least one entity
+        lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
+        if lex_with_entities == len(parsed["lex"]):
+            entity_score += 15  # All lex lines have entities - great!
+        elif lex_with_entities > 0:
+            entity_score += 5   # Some have entities
+        else:
+            entity_score -= 30  # NO lex lines have entities - HEAVY penalty!
+
+        # Penalize generic filler phrases in lex (these are useless for BM25)
+        generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
+        entity_score -= generic_count * 15  # -15 per generic phrase
+
+        # Bonus for entities in vec too (less critical but nice)
+        if parsed["vec"]:
+            vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
+            if vec_with_entities > 0:
+                entity_score += 5
+    elif not entities:
+        # No entities in query - give base score
+        entity_score = 10
+
+    # Entity score CAN go negative to heavily penalize missing entities
+    total = format_score + diversity_score + hyde_score + quality_score + entity_score
+    max_possible = 120 if parsed["hyde"] else 100
+    return max(0.0, min(1.0, total / max_possible))  # Clamp to 0.0-1.0
 
 
 def extract_query_from_prompt(prompt: str) -> str:
@@ -231,10 +343,36 @@ def main():
         print("GRPO Training Configuration:")
         print(yaml.dump(cfg, default_flow_style=False))
         print("\nTesting reward function...")
+
+        # Test 1: Basic query
         test_good = "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."
         test_bad = "auth is important for security"
-        print(f"  Good output score: {score_expansion('auth', test_good):.2f}")
-        print(f"  Bad output score: {score_expansion('auth', test_bad):.2f}")
+        print(f"\n  Query: 'auth'")
+        print(f"    Good output score: {score_expansion('auth', test_good):.2f}")
+        print(f"    Bad output score: {score_expansion('auth', test_bad):.2f}")
+
+        # Test 2: Named entity query (the critical case!)
+        query_entity = "who is TDS motorsports"
+        entities = extract_named_entities(query_entity)
+        print(f"\n  Query: '{query_entity}'")
+        print(f"    Extracted entities: {entities}")
+
+        good_entity = "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"
+        bad_entity = "lex: find information about\nlex: company details\nvec: who is this company"
+        print(f"    Good (preserves entity): {score_expansion(query_entity, good_entity):.2f}")
+        print(f"    Bad (generic phrases): {score_expansion(query_entity, bad_entity):.2f}")
+
+        # Test 3: Technical term
+        query_tech = "how to use React hooks"
+        entities_tech = extract_named_entities(query_tech)
+        print(f"\n  Query: '{query_tech}'")
+        print(f"    Extracted entities: {entities_tech}")
+
+        good_tech = "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"
+        bad_tech = "lex: programming tutorial\nlex: how to code\nvec: learn web development"
+        print(f"    Good (preserves React): {score_expansion(query_tech, good_tech):.2f}")
+        print(f"    Bad (generic): {score_expansion(query_tech, bad_tech):.2f}")
+
         return
 
     # Login