瀏覽代碼

Add wall-clock checkpoints and full eval defaults

Tobi Lütke 3 月之前
父節點
當前提交
cbeeb1f89b

+ 4 - 0
finetune/configs/grpo.yaml

@@ -30,6 +30,10 @@ training:
   learning_rate: 0.0000005
   max_grad_norm: 0.5
   max_steps: 200
+  # Save checkpoints every 30 minutes
+  save_interval_minutes: 30
+  # Fallback time-step save cadence if needed (not used for wall-clock mode)
+  save_steps: 50
 
 grpo:
   num_generations: 4

+ 5 - 0
finetune/configs/sft.yaml

@@ -23,6 +23,11 @@ training:
   max_length: 512
   warmup_ratio: 0.03
   lr_scheduler: "cosine"
+  # Save checkpoints every 30 minutes
+  save_interval_minutes: 30
+  # Fallback time-step save cadence if needed (not used for wall-clock mode)
+  save_steps: 200
+  save_total_limit: 3
 
 lora:
   rank: 16

+ 4 - 0
finetune/configs/sft_local.yaml

@@ -21,6 +21,10 @@ training:
   warmup_ratio: 0.03
   lr_scheduler: "cosine"
   ddp_find_unused_parameters: false
+  # Save checkpoints every 30 minutes
+  save_interval_minutes: 30
+  # Fallback time-step save cadence if needed (not used for wall-clock mode)
+  save_steps: 200
 
 lora:
   rank: 16

+ 3 - 5
finetune/data/train/dataset_info.json

@@ -1,11 +1,9 @@
 {
   "dataset_name": "qmd-query-expansion",
-  "train_samples": 5440,
-  "val_samples": 605,
-  "short_query_pct": 11.1,
+  "train_samples": 2806,
+  "val_samples": 312,
+  "short_query_pct": 15.5,
   "columns": [
-    "prompt",
-    "completion",
     "text",
     "messages"
   ]

+ 9 - 1
finetune/dataset/prepare_data.py

@@ -48,10 +48,17 @@ def format_for_training(ex: TrainingExample) -> dict:
     tokenizer = get_tokenizer()
     output_text = output_items_to_text(ex.output)
 
+    user_prompt = f"/no_think Expand this search query: {ex.query}"
+    if ex.intent:
+        user_prompt = (
+            f"/no_think Expand this search query: {ex.query}\n"
+            f"Query intent: {ex.intent.strip()}"
+        )
+
     messages = [
         {
             "role": "user",
-            "content": f"/no_think Expand this search query: {ex.query}",
+            "content": user_prompt,
         },
         {"role": "assistant", "content": output_text},
     ]
@@ -165,6 +172,7 @@ def main():
         "train_samples": len(train_data),
         "val_samples": len(val_data),
         "short_query_pct": round(100 * short_final / len(all_examples), 1),
+        "columns": ["text", "messages"],
     }
     with open(output_dir / "dataset_info.json", "w") as f:
         json.dump(dataset_info, f, indent=2)

+ 18 - 20
finetune/eval.py

@@ -12,14 +12,15 @@ Minimal QMD query expansion evaluator.
 
 Usage:
     uv run eval.py ./outputs/sft
-    uv run eval.py tobil/qmd-query-expansion-1.7B --queries evals/queries.txt
+    uv run eval.py ./outputs/sft --queries evals/queries.txt
+
+By default, query file defaults to evals/queries.txt and runs all queries unless --max-queries is set.
 """
 
 import argparse
 import json
 import re
 import sys
-import os
 from pathlib import Path
 
 # Import reward scoring
@@ -27,18 +28,8 @@ sys.path.insert(0, str(Path(__file__).parent))
 from reward import score_expansion_detailed
 
 
-QUERIES = [
-    "how to configure authentication",
-    "docker compose networking",
-    "auth",
-    "who is TDS motorsports",
-    "React hooks tutorial",
-    "recent news about Shopify",
-    "how to implement caching with redis in nodejs",
-    "auth /only:lex",
-    "kubernetes pod deployment /only:vec",
-    "AWS Lambda cold start /only:hyde",
-]
+
+DEFAULT_QUERY_FILE = Path(__file__).parent / "evals" / "queries.txt"
 
 
 def load_model(model_path: str):
@@ -127,7 +118,11 @@ def generate_batch(
 def main():
     parser = argparse.ArgumentParser(description="Evaluate QMD model")
     parser.add_argument("model", help="Model path (local or HF)")
-    parser.add_argument("--queries", help="Queries file (one per line)")
+    parser.add_argument(
+        "--queries",
+        default=str(DEFAULT_QUERY_FILE),
+        help="Queries file (one per line) [default: evals/queries.txt]",
+    )
     parser.add_argument(
         "--max-new-tokens",
         type=int,
@@ -154,11 +149,14 @@ def main():
     )
     args = parser.parse_args()
 
-    # Load queries
-    queries = QUERIES
-    if args.queries:
-        with open(args.queries) as f:
-            queries = [l.strip() for l in f if l.strip() and not l.startswith("#")]
+    # Load queries (default to full evals/queries.txt)
+    query_file = Path(args.queries)
+    if not query_file.exists():
+        raise FileNotFoundError(f"Queries file not found: {query_file}")
+    with query_file.open(encoding="utf-8") as f:
+        queries = [
+            l.strip() for l in f if l.strip() and not l.strip().startswith("#")
+        ]
 
     if args.max_queries and args.max_queries > 0:
         queries = queries[: args.max_queries]

+ 70 - 6
finetune/train.py

@@ -32,9 +32,11 @@ import argparse
 import os
 import subprocess
 import sys
+import time
 from pathlib import Path
 
 import yaml
+from transformers import TrainerCallback
 
 
 def export_gguf(model, tokenizer, output_dir: str, model_name: str):
@@ -156,6 +158,24 @@ def export_gguf(model, tokenizer, output_dir: str, model_name: str):
     print(f"GGUF files saved to: {gguf_dir}")
 
 
+class TimedSaveCallback(TrainerCallback):
+    """Trigger periodic checkpoint saves based on elapsed wall-clock time."""
+
+    def __init__(self, interval_minutes: float):
+        self.interval_seconds = float(interval_minutes) * 60.0
+        self.last_save_time = time.time()
+
+    def on_step_end(self, args, state, control, **kwargs):
+        if not getattr(state, "is_world_process_zero", False):
+            return control
+
+        now = time.time()
+        if now - self.last_save_time >= self.interval_seconds:
+            control.should_save = True
+            self.last_save_time = now
+        return control
+
+
 def run_eval(model_path: str) -> float | None:
     """Run eval.py on the trained model and return average score."""
     print("\n" + "=" * 60)
@@ -188,9 +208,7 @@ def run_eval(model_path: str) -> float | None:
 def cmd_sft(args):
     """Run supervised fine-tuning."""
     import torch
-    import os
     from datasets import load_dataset
-    import torch
     import torch.distributed as dist
     from peft import LoraConfig
     from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -276,6 +294,22 @@ def cmd_sft(args):
             "{time}", now.strftime("%H:%M")
         )
 
+    save_interval_minutes = cfg["training"].get("save_interval_minutes")
+    save_steps = cfg["training"].get("save_steps", 200)
+    save_total_limit = cfg["training"].get("save_total_limit", 2)
+    if save_interval_minutes:
+        # Prefer wall-clock checkpointing (for long jobs / preemption safety)
+        save_steps = max(save_steps, 10_000_000)
+
+    callbacks = []
+    if save_interval_minutes:
+        try:
+            interval_value = float(save_interval_minutes)
+        except (TypeError, ValueError):
+            interval_value = None
+        if interval_value and interval_value > 0:
+            callbacks.append(TimedSaveCallback(interval_value))
+
     config = SFTConfig(
         output_dir=output_dir,
         push_to_hub=push_to_hub,
@@ -288,10 +322,10 @@ def cmd_sft(args):
         max_length=cfg["training"]["max_length"],
         logging_steps=10,
         save_strategy="steps",
-        save_steps=200,
-        save_total_limit=2,
+        save_steps=save_steps,
+        save_total_limit=save_total_limit,
         eval_strategy="steps",
-        eval_steps=200,
+        eval_steps=cfg["training"].get("eval_steps", 200),
         warmup_ratio=cfg["training"]["warmup_ratio"],
         lr_scheduler_type=cfg["training"]["lr_scheduler"],
         ddp_find_unused_parameters=cfg["training"].get(
@@ -329,6 +363,7 @@ def cmd_sft(args):
         args=config,
         peft_config=peft_config,
         processing_class=tokenizer,
+        callbacks=callbacks,
     )
 
     print("Starting SFT training...")
@@ -378,6 +413,7 @@ def cmd_sft(args):
 def cmd_grpo(args):
     """Run GRPO reinforcement learning on top of merged SFT weights."""
     import torch
+    import torch.distributed as dist
     import os
     from datasets import load_dataset
     from peft import LoraConfig, PeftModel, get_peft_model
@@ -494,6 +530,7 @@ def cmd_grpo(args):
         task_type="CAUSAL_LM",
         target_modules=cfg["lora"]["target_modules"],
         modules_to_save=["embed_tokens", "lm_head"],  # Critical for special tokens
+        ensure_weight_tying=True,
     )
     model = get_peft_model(model, grpo_lora_config)
     model.print_trainable_parameters()
@@ -510,6 +547,24 @@ def cmd_grpo(args):
     if isinstance(learning_rate, str):
         learning_rate = float(learning_rate)
 
+    save_interval_minutes = cfg["training"].get("save_interval_minutes")
+    save_steps = cfg["training"].get("save_steps", 200)
+    save_total_limit = cfg["training"].get("save_total_limit", 2)
+    save_strategy = cfg["training"].get("save_strategy", "epoch")
+    if save_interval_minutes:
+        # Prefer wall-clock checkpointing (for long jobs / preemption safety)
+        save_steps = max(save_steps, 10_000_000)
+        save_strategy = "steps"
+
+    callbacks = []
+    if save_interval_minutes:
+        try:
+            interval_value = float(save_interval_minutes)
+        except (TypeError, ValueError):
+            interval_value = None
+        if interval_value and interval_value > 0:
+            callbacks.append(TimedSaveCallback(interval_value))
+
     config = GRPOConfig(
         output_dir=output_dir,
         push_to_hub=push_to_hub,
@@ -524,7 +579,9 @@ def cmd_grpo(args):
         max_grad_norm=cfg["training"]["max_grad_norm"],
         max_steps=cfg["training"].get("max_steps", -1),
         logging_steps=10,
-        save_strategy="epoch",
+        save_strategy=save_strategy,
+        save_steps=save_steps,
+        save_total_limit=save_total_limit,
         bf16=True,
         skip_memory_metrics=True,
         report_to=report_to,
@@ -539,11 +596,18 @@ def cmd_grpo(args):
         args=config,
         train_dataset=dataset,
         reward_funcs=[QMDRewardFunction()],
+        callbacks=callbacks,
     )
 
     print("Starting GRPO training...")
     trainer.train()
 
+    is_main = os.environ.get("RANK", "0") == "0"
+    if dist.is_available() and dist.is_initialized():
+        dist.barrier()
+    if not is_main:
+        return
+
     if push_to_hub:
         print("Pushing to Hub...")
         trainer.push_to_hub()