Browse Source

fix(reward): tighten entity detection, add filler penalty, stricter diversity

- Compound entity chaining now stops one level deep. Previously "TDS
  motorsports team history" would inflate the expected entity set with
  "team" and "history", causing false-positive entity-preservation
  penalties during GRPO. Now only {tds, motorsports} are detected.

- Add INTERIOR_FILLER_WORDS penalty (-3/line): lex lines containing
  "overview" or "basics" absent from the original query are penalised.
  Targets template-generator noise, e.g. "ancient overview rome timeline".

- Raise is_diverse threshold 2→3: requires 3 unique words between lex
  lines before they count as diverse. Reduces reward for near-duplicate
  pairs like "auth setup" / "auth configuration".

- Broaden quoted-phrase bonus: was gated on named entities existing;
  now any multi-word query earns +3 for using quotes in lex lines.
  Better incentivises BM25-aware syntax like "memory leak" python.

Fixes scoring noise identified while working on issue #247.
rkbadhan 3 months ago
parent
commit
4511b9bd4d
1 changed files with 47 additions and 22 deletions
  1. 47 22
      finetune/reward.py

+ 47 - 22
finetune/reward.py

@@ -72,6 +72,10 @@ GENERIC_LEX_PHRASES = frozenset({
     'what is', 'how to', 'guide to', 'help with',
 })
 
