Browse Source

Strict format validation: every line must be lex:/vec:/hyde:

Any line that doesn't start with a valid prefix now returns 0.0
instead of just counting as a penalty. This prevents any prose,
explanations, bullet points, or other invalid content.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 4 tháng trước cách đây
mục cha
commit
dc8f5a2335
1 tập tin đã thay đổi với 22 bổ sung8 xóa
  1. 22 8
      finetune/rl.py

+ 22 - 8
finetune/rl.py

@@ -200,23 +200,24 @@ def score_expansion(query: str, expansion: str) -> float:
                                         '\nassistant\n', '\nuser\n', '<|endoftext|>']):
         return 0.0  # Zero reward for chat template leakage
 
-    # HARD FAIL: Must start with valid prefix (prevents verbose explanations)
-    first_line = text.split("\n")[0].strip() if text else ""
-    if not first_line.startswith(("lex:", "vec:", "hyde:")):
-        return 0.0  # Zero reward for wrong format
+    # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
+    for line in text.split("\n"):
+        line = line.strip()
+        if not line:
+            continue  # Skip empty lines
+        if not line.startswith(("lex:", "vec:", "hyde:")):
+            return 0.0  # Zero reward for any invalid line
 
     parsed = parse_expansion(expansion)
 
     # FORMAT (0-30)
+    # Note: invalid lines already cause hard fail above, so parsed["invalid"] is always empty here
     format_score = 0
     if parsed["lex"]:
         format_score += 10
     if parsed["vec"]:
         format_score += 10
-    if not parsed["invalid"]:
-        format_score += 10
-    else:
-        format_score += max(0, 10 - len(parsed["invalid"]) * 5)
+    format_score += 10  # No invalid lines (guaranteed by hard fail above)
 
     # DIVERSITY (0-30)
     diversity_score = 0
@@ -391,6 +392,19 @@ def main():
             status = "✓" if score == 0.0 else "✗ FAIL"
             print(f"    {status} '{test[:40]}...' -> {score:.2f}")
 
+        # Test 5: Invalid line format (MUST be 0.0)
+        print(f"\n  Invalid line format tests (all should be 0.00):")
+        invalid_tests = [
+            "lex: auth\nThis is some explanation\nvec: more",
+            "lex: auth\nvec: search\nHere's why I chose these",
+            "Authentication is important\nlex: auth",
+            "lex: auth\n- bullet point",
+        ]
+        for test in invalid_tests:
+            score = score_expansion("auth", test)
+            status = "✓" if score == 0.0 else "✗ FAIL"
+            print(f"    {status} '{test[:40]}...' -> {score:.2f}")
+
         return
 
     # Login