|
|
@@ -150,6 +150,7 @@ def compute_length_reward(text: str) -> float:
|
|
|
|
|
|
class QMDRewardFunction:
|
|
|
"""Combined reward function for QMD query expansion."""
|
|
|
+ __name__ = "qmd_format_diversity_reward"
|
|
|
|
|
|
def __init__(self):
|
|
|
# Load a small embedding model for diversity computation
|
|
|
@@ -157,7 +158,7 @@ class QMDRewardFunction:
|
|
|
self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
print("Embedding model loaded.")
|
|
|
|
|
|
- def __call__(self, completions: list[str], prompts: list[str] = None) -> list[float]:
|
|
|
+ def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
|
|
|
"""Compute rewards for a batch of completions."""
|
|
|
rewards = []
|
|
|
|
|
|
@@ -236,6 +237,20 @@ def main():
|
|
|
model = model.merge_and_unload() # Merge LoRA weights
|
|
|
print("Model loaded and LoRA merged.")
|
|
|
|
|
|
+ # Add new LoRA adapter for GRPO training
|
|
|
+ from peft import get_peft_model
|
|
|
+ grpo_lora_config = LoraConfig(
|
|
|
+ r=8,
|
|
|
+ lora_alpha=16,
|
|
|
+ lora_dropout=0.05,
|
|
|
+ bias="none",
|
|
|
+ task_type="CAUSAL_LM",
|
|
|
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
|
+ )
|
|
|
+ model = get_peft_model(model, grpo_lora_config)
|
|
|
+ model.print_trainable_parameters()
|
|
|
+ print("Added new LoRA adapter for GRPO.")
|
|
|
+
|
|
|
# Initialize reward function
|
|
|
reward_fn = QMDRewardFunction()
|
|
|
|
|
|
@@ -247,8 +262,7 @@ def main():
|
|
|
|
|
|
# GRPO specific
|
|
|
num_generations=4, # Generate 4 completions per prompt
|
|
|
- max_new_tokens=256,
|
|
|
- temperature=0.8,
|
|
|
+ max_completion_length=256,
|
|
|
|
|
|
# Training
|
|
|
num_train_epochs=args.epochs,
|
|
|
@@ -270,10 +284,10 @@ def main():
|
|
|
print("Initializing GRPO trainer...")
|
|
|
trainer = GRPOTrainer(
|
|
|
model=model,
|
|
|
- tokenizer=tokenizer,
|
|
|
- config=config,
|
|
|
+ processing_class=tokenizer,
|
|
|
+ args=config,
|
|
|
train_dataset=dataset,
|
|
|
- reward_funcs=reward_fn,
|
|
|
+ reward_funcs=[reward_fn],
|
|
|
)
|
|
|
|
|
|
# Train
|