Kaynağa Gözat

Add HuggingFace login and comprehensive scoring to GRPO v2 training

- Add explicit HF_TOKEN login before training
- Use SCORING.md criteria as RL reward function
- Conservative training: LR 1e-6, LoRA rank 4
- Reward scores: good=0.94, bad=0.38

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Tobi Lutke 4 ay önce
ebeveyn
işleme
d32e13c172
1 değiştirilmiş dosya ile 199 ekleme ve 133 silme
  1. 199 133
      finetune/train_grpo.py

+ 199 - 133
finetune/train_grpo.py

@@ -5,180 +5,228 @@
 #     "peft>=0.7.0",
 #     "transformers>=4.45.0",
 #     "accelerate>=0.24.0",
+#     "huggingface_hub>=0.20.0",
 #     "trackio",
 #     "datasets",
 #     "bitsandbytes",
-#     "sentence-transformers",
 # ]
 # ///
 """
 GRPO (Group Relative Policy Optimization) training for QMD query expansion.
 
-Reward Type 2: Format + Diversity
-- Rewards correct lex/vec/hyde format
-- Penalizes repetition between lines
-- Rewards semantic diversity of expansions
+Uses the comprehensive scoring system from SCORING.md:
+- Format (30%): Must have lex: and vec: prefixes
+- Diversity (30%): No echoing query, diverse expansions
+- Hyde (20%): Concise, no newlines, no repetition
+- Quality (20%): lex=keywords, vec=natural language
 
 Usage:
     uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
 """
 
+import os
 import re
 import torch
 import trackio
+from collections import Counter
 from datasets import load_dataset
-from peft import LoraConfig, PeftModel
+from huggingface_hub import login
+from peft import LoraConfig, PeftModel, get_peft_model
 from transformers import AutoModelForCausalLM, AutoTokenizer
 from trl import GRPOTrainer, GRPOConfig
-from sentence_transformers import SentenceTransformer
+
+STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
 
 # ============================================================================
-# Reward Function: Format + Diversity
+# Scoring Functions (from SCORING.md)
 # ============================================================================
 
 def parse_expansion(text: str) -> dict:
-    """Parse expansion output into lex/vec/hyde components."""
-    result = {"lex": [], "vec": [], "hyde": []}
+    """Parse expansion into structured format."""
+    lines = text.strip().split("\n")
+    result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
 
-    for line in text.strip().split("\n"):
+    for line in lines:
         line = line.strip()
+        if not line:
+            continue
         if line.startswith("lex:"):
             result["lex"].append(line[4:].strip())
         elif line.startswith("vec:"):
             result["vec"].append(line[4:].strip())
         elif line.startswith("hyde:"):
             result["hyde"].append(line[5:].strip())
+        else:
+            result["invalid"].append(line)
 
     return result
 
 
-def compute_format_reward(text: str) -> float:
-    """
-    Reward for correct format:
-    - Has at least 1 lex line: +0.2
-    - Has at least 1 vec line: +0.2
-    - Has hyde line: +0.1
-    - Correct line format (type: content): +0.1 per line (max 0.3)
-    - No garbage/malformed lines: +0.2
-    """
-    reward = 0.0
-    parsed = parse_expansion(text)
+def edit_distance_simple(a: str, b: str) -> int:
+    """Simple word-level edit distance."""
+    words_a = set(a.lower().split())
+    words_b = set(b.lower().split())
+    return len(words_a ^ words_b)
 
-    # Check required components
-    if parsed["lex"]:
-        reward += 0.2
-    if parsed["vec"]:
-        reward += 0.2
-    if parsed["hyde"]:
-        reward += 0.1
 
-    # Check line format
-    lines = text.strip().split("\n")
-    valid_lines = 0
-    for line in lines:
-        if re.match(r'^(lex|vec|hyde):\s*.+', line.strip()):
-            valid_lines += 1
+def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
+    """Check if two strings are sufficiently different."""
+    a, b = a.lower().strip(), b.lower().strip()
+    if a == b:
+        return False
+    if a in b or b in a:
+        return False
+    return edit_distance_simple(a, b) >= min_distance
 
