|
@@ -32,9 +32,11 @@ import argparse
|
|
|
import os
|
|
import os
|
|
|
import subprocess
|
|
import subprocess
|
|
|
import sys
|
|
import sys
|
|
|
|
|
+import time
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import yaml
|
|
import yaml
|
|
|
|
|
+from transformers import TrainerCallback
|
|
|
|
|
|
|
|
|
|
|
|
|
def export_gguf(model, tokenizer, output_dir: str, model_name: str):
|
|
def export_gguf(model, tokenizer, output_dir: str, model_name: str):
|
|
@@ -156,6 +158,24 @@ def export_gguf(model, tokenizer, output_dir: str, model_name: str):
|
|
|
print(f"GGUF files saved to: {gguf_dir}")
|
|
print(f"GGUF files saved to: {gguf_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+class TimedSaveCallback(TrainerCallback):
|
|
|
|
|
+ """Trigger periodic checkpoint saves based on elapsed wall-clock time."""
|
|
|
|
|
+
|
|
|
|
|
+ def __init__(self, interval_minutes: float):
|
|
|
|
|
+ self.interval_seconds = float(interval_minutes) * 60.0
|
|
|
|
|
+ self.last_save_time = time.time()
|
|
|
|
|
+
|
|
|
|
|
+ def on_step_end(self, args, state, control, **kwargs):
|
|
|
|
|
+ if not getattr(state, "is_world_process_zero", False):
|
|
|
|
|
+ return control
|
|
|
|
|
+
|
|
|
|
|
+ now = time.time()
|
|
|
|
|
+ if now - self.last_save_time >= self.interval_seconds:
|
|
|
|
|
+ control.should_save = True
|
|
|
|
|
+ self.last_save_time = now
|
|
|
|
|
+ return control
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def run_eval(model_path: str) -> float | None:
|
|
def run_eval(model_path: str) -> float | None:
|
|
|
"""Run eval.py on the trained model and return average score."""
|
|
"""Run eval.py on the trained model and return average score."""
|
|
|
print("\n" + "=" * 60)
|
|
print("\n" + "=" * 60)
|
|
@@ -188,9 +208,7 @@ def run_eval(model_path: str) -> float | None:
|
|
|
def cmd_sft(args):
|
|
def cmd_sft(args):
|
|
|
"""Run supervised fine-tuning."""
|
|
"""Run supervised fine-tuning."""
|
|
|
import torch
|
|
import torch
|
|
|
- import os
|
|
|
|
|
from datasets import load_dataset
|
|
from datasets import load_dataset
|
|
|
- import torch
|
|
|
|
|
import torch.distributed as dist
|
|
import torch.distributed as dist
|
|
|
from peft import LoraConfig
|
|
from peft import LoraConfig
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
@@ -276,6 +294,22 @@ def cmd_sft(args):
|
|
|
"{time}", now.strftime("%H:%M")
|
|
"{time}", now.strftime("%H:%M")
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
+ save_interval_minutes = cfg["training"].get("save_interval_minutes")
|
|
|
|
|
+ save_steps = cfg["training"].get("save_steps", 200)
|
|
|
|
|
+ save_total_limit = cfg["training"].get("save_total_limit", 2)
|
|
|
|
|
+ if save_interval_minutes:
|
|
|
|
|
+ # Prefer wall-clock checkpointing (for long jobs / preemption safety)
|
|
|
|
|
+ save_steps = max(save_steps, 10_000_000)
|
|
|
|
|
+
|
|
|
|
|
+ callbacks = []
|
|
|
|
|
+ if save_interval_minutes:
|
|
|
|
|
+ try:
|
|
|
|
|
+ interval_value = float(save_interval_minutes)
|
|
|
|
|
+ except (TypeError, ValueError):
|
|
|
|
|
+ interval_value = None
|
|
|
|
|
+ if interval_value and interval_value > 0:
|
|
|
|
|
+ callbacks.append(TimedSaveCallback(interval_value))
|
|
|
|
|
+
|
|
|
config = SFTConfig(
|
|
config = SFTConfig(
|
|
|
output_dir=output_dir,
|
|
output_dir=output_dir,
|
|
|
push_to_hub=push_to_hub,
|
|
push_to_hub=push_to_hub,
|
|
@@ -288,10 +322,10 @@ def cmd_sft(args):
|
|
|
max_length=cfg["training"]["max_length"],
|
|
max_length=cfg["training"]["max_length"],
|
|
|
logging_steps=10,
|
|
logging_steps=10,
|
|
|
save_strategy="steps",
|
|
save_strategy="steps",
|
|
|
- save_steps=200,
|
|
|
|
|
- save_total_limit=2,
|
|
|
|
|
|
|
+ save_steps=save_steps,
|
|
|
|
|
+ save_total_limit=save_total_limit,
|
|
|
eval_strategy="steps",
|
|
eval_strategy="steps",
|
|
|
- eval_steps=200,
|
|
|
|
|
|
|
+ eval_steps=cfg["training"].get("eval_steps", 200),
|
|
|
warmup_ratio=cfg["training"]["warmup_ratio"],
|
|
warmup_ratio=cfg["training"]["warmup_ratio"],
|
|
|
lr_scheduler_type=cfg["training"]["lr_scheduler"],
|
|
lr_scheduler_type=cfg["training"]["lr_scheduler"],
|
|
|
ddp_find_unused_parameters=cfg["training"].get(
|
|
ddp_find_unused_parameters=cfg["training"].get(
|
|
@@ -329,6 +363,7 @@ def cmd_sft(args):
|
|
|
args=config,
|
|
args=config,
|
|
|
peft_config=peft_config,
|
|
peft_config=peft_config,
|
|
|
processing_class=tokenizer,
|
|
processing_class=tokenizer,
|
|
|
|
|
+ callbacks=callbacks,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
print("Starting SFT training...")
|
|
print("Starting SFT training...")
|
|
@@ -378,6 +413,7 @@ def cmd_sft(args):
|
|
|
def cmd_grpo(args):
|
|
def cmd_grpo(args):
|
|
|
"""Run GRPO reinforcement learning on top of merged SFT weights."""
|
|
"""Run GRPO reinforcement learning on top of merged SFT weights."""
|
|
|
import torch
|
|
import torch
|
|
|
|
|
+ import torch.distributed as dist
|
|
|
import os
|
|
import os
|
|
|
from datasets import load_dataset
|
|
from datasets import load_dataset
|
|
|
from peft import LoraConfig, PeftModel, get_peft_model
|
|
from peft import LoraConfig, PeftModel, get_peft_model
|
|
@@ -494,6 +530,7 @@ def cmd_grpo(args):
|
|
|
task_type="CAUSAL_LM",
|
|
task_type="CAUSAL_LM",
|
|
|
target_modules=cfg["lora"]["target_modules"],
|
|
target_modules=cfg["lora"]["target_modules"],
|
|
|
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
|
|
modules_to_save=["embed_tokens", "lm_head"], # Critical for special tokens
|
|
|
|
|
+ ensure_weight_tying=True,
|
|
|
)
|
|
)
|
|
|
model = get_peft_model(model, grpo_lora_config)
|
|
model = get_peft_model(model, grpo_lora_config)
|
|
|
model.print_trainable_parameters()
|
|
model.print_trainable_parameters()
|
|
@@ -510,6 +547,24 @@ def cmd_grpo(args):
|
|
|
if isinstance(learning_rate, str):
|
|
if isinstance(learning_rate, str):
|
|
|
learning_rate = float(learning_rate)
|
|
learning_rate = float(learning_rate)
|
|
|
|
|
|
|
|
|
|
+ save_interval_minutes = cfg["training"].get("save_interval_minutes")
|
|
|
|
|
+ save_steps = cfg["training"].get("save_steps", 200)
|
|
|
|
|
+ save_total_limit = cfg["training"].get("save_total_limit", 2)
|
|
|
|
|
+ save_strategy = cfg["training"].get("save_strategy", "epoch")
|
|
|
|
|
+ if save_interval_minutes:
|
|
|
|
|
+ # Prefer wall-clock checkpointing (for long jobs / preemption safety)
|
|
|
|
|
+ save_steps = max(save_steps, 10_000_000)
|
|
|
|
|
+ save_strategy = "steps"
|
|
|
|
|
+
|
|
|
|
|
+ callbacks = []
|
|
|
|
|
+ if save_interval_minutes:
|
|
|
|
|
+ try:
|
|
|
|
|
+ interval_value = float(save_interval_minutes)
|
|
|
|
|
+ except (TypeError, ValueError):
|
|
|
|
|
+ interval_value = None
|
|
|
|
|
+ if interval_value and interval_value > 0:
|
|
|
|
|
+ callbacks.append(TimedSaveCallback(interval_value))
|
|
|
|
|
+
|
|
|
config = GRPOConfig(
|
|
config = GRPOConfig(
|
|
|
output_dir=output_dir,
|
|
output_dir=output_dir,
|
|
|
push_to_hub=push_to_hub,
|
|
push_to_hub=push_to_hub,
|
|
@@ -524,7 +579,9 @@ def cmd_grpo(args):
|
|
|
max_grad_norm=cfg["training"]["max_grad_norm"],
|
|
max_grad_norm=cfg["training"]["max_grad_norm"],
|
|
|
max_steps=cfg["training"].get("max_steps", -1),
|
|
max_steps=cfg["training"].get("max_steps", -1),
|
|
|
logging_steps=10,
|
|
logging_steps=10,
|
|
|
- save_strategy="epoch",
|
|
|
|
|
|
|
+ save_strategy=save_strategy,
|
|
|
|
|
+ save_steps=save_steps,
|
|
|
|
|
+ save_total_limit=save_total_limit,
|
|
|
bf16=True,
|
|
bf16=True,
|
|
|
skip_memory_metrics=True,
|
|
skip_memory_metrics=True,
|
|
|
report_to=report_to,
|
|
report_to=report_to,
|
|
@@ -539,11 +596,18 @@ def cmd_grpo(args):
|
|
|
args=config,
|
|
args=config,
|
|
|
train_dataset=dataset,
|
|
train_dataset=dataset,
|
|
|
reward_funcs=[QMDRewardFunction()],
|
|
reward_funcs=[QMDRewardFunction()],
|
|
|
|
|
+ callbacks=callbacks,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
print("Starting GRPO training...")
|
|
print("Starting GRPO training...")
|
|
|
trainer.train()
|
|
trainer.train()
|
|
|
|
|
|
|
|
|
|
+ is_main = os.environ.get("RANK", "0") == "0"
|
|
|
|
|
+ if dist.is_available() and dist.is_initialized():
|
|
|
|
|
+ dist.barrier()
|
|
|
|
|
+ if not is_main:
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
if push_to_hub:
|
|
if push_to_hub:
|
|
|
print("Pushing to Hub...")
|
|
print("Pushing to Hub...")
|
|
|
trainer.push_to_hub()
|
|
trainer.push_to_hub()
|