Explorar o código

Fix GRPO model loading to use SFT base first

The GRPO adapter was trained on merged SFT weights, so loading it
directly on the base model results in 0% score. Added --sft-model
parameter to evals/run.py to load SFT first, then apply GRPO adapter.

With correct loading: GRPO scores 89.7% (all 26 queries Excellent).

Updated README with correct GRPO score and loading instructions.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke hai 4 meses
pai
achega
f96766cce8
Modificáronse 2 ficheiros con 53 adicións e 32 borrados
  1. 26 5
      finetune/README.md
  2. 27 27
      finetune/evals/run.py

+ 26 - 5
finetune/README.md

@@ -29,7 +29,20 @@ hyde: To configure authentication, set the AUTH_SECRET environment variable and
 | Model | HuggingFace | Score | Status |
 |-------|-------------|-------|--------|
 | **Qwen3-0.6B v4 (SFT)** | [tobil/qmd-query-expansion-0.6B-v4](https://huggingface.co/tobil/qmd-query-expansion-0.6B-v4) | **98.8%** | Recommended |
-| Qwen3-0.6B v4 (GRPO) | [tobil/qmd-query-expansion-0.6B-v4-grpo](https://huggingface.co/tobil/qmd-query-expansion-0.6B-v4-grpo) | 0% | Failed - catastrophic drift |
+| Qwen3-0.6B v4 (GRPO) | [tobil/qmd-query-expansion-0.6B-v4-grpo](https://huggingface.co/tobil/qmd-query-expansion-0.6B-v4-grpo) | 89.7% | Requires SFT base (see note) |
+
+**Note on GRPO model**: The GRPO adapter was trained on top of the merged SFT model, so you must load SFT first:
+
+```python
+from peft import PeftModel
+from transformers import AutoModelForCausalLM
+
+# Load base → merge SFT → apply GRPO
+model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
+model = PeftModel.from_pretrained(model, "tobil/qmd-query-expansion-0.6B-v4")
+model = model.merge_and_unload()
+model = PeftModel.from_pretrained(model, "tobil/qmd-query-expansion-0.6B-v4-grpo")
+```
 
 ## Prompt Format
 
@@ -171,13 +184,21 @@ All 21 test queries rated "Excellent":
 | `git rebase vs merge` | 100% | Excellent |
 | `react useEffect cleanup` | 100% | Excellent |
 
-### GRPO v4 (0% - Failed)
+### GRPO v4 (89.7% - with SFT base)
 
-The GRPO training caused catastrophic drift. The model now generates verbose explanations instead of structured `lex:/vec:/hyde:` format.
+All 26 test queries rated "Excellent" when loaded correctly (SFT first, then GRPO adapter).
+
+| Query | Score | Rating |
+|-------|-------|--------|
+| `AWS Lambda functions` | 96% | Excellent |
+| `typescript async await` | 92% | Excellent |
+| `kubernetes vs docker swarm` | 92% | Excellent |
+| `who is TDS motorsports` | 89% | Excellent |
 
-**Root cause**: Reward function didn't enforce format strictly enough. The model learned that verbose explanations could score higher than concise structured output.
+**Important**: Loading GRPO directly on base model results in 0% (catastrophic drift) because GRPO was trained on merged SFT weights.
 
 ## Known Issues
 
-- **GRPO drift**: RL training causes the model to lose SFT-learned formatting. Needs stricter format enforcement in reward function.
+- **GRPO loading**: Requires SFT adapter loaded first before GRPO adapter (see model card note above)
 - **Key term preservation**: Some lex lines still too generic (missing query key terms)
+- **Entity scoring**: Named entity detection is heuristic-based, may miss some cases

+ 27 - 27
finetune/evals/run.py

@@ -39,31 +39,30 @@ def load_queries(path: str) -> list[str]:
     return queries
 
 
-def load_model(model_path: str, base_model: str = None):
+def load_model(model_path: str, base_model: str = None, sft_model: str = None):
     """
     Load the model (supports Hub URLs, local adapters, and merged models).
 
     Args:
         model_path: HF Hub ID (tobil/model), local adapter dir, or merged model dir
         base_model: Base model for tokenizer (auto-detected if not provided)
+        sft_model: SFT adapter to load first (for GRPO models that need SFT base)
     """
-    model_path = Path(model_path) if not model_path.startswith(("http", "hf://")) and "/" not in model_path[:20] else model_path
-    is_local = isinstance(model_path, Path) and model_path.exists()
+    model_path_str = str(model_path)
+    is_local = Path(model_path_str).exists()
 
     # Check if it's an adapter or full model
     is_adapter = False
     if is_local:
-        adapter_config = model_path / "adapter_config.json"
-        is_adapter = adapter_config.exists()
+        adapter_config_path = Path(model_path_str) / "adapter_config.json"
+        is_adapter = adapter_config_path.exists()
         if is_adapter and not base_model:
-            # Read base model from adapter config
-            import json
-            with open(adapter_config) as f:
+            with open(adapter_config_path) as f:
                 config = json.load(f)
                 base_model = config.get("base_model_name_or_path", "Qwen/Qwen3-0.6B")
     else:
-        # For Hub models, assume adapter unless --base-model not provided
-        is_adapter = base_model is not None
+        # For Hub models, assume adapter
+        is_adapter = True
 
     # Default base model
     if not base_model:
@@ -74,22 +73,22 @@ def load_model(model_path: str, base_model: str = None):
     if tokenizer.pad_token is None:
         tokenizer.pad_token = tokenizer.eos_token
 
+    print(f"Loading base model {base_model}...", file=sys.stderr)
+    model = AutoModelForCausalLM.from_pretrained(
+        base_model,
+        torch_dtype=torch.bfloat16,
+        device_map="auto",
+    )
+
+    # If SFT model specified, load and merge it first (for GRPO models)
+    if sft_model:
+        print(f"Loading SFT adapter {sft_model} and merging...", file=sys.stderr)
+        model = PeftModel.from_pretrained(model, sft_model)
+        model = model.merge_and_unload()
+
     if is_adapter:
-        print(f"Loading base model {base_model}...", file=sys.stderr)
-        base = AutoModelForCausalLM.from_pretrained(
-            base_model,
-            torch_dtype=torch.bfloat16,
-            device_map="auto",
-        )
-        print(f"Loading adapter from {model_path}...", file=sys.stderr)
-        model = PeftModel.from_pretrained(base, str(model_path))
-    else:
-        print(f"Loading model from {model_path}...", file=sys.stderr)
-        model = AutoModelForCausalLM.from_pretrained(
-            str(model_path),
-            torch_dtype=torch.bfloat16,
-            device_map="auto",
-        )
+        print(f"Loading adapter from {model_path_str}...", file=sys.stderr)
+        model = PeftModel.from_pretrained(model, model_path_str)
 
     model.eval()
     return model, tokenizer
@@ -138,7 +137,8 @@ def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200)
 def main():
     parser = argparse.ArgumentParser(description="Generate query expansions")
     parser.add_argument("--model", required=True, help="Model path (Hub or local)")
-    parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B", help="Base model for tokenizer")
+    parser.add_argument("--base-model", default=None, help="Base model for tokenizer")
+    parser.add_argument("--sft-model", default=None, help="SFT adapter to load first (for GRPO models)")
     parser.add_argument("--queries", default="evals/queries.txt", help="Queries file")
     parser.add_argument("--output", help="Output JSONL file (default: evals/results_{model_name}.jsonl)")
     parser.add_argument("--max-tokens", type=int, default=200, help="Max tokens to generate")
@@ -156,7 +156,7 @@ def main():
     print(f"Loaded {len(queries)} queries from {args.queries}", file=sys.stderr)
 
     # Load model
-    model, tokenizer = load_model(args.model, args.base_model)
+    model, tokenizer = load_model(args.model, args.base_model, args.sft_model)
 
     # Generate expansions
     print(f"Generating expansions...", file=sys.stderr)