Explorar el Código

Fix GRPO reward function to handle think blocks and end tokens

- Strip <|im_end|> token from completions (model output includes it)
- Change think_penalty to skipped_think bonus (+20 for not using think)
- Adjust max_possible to account for bonus (120/140)
- Fix typo in chat template artifact check

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke hace 3 meses
padre
commit
891f3262cf
Se han modificado 2 ficheros con 24 adiciones y 8 borrados
  1. 12 4
      finetune/train_1.7B_grpo.py
  2. 12 4
      finetune/train_4B_grpo.py

+ 12 - 4
finetune/train_1.7B_grpo.py

@@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
     """Score expansion. Returns 0.0-1.0 for RL reward."""
     text = expansion.strip()
 
+    # Strip end token if present
+    text = text.replace('<|im_end|>', '').strip()
+
+    # Check for <think>...</think> blocks - strip and mark as not skipped
+    skipped_think = 20  # Bonus for not using thinking mode
+    if '<think>' in text and '</think>' in text:
+        text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
+        skipped_think = 0  # No bonus if thinking was used
+
     # HARD FAIL: Chat template artifacts
-    if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
-                                        '\nassistant\n', '\nuser\n', '<|endoftext|>']):
+    if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
         return 0.0
 
     # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
@@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
     elif not entities:
         entity_score = 10
 
-    total = format_score + diversity_score + hyde_score + quality_score + entity_score
-    max_possible = 120 if parsed["hyde"] else 100
+    total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
+    max_possible = 140 if parsed["hyde"] else 120  # +20 for skipped_think bonus
     return max(0.0, min(1.0, total / max_possible))
 
 

+ 12 - 4
finetune/train_4B_grpo.py

@@ -166,9 +166,17 @@ def score_expansion(query: str, expansion: str) -> float:
     """Score expansion. Returns 0.0-1.0 for RL reward."""
     text = expansion.strip()
 
+    # Strip end token if present
+    text = text.replace('<|im_end|>', '').strip()
+
+    # Check for <think>...</think> blocks - strip and mark as not skipped
+    skipped_think = 20  # Bonus for not using thinking mode
+    if '<think>' in text and '</think>' in text:
+        text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
+        skipped_think = 0  # No bonus if thinking was used
+
     # HARD FAIL: Chat template artifacts
-    if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
-                                        '\nassistant\n', '\nuser\n', '<|endoftext|>']):
+    if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
         return 0.0
 
     # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
@@ -275,8 +283,8 @@ def score_expansion(query: str, expansion: str) -> float:
     elif not entities:
         entity_score = 10
 
-    total = format_score + diversity_score + hyde_score + quality_score + entity_score
-    max_possible = 120 if parsed["hyde"] else 100
+    total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
+    max_possible = 140 if parsed["hyde"] else 120  # +20 for skipped_think bonus
     return max(0.0, min(1.0, total / max_possible))