Parcourir la source

Fix GRPO model loading to use SFT base first

The GRPO adapter was trained on merged SFT weights, so loading it
directly on the base model results in 0% score. Added --sft-model
parameter to evals/run.py to load SFT first, then apply GRPO adapter.

With correct loading: GRPO scores 89.7% (all 26 queries Excellent).

Updated README with correct GRPO score and loading instructions.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke il y a 4 mois
Parent
commit
f96766cce8
2 fichiers modifiés avec 53 ajouts et 32 suppressions
  1. 26 5
      finetune/README.md
  2. 27 27
      finetune/evals/run.py

+ 26 - 5
finetune/README.md

@@ -29,7 +29,20 @@ hyde: To configure authentication, set the AUTH_SECRET environment variable and
 | Model | HuggingFace | Score | Status |
 |-------|-------------|-------|--------|
 | **Qwen3-0.6B v4 (SFT)** | [tobil/qmd-query-expansion-0.6B-v4](https://huggingface.co/tobil/qmd-query-expansion-0.6B-v4) | **98.8%** | Recommended |
-| Qwen3-0.6B v4 (GRPO) | [tobil/qmd-query-expansion-0.6B-v4-grpo](https://huggingface.co/tobil/qmd-query-expansion-0.6B-v4-grpo) | 0% | Failed - catastrophic drift |
+| Qwen3-0.6B v4 (GRPO) | [tobil/qmd-query-expansion-0.6B-v4-grpo](https://huggingface.co/tobil/qmd-query-expansion-0.6B-v4-grpo) | 89.7% | Requires SFT base (see note) |
+
+**Note on GRPO model**: The GRPO adapter was trained on top of the merged SFT model, so you must load SFT first:
+
+```python
+from peft import PeftModel
+from transformers import AutoModelForCausalLM
+
+# Load base → merge SFT → apply GRPO
+model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
+model = PeftModel.from_pretrained(model, "tobil/qmd-query-expansion-0.6B-v4")
+model = model.merge_and_unload()
+model = PeftModel.from_pretrained(model, "tobil/qmd-query-expansion-0.6B-v4-grpo")
+```
 
 ## Prompt Format
 
@@ -171,13 +184,21 @@ All 21 test queries rated "Excellent":
 | `git rebase vs merge` | 100% | Excellent |
 | `react useEffect cleanup` | 100% | Excellent |
 
-### GRPO v4 (0% - Failed)
+### GRPO v4 (89.7% - with SFT base)
 
-The GRPO training caused catastrophic drift. The model now generates verbose explanations instead of structured `lex:/vec:/hyde:` format.
+All 26 test queries rated "Excellent" when loaded correctly (SFT first, then GRPO adapter).
+
+| Query | Score | Rating |
+|-------|-------|--------|
+| `AWS Lambda functions` | 96% | Excellent |
+| `typescript async await` | 92% | Excellent |
+| `kubernetes vs docker swarm` | 92% | Excellent |
+| `who is TDS motorsports` | 89% | Excellent |
 
-**Root cause**: Reward function didn't enforce format strictly enough. The model learned that verbose explanations could score higher than concise structured output.
+**Important**: Loading GRPO directly on base model results in 0% (catastrophic drift) because GRPO was trained on merged SFT weights.
 
 ## Known Issues
 
-- **GRPO drift**: RL training causes the model to lose SFT-learned formatting. Needs stricter format enforcement in reward function.
+- **GRPO loading**: Requires SFT adapter loaded first before GRPO adapter (see model card note above)
 - **Key term preservation**: Some lex lines still too generic (missing query key terms)
+- **Entity scoring**: Named entity detection is heuristic-based, may miss some cases

+ 27 - 27
finetune/evals/run.py

@@ -39,31 +39,30 @@ def load_queries(path: str) -> list[str]:
     return queries
 
 
-def load_model(model_path: str, base_model: str = None):
+def load_model(model_path: str, base_model: str = None, sft_model: str = None):
     """
     Load the model (supports Hub URLs, local adapters, and merged models).
 
     Args:
         model_path: HF Hub ID (tobil/model), local adapter dir, or merged model dir
         base_model: Base model for tokenizer (auto-detected if not provided)
+        sft_model: SFT adapter to load first (for GRPO models that need SFT base)
     """
-    model_path = Path(model_path) if not model_path.startswith(("http", "hf://")) and "/" not in model_path[:20] else model_path
-    is_local = isinstance(model_path, Path) and model_path.exists()
+    model_path_str = str(model_path)
+    is_local = Path(model_path_str).exists()
 
     # Check if it's an adapter or full model
     is_adapter = False
     if is_local:
-        adapter_config = model_path / "adapter_config.json"
-        is_adapter = adapter_config.exists()
+        adapter_config_path = Path(model_path_str) / "adapter_config.json"
+        is_adapter = adapter_config_path.exists()
         if is_adapter and not base_model:
-            # Read base model from adapter config
-            import json
-            with open(adapter_config) as f:
+            with open(adapter_config_path) as f:
                 config = json.load(f)
                 base_model = config.get("base_model_name_or_path", "Qwen/Qwen3-0.6B")
     else:
-        # For Hub models, assume adapter unless --base-model not provided
-        is_adapter = base_model is not None
+        # For Hub models, assume adapter
+        is_adapter = True
 
     # Default base model
     if not base_model:
@@ -74,22 +73,22 @@ def load_model(model_path: str, base_model: str = None):
     if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token
 
+    print(f"Loading base model {base_model}...", file=sys.stderr)
+    model = AutoModelForCausalLM.from_pretrained(
+        base_model,
+        torch_dtype=torch.bfloat16,
+        device_map="auto",
+    )
+
+    # If SFT model specified, load and merge it first (for GRPO models)
+    if sft_model:
+        print(f"Loading SFT adapter {sft_model} and merging...", file=sys.stderr)
+        model = PeftModel.from_pretrained(model, sft_model)
+        model = model.merge_and_unload()
+
     if is_adapter:
-        print(f"Loading base model {base_model}...", file=sys.stderr)
-        base = AutoModelForCausalLM.from_pretrained(
-            base_model,
-            torch_dtype=torch.bfloat16,
-            device_map="auto",
-        )
-        print(f"Loading adapter from {model_path}...", file=sys.stderr)
-        model = PeftModel.from_pretrained(base, str(model_path))
-    else:
-        print(f"Loading model from {model_path}...", file=sys.stderr)
-        model = AutoModelForCausalLM.from_pretrained(
-            str(model_path),
-            torch_dtype=torch.bfloat16,
-            device_map="auto",
-        )
+        print(f"Loading adapter from {model_path_str}...", file=sys.stderr)
+        model = PeftModel.from_pretrained(model, model_path_str)
 
     model.eval()
     return model, tokenizer
@@ -138,7 +137,8 @@ def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200)
 def main():
     parser = argparse.ArgumentParser(description="Generate query expansions")
     parser.add_argument("--model", required=True, help="Model path (Hub or local)")
-    parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B", help="Base model for tokenizer")
+    parser.add_argument("--base-model", default=None, help="Base model for tokenizer")
+    parser.add_argument("--sft-model", default=None, help="SFT adapter to load first (for GRPO models)")
     parser.add_argument("--queries", default="evals/queries.txt", help="Queries file")
     parser.add_argument("--output", help="Output JSONL file (default: evals/results_{model_name}.jsonl)")
     parser.add_argument("--max-tokens", type=int, default=200, help="Max tokens to generate")
@@ -156,7 +156,7 @@ def main():
     print(f"Loaded {len(queries)} queries from {args.queries}", file=sys.stderr)
 
     # Load model
-    model, tokenizer = load_model(args.model, args.base_model)
+    model, tokenizer = load_model(args.model, args.base_model, args.sft_model)
 
     # Generate expansions
     print(f"Generating expansions...", file=sys.stderr)