+# Words commonly injected as filler/noise into lex lines by template generators
+# (e.g. "ancient overview rome timeline"). Penalized when absent from the query.
+INTERIOR_FILLER_WORDS = frozenset({'overview', 'basics'})
+
 # Chat template tokens that indicate a broken output
 CHAT_TEMPLATE_TOKENS = frozenset({
     '<|im_start|>', '<|im_end|>', '<|endoftext|>',
@@ -142,47 +146,49 @@ def extract_named_entities(query: str) -> set:
 
     Position-0 words are also detected as entities if they are capitalized and
     not common query-starting verbs (e.g. "Bob asked about deploy" -> "bob").
+
+    Compound chaining extends one level from a directly-detected entity:
+    "TDS motorsports" -> {tds, motorsports}; "TDS motorsports team" -> {tds, motorsports}.
     """
     entities = set()
     words = query.split()
-    prev_was_entity = False
+    prev_was_base_entity = False
 
     for i, word in enumerate(words):
         clean = word.strip('.,!?:;()[]"\'')
         if not clean:
-            prev_was_entity = False
+            prev_was_base_entity = False
             continue
 
-        is_entity = False
+        is_base_entity = False
 
         # ALL-CAPS acronyms: TDS, API, GPU, AWS
         if clean.isupper() and len(clean) >= 2:
             entities.add(clean.lower())
-            is_entity = True
+            is_base_entity = True
         # Capitalized proper nouns (any position, including first word)
         elif clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
             if i > 0:
                 # Non-first words: always treat as entity
                 entities.add(clean.lower())
-                is_entity = True
+                is_base_entity = True
             elif clean.lower() not in QUERY_VERB_STOPWORDS:
                 # First word: also entity if not a common query verb
                 entities.add(clean.lower())
-                is_entity = True
+                is_base_entity = True
         # Technical terms with special chars: node.js, C++, .NET
         elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
             entities.add(clean.lower())
-            is_entity = True
+            is_base_entity = True
         # CamelCase: JavaScript, TypeScript
         elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
             entities.add(clean.lower())
-            is_entity = True
-        # Compound names: word following an entity (TDS motorsports)
-        elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
+            is_base_entity = True
+        # Compound names: word following a BASE entity only (one level deep).
+        elif prev_was_base_entity and clean.lower() not in KEY_TERM_STOPWORDS:
             entities.add(clean.lower())
-            is_entity = True
 
-        prev_was_entity = is_entity
+        prev_was_base_entity = is_base_entity
 
     return entities
 
@@ -208,6 +214,13 @@ def lex_preserves_entities(line: str, entities: set) -> bool:
     return any(e in lower for e in entities)
 
 
+def lex_has_filler(lex_line: str, query: str) -> bool:
+    """Does the lex line contain an INTERIOR_FILLER_WORDS word absent from the query?"""
+    query_words = set(query.lower().split())
+    return any(w in INTERIOR_FILLER_WORDS and w not in query_words
+               for w in lex_line.lower().split())
+
+
 def lex_is_generic(lex_line: str) -> bool:
     """Is this lex line a useless generic filler phrase?"""
     lower = lex_line.lower().strip()
@@ -280,13 +293,14 @@ def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool
     
     # --- Diversity (0-30) ---
     diversity_score = 0
+    div_threshold = 3 if len(base_query.split()) >= 5 else 2
     if len(expected_items) >= 2:
         diversity_score += 15
         # Check for diversity among items
         div_score = 15
         for i, a in enumerate(expected_items):
             for b in expected_items[i+1:]:
-                if not is_diverse(a, b, 2):
+                if not is_diverse(a, b, div_threshold):
                     div_score -= 5
                     deductions.append(f"{only_type} duplicate: {a[:20]}...")
         diversity_score += max(0, div_score)
@@ -315,6 +329,11 @@ def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool
             quality_score += 5
         else:
             deductions.append(f"{generic} generic lex phrases")
+        # Penalty: lex lines containing filler words absent from the query
+        filler_count = sum(1 for l in expected_items if lex_has_filler(l, base_query))
+        if filler_count > 0:
+            quality_score -= filler_count * 3
+            deductions.append(f"{filler_count} lex line(s) with filler words")
     
     elif only_type == "vec":
         # Vec should be natural language sentences
@@ -444,10 +463,11 @@ def score_expansion_detailed(query: str, expansion: str) -> dict:
     if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
         diversity_score += 5
 
+    div_threshold = 3 if len(query.split()) >= 5 else 2
     lex_div = 5
     for i, a in enumerate(parsed["lex"]):
         for b in parsed["lex"][i+1:]:
-            if not is_diverse(a, b, 2):
+            if not is_diverse(a, b, div_threshold):
                 lex_div -= 2
                 deductions.append(f"lex duplicate: {a[:20]}...")
     diversity_score += max(0, lex_div)
@@ -455,7 +475,7 @@ def score_expansion_detailed(query: str, expansion: str) -> dict:
     vec_div = 5
     for i, a in enumerate(parsed["vec"]):
         for b in parsed["vec"][i+1:]:
-            if not is_diverse(a, b, 3):
+            if not is_diverse(a, b, div_threshold):
                 vec_div -= 2
                 deductions.append(f"vec duplicate: {a[:20]}...")
     diversity_score += max(0, vec_div)
@@ -517,13 +537,18 @@ def score_expansion_detailed(query: str, expansion: str) -> dict:
         else:
             deductions.append("lex missing key terms")
 
-    # Bonus: lex uses quoted phrases for multi-word entities (+3)
-    if entities and parsed["lex"]:
-        multi_word_entities = [e for e in entities if " " in e or len(e) > 6]
-        if multi_word_entities:
-            lex_joined = " ".join(parsed["lex"])
-            if '"' in lex_joined:
-                quality_score += 3
+    # Penalty: lex lines containing filler words absent from the query
+    if parsed["lex"]:
+        filler_count = sum(1 for l in parsed["lex"] if lex_has_filler(l, query))
+        if filler_count > 0:
+            quality_score -= filler_count * 3
+            deductions.append(f"{filler_count} lex line(s) with filler words")
+
+    # Bonus: lex uses quoted phrases for multi-word queries (+3)
+    if parsed["lex"] and len(query.split()) >= 2:
+        lex_joined = " ".join(parsed["lex"])
+        if '"' in lex_joined:
+            quality_score += 3
 
     # --- Entity Preservation (-45 to +20) ---
     entity_score = 0