Sfoglia il codice sorgente

Fix GRPO training script for TRL API compatibility

- Use max_completion_length instead of max_new_tokens
- Use processing_class instead of tokenizer
- Use args instead of config for GRPOTrainer
- Add __name__ attribute to reward function class
- Accept **kwargs in reward function for extra TRL args
- Add new LoRA adapter after merging SFT weights

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 4 mesi fa
parent
commit
0353994e7d
1 ha cambiato i file con 20 aggiunte e 6 eliminazioni
  1. 20 6
      finetune/train_grpo.py

+ 20 - 6
finetune/train_grpo.py

@@ -150,6 +150,7 @@ def compute_length_reward(text: str) -> float:
 
 class QMDRewardFunction:
     """Combined reward function for QMD query expansion."""
+    __name__ = "qmd_format_diversity_reward"
 
     def __init__(self):
         # Load a small embedding model for diversity computation
@@ -157,7 +158,7 @@ class QMDRewardFunction:
         self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
         print("Embedding model loaded.")
 
-    def __call__(self, completions: list[str], prompts: list[str] = None) -> list[float]:
+    def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
         """Compute rewards for a batch of completions."""
         rewards = []
 
@@ -236,6 +237,20 @@ def main():
     model = model.merge_and_unload()  # Merge LoRA weights
     print("Model loaded and LoRA merged.")
 
+    # Add new LoRA adapter for GRPO training
+    from peft import get_peft_model
+    grpo_lora_config = LoraConfig(
+        r=8,
+        lora_alpha=16,
+        lora_dropout=0.05,
+        bias="none",
+        task_type="CAUSAL_LM",
+        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
+    )
+    model = get_peft_model(model, grpo_lora_config)
+    model.print_trainable_parameters()
+    print("Added new LoRA adapter for GRPO.")
+
     # Initialize reward function
     reward_fn = QMDRewardFunction()
 
@@ -247,8 +262,7 @@ def main():
 
         # GRPO specific
         num_generations=4,  # Generate 4 completions per prompt
-        max_new_tokens=256,
-        temperature=0.8,
+        max_completion_length=256,
 
         # Training
         num_train_epochs=args.epochs,
@@ -270,10 +284,10 @@ def main():
     print("Initializing GRPO trainer...")
     trainer = GRPOTrainer(
         model=model,
-        tokenizer=tokenizer,
-        config=config,
+        processing_class=tokenizer,
+        args=config,
         train_dataset=dataset,
-        reward_funcs=reward_fn,
+        reward_funcs=[reward_fn],
     )
 
     # Train