-    reward += min(0.3, valid_lines * 0.1)
 
-    # Penalize malformed lines
-    malformed = len(lines) - valid_lines
-    if malformed == 0:
-        reward += 0.2
-    else:
-        reward -= malformed * 0.1
+def echoes_query(expansion: str, query: str) -> bool:
+    """Check if expansion is just echoing the query."""
+    exp = expansion.lower().strip()
+    q = query.lower().strip()
+    if exp == q:
+        return True
+    if q in exp and len(exp) < len(q) + 10:
+        return True
+    return False
+
 
-    return max(0.0, min(1.0, reward))
+def word_repetition_penalty(text: str) -> int:
+    """Count penalty for repeated words (excluding stopwords)."""
+    words = re.findall(r'\b\w+\b', text.lower())
+    counts = Counter(words)
+    penalty = 0
+    for word, count in counts.items():
+        if count >= 3 and word not in STOPWORDS and len(word) > 2:
+            penalty += (count - 2) * 2
+    return penalty
 
 
-def compute_diversity_reward(text: str, embedder) -> float:
+def score_expansion(query: str, expansion: str) -> float:
     """
-    Reward for diverse expansions:
-    - Penalize exact duplicates
-    - Reward semantic distance between expansions
+    Score an expansion based on SCORING.md criteria.
+    Returns normalized score 0.0-1.0 for RL reward.
     """
-    parsed = parse_expansion(text)
-    all_expansions = parsed["lex"] + parsed["vec"] + parsed["hyde"]
-
-    if len(all_expansions) < 2:
-        return 0.0
-
-    # Penalize exact duplicates
-    unique = set(e.lower() for e in all_expansions)
-    duplicate_penalty = (len(all_expansions) - len(unique)) * 0.2
-
-    # Compute semantic diversity
-    if len(unique) >= 2:
-        try:
-            embeddings = embedder.encode(list(unique))
-            # Compute pairwise cosine similarities
-            from torch.nn.functional import cosine_similarity
-            emb_tensor = torch.tensor(embeddings)
-
-            similarities = []
-            for i in range(len(emb_tensor)):
-                for j in range(i + 1, len(emb_tensor)):
-                    sim = cosine_similarity(
-                        emb_tensor[i].unsqueeze(0),
-                        emb_tensor[j].unsqueeze(0)
-                    ).item()
-                    similarities.append(sim)
-
-            # Lower similarity = higher diversity = higher reward
-            avg_similarity = sum(similarities) / len(similarities) if similarities else 1.0
-            diversity_reward = 1.0 - avg_similarity  # 0 = identical, 1 = orthogonal
-        except Exception:
-            diversity_reward = 0.0
+    parsed = parse_expansion(expansion)
+
+    # === FORMAT (0-30) ===
+    format_score = 0
+    if parsed["lex"]:
+        format_score += 10
+    if parsed["vec"]:
+        format_score += 10
+    if not parsed["invalid"]:
+        format_score += 10
     else:
