Преглед на файлове

Fix GRPO training: apply chat template to prompts

The SFT model was trained with chat template format but GRPO was
passing raw prompts. Now prompts are formatted with tokenizer.apply_chat_template()
so the model sees the same format it learned during SFT.

Also update extract_query_from_prompt to strip chat template artifacts.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke преди 3 месеца
родител
ревизия
9b3a209a97
променени са 2 файла, в които са добавени 32 реда и са изтрити 16 реда
  1. 16 8
      finetune/train_1.7B_grpo.py
  2. 16 8
      finetune/train_4B_grpo.py

+ 16 - 8
finetune/train_1.7B_grpo.py

@@ -290,7 +290,11 @@ def score_expansion(query: str, expansion: str) -> float:
 
 def extract_query_from_prompt(prompt: str) -> str:
     if "Expand this search query:" in prompt:
-        return prompt.split("Expand this search query:")[-1].strip()
+        query = prompt.split("Expand this search query:")[-1].strip()
+        # Remove chat template artifacts if present
+        if "<|im_end|>" in query:
+            query = query.split("<|im_end|>")[0].strip()
+        return query
     return prompt.strip()
 
 
@@ -323,23 +327,27 @@ def main():
         print("Logging in to HuggingFace Hub...")
         login(token=hf_token)
 
+    # Load tokenizer first (needed for chat template)
+    print(f"Loading tokenizer from {BASE_MODEL}...")
+    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
+    if tokenizer.pad_token is None:
+        tokenizer.pad_token = tokenizer.eos_token
+
     # Load dataset
     print("Loading dataset...")
     dataset = load_dataset(DATASET, split="train")
 
     def extract_prompt(example):
-        return {"prompt": example["messages"][0]["content"]}
+        # Apply chat template so model sees the same format as SFT training
+        content = example["messages"][0]["content"]
+        messages = [{"role": "user", "content": content}]
+        formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+        return {"prompt": formatted}
 
     dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
     dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
     print(f"Using {len(dataset)} prompts for GRPO")
 
-    # Load tokenizer and model
-    print(f"Loading tokenizer from {BASE_MODEL}...")
-    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
-    if tokenizer.pad_token is None:
-        tokenizer.pad_token = tokenizer.eos_token
-
     print(f"Loading SFT model from {SFT_MODEL}...")
     base_model = AutoModelForCausalLM.from_pretrained(
         BASE_MODEL,

+ 16 - 8
finetune/train_4B_grpo.py

@@ -290,7 +290,11 @@ def score_expansion(query: str, expansion: str) -> float:
 
 def extract_query_from_prompt(prompt: str) -> str:
     if "Expand this search query:" in prompt:
-        return prompt.split("Expand this search query:")[-1].strip()
+        query = prompt.split("Expand this search query:")[-1].strip()
+        # Remove chat template artifacts if present
+        if "<|im_end|>" in query:
+            query = query.split("<|im_end|>")[0].strip()
+        return query
     return prompt.strip()
 
 
@@ -323,23 +327,27 @@ def main():
         print("Logging in to HuggingFace Hub...")
         login(token=hf_token)
 
+    # Load tokenizer first (needed for chat template)
+    print(f"Loading tokenizer from {BASE_MODEL}...")
+    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
+    if tokenizer.pad_token is None:
+        tokenizer.pad_token = tokenizer.eos_token
+
     # Load dataset
     print("Loading dataset...")
     dataset = load_dataset(DATASET, split="train")
 
     def extract_prompt(example):
-        return {"prompt": example["messages"][0]["content"]}
+        # Apply chat template so model sees the same format as SFT training
+        content = example["messages"][0]["content"]
+        messages = [{"role": "user", "content": content}]
+        formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+        return {"prompt": formatted}
 
     dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
     dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
     print(f"Using {len(dataset)} prompts for GRPO")
 
-    # Load tokenizer and model
-    print(f"Loading tokenizer from {BASE_MODEL}...")
-    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
-    if tokenizer.pad_token is None:
-        tokenizer.pad_token = tokenizer.eos_token
-
     print(f"Loading SFT model from {SFT_MODEL}...")
     base_model = AutoModelForCausalLM.from_pretrained(
         BASE_MODEL,