-        diversity_reward = 0.0
+        format_score += max(0, 10 - len(parsed["invalid"]) * 5)
+
+    # === DIVERSITY (0-30) ===
+    diversity_score = 0
+
+    # 2+ different types
+    types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
+    if types_present >= 2:
+        diversity_score += 10
+
+    # 2+ total expansions
+    total_expansions = len(parsed["lex"]) + len(parsed["vec"])
+    if total_expansions >= 2:
+        diversity_score += 5
+
+    # Lex diversity
+    lex_score = 5
+    for i, a in enumerate(parsed["lex"]):
+        for b in parsed["lex"][i+1:]:
+            if not is_diverse(a, b, 2):
+                lex_score -= 2
+    diversity_score += max(0, lex_score)
+
+    # Vec diversity
+    vec_score = 5
+    for i, a in enumerate(parsed["vec"]):
+        for b in parsed["vec"][i+1:]:
+            if not is_diverse(a, b, 3):
+                vec_score -= 2
+    diversity_score += max(0, vec_score)
+
+    # Don't echo query
+    echo_score = 5
+    for exp in parsed["lex"] + parsed["vec"]:
+        if echoes_query(exp, query):
+            echo_score -= 3  # Heavier penalty for echoing
+    diversity_score += max(0, echo_score)
+
+    # === HYDE (0-20) ===
+    hyde_score = 0
+    if parsed["hyde"]:
+        hyde_text = parsed["hyde"][0]
+        hyde_score += 5  # Present
+
+        # Length check (50-200 chars ideal)
+        hyde_len = len(hyde_text)
+        if 50 <= hyde_len <= 200:
+            hyde_score += 5
+        elif hyde_len < 50:
+            hyde_score += 2
+
+        # No newlines
+        if "\n" not in hyde_text:
+            hyde_score += 5
+
+        # No repetition
+        rep_penalty = word_repetition_penalty(hyde_text)
+        hyde_score += max(0, 5 - rep_penalty)
+
+    # === QUALITY (0-20) ===
+    quality_score = 10  # Base
+
+    # Lex should be shorter than vec
+    if parsed["lex"] and parsed["vec"]:
+        avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
+        avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
+        if avg_lex <= avg_vec:
+            quality_score += 5
+
+    # Vec should be natural language
+    if parsed["vec"]:
+        natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
+        if natural == len(parsed["vec"]):
+            quality_score += 5
+        else:
+            quality_score += 2
 
-    return max(0.0, diversity_reward - duplicate_penalty)
+    # === TOTAL ===
+    total = format_score + diversity_score + hyde_score + quality_score
+    max_possible = 100 if parsed["hyde"] else 80
 
+    # Normalize to 0-1
+    return total / max_possible
 
-def compute_length_reward(text: str) -> float:
-    """Reward appropriate length (not too short, not too long)."""
-    lines = [l for l in text.strip().split("\n") if l.strip()]
 
-    # Ideal: 3-6 lines
-    if 3 <= len(lines) <= 6:
-        return 0.2
-    elif 2 <= len(lines) <= 7:
-        return 0.1
-    else:
-        return 0.0
+def extract_query_from_prompt(prompt: str) -> str:
+    """Extract the query from the prompt template."""
+    # Prompt format: "Expand this search query:\n\n{query}"
+    if "Expand this search query:" in prompt:
+        return prompt.split("Expand this search query:")[-1].strip()
+    return prompt.strip()
 
 
 class QMDRewardFunction:
-    """Combined reward function for QMD query expansion."""
-    __name__ = "qmd_format_diversity_reward"
-
-    def __init__(self):
-        # Load a small embedding model for diversity computation
-        print("Loading embedding model for diversity reward...")
-        self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
-        print("Embedding model loaded.")
+    """Reward function using comprehensive SCORING.md criteria."""
+    __name__ = "qmd_scoring_reward"
 
     def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
         """Compute rewards for a batch of completions."""
         rewards = []
 
-        for completion in completions:
-            # Extract just the generated part (after prompt)
-            text = completion
-
-            # Compute component rewards
-            format_r = compute_format_reward(text)
-            diversity_r = compute_diversity_reward(text, self.embedder)
-            length_r = compute_length_reward(text)
-
-            # Weighted combination
-            total = (
-                0.5 * format_r +      # Format is most important
-                0.35 * diversity_r +  # Diversity is second
-                0.15 * length_r       # Length is minor
-            )
+        for i, completion in enumerate(completions):
+            # Get the query from prompt if available
+            query = ""
+            if prompts and i < len(prompts):
+                query = extract_query_from_prompt(prompts[i])
 
-            rewards.append(total)
+            # Score using comprehensive system
+            score = score_expansion(query, completion)
+            rewards.append(score)
 
         return rewards
 
@@ -194,9 +242,11 @@ def main():
                         help="SFT model to use as starting point")
     parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
                         help="Base model (for loading tokenizer)")
-    parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo",
+    parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo-v2",
                         help="Output model name on Hub")
     parser.add_argument("--epochs", type=int, default=1)
+    parser.add_argument("--lr", type=float, default=1e-6,
+                        help="Learning rate (lower for stability)")
     parser.add_argument("--dry-run", action="store_true")
     args = parser.parse_args()
 
@@ -206,8 +256,17 @@ def main():
         print(f"  Base Model: {args.base_model}")
         print(f"  Output: {args.output}")
         print(f"  Epochs: {args.epochs}")
+        print(f"  LR: {args.lr}")
         return
 
+    # Login to HuggingFace Hub
+    hf_token = os.environ.get("HF_TOKEN")
+    if hf_token:
+        print("Logging in to HuggingFace Hub...")
+        login(token=hf_token)
+    else:
+        print("Warning: HF_TOKEN not set, will try cached login")
+
     # Load dataset (just prompts needed for GRPO)
     print("Loading dataset...")
     dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
@@ -217,7 +276,7 @@ def main():
         return {"prompt": example["messages"][0]["content"]}
 
     dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
-    dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset))))  # Use subset for GRPO
+    dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset))))
     print(f"Using {len(dataset)} prompts for GRPO")
 
     # Load tokenizer
@@ -234,18 +293,17 @@ def main():
         device_map="auto",
     )
     model = PeftModel.from_pretrained(base_model, args.sft_model)
-    model = model.merge_and_unload()  # Merge LoRA weights
+    model = model.merge_and_unload()
     print("Model loaded and LoRA merged.")
 
-    # Add new LoRA adapter for GRPO training
-    from peft import get_peft_model
+    # Add new LoRA adapter for GRPO training (smaller rank for stability)
     grpo_lora_config = LoraConfig(
-        r=8,
-        lora_alpha=16,
+        r=4,  # Smaller rank for more stable RL
+        lora_alpha=8,
         lora_dropout=0.05,
         bias="none",
         task_type="CAUSAL_LM",
-        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
+        target_modules=["q_proj", "v_proj"],  # Fewer modules for stability
     )
     model = get_peft_model(model, grpo_lora_config)
     model.print_trainable_parameters()
@@ -254,21 +312,29 @@ def main():
     # Initialize reward function
     reward_fn = QMDRewardFunction()
 
-    # GRPO config
+    # Test reward function
+    print("\nTesting reward function...")
+    test_good = "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."
+    test_bad = "auth is important for security"
+    print(f"  Good output score: {score_expansion('auth', test_good):.2f}")
+    print(f"  Bad output score: {score_expansion('auth', test_bad):.2f}")
+
+    # GRPO config with conservative settings
     config = GRPOConfig(
-        output_dir="qmd-expansion-grpo",
+        output_dir="qmd-expansion-grpo-v2",
         push_to_hub=True,
         hub_model_id=args.output,
 
-        # GRPO specific
-        num_generations=4,  # Generate 4 completions per prompt
-        max_completion_length=256,
+        # GRPO specific - conservative
+        num_generations=4,
+        max_completion_length=200,  # Shorter to avoid rambling
 
-        # Training
+        # Training - very conservative
         num_train_epochs=args.epochs,
         per_device_train_batch_size=2,
-        gradient_accumulation_steps=4,
-        learning_rate=5e-6,  # Lower LR for RL
+        gradient_accumulation_steps=8,
+        learning_rate=args.lr,
+        max_grad_norm=0.5,  # Clip gradients more aggressively
 
         # Logging
         logging_steps=10,
@@ -276,8 +342,8 @@ def main():
 
         # Monitoring
         report_to="trackio",
-        project="qmd-query-expansion-grpo",
-        run_name="grpo-format-diversity",
+        project="qmd-query-expansion-grpo-v2",
+        run_name="grpo-scoring-v2",
     )
 
     # Create trainer