|
|
@@ -5,180 +5,228 @@
|
|
|
# "peft>=0.7.0",
|
|
|
# "transformers>=4.45.0",
|
|
|
# "accelerate>=0.24.0",
|
|
|
+# "huggingface_hub>=0.20.0",
|
|
|
# "trackio",
|
|
|
# "datasets",
|
|
|
# "bitsandbytes",
|
|
|
-# "sentence-transformers",
|
|
|
# ]
|
|
|
# ///
|
|
|
"""
|
|
|
GRPO (Group Relative Policy Optimization) training for QMD query expansion.
|
|
|
|
|
|
-Reward Type 2: Format + Diversity
|
|
|
-- Rewards correct lex/vec/hyde format
|
|
|
-- Penalizes repetition between lines
|
|
|
-- Rewards semantic diversity of expansions
|
|
|
+Uses the comprehensive scoring system from SCORING.md:
|
|
|
+- Format (30%): Must have lex: and vec: prefixes
|
|
|
+- Diversity (30%): No echoing query, diverse expansions
|
|
|
+- Hyde (20%): Concise, no newlines, no repetition
|
|
|
+- Quality (20%): lex=keywords, vec=natural language
|
|
|
|
|
|
Usage:
|
|
|
uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
|
|
|
"""
|
|
|
|
|
|
+import os
|
|
|
import re
|
|
|
import torch
|
|
|
import trackio
|
|
|
+from collections import Counter
|
|
|
from datasets import load_dataset
|
|
|
-from peft import LoraConfig, PeftModel
|
|
|
+from huggingface_hub import login
|
|
|
+from peft import LoraConfig, PeftModel, get_peft_model
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
from trl import GRPOTrainer, GRPOConfig
|
|
|
-from sentence_transformers import SentenceTransformer
|
|
|
+
|
|
|
+STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
|
|
|
|
|
|
# ============================================================================
|
|
|
-# Reward Function: Format + Diversity
|
|
|
+# Scoring Functions (from SCORING.md)
|
|
|
# ============================================================================
|
|
|
|
|
|
def parse_expansion(text: str) -> dict:
|
|
|
- """Parse expansion output into lex/vec/hyde components."""
|
|
|
- result = {"lex": [], "vec": [], "hyde": []}
|
|
|
+ """Parse expansion into structured format."""
|
|
|
+ lines = text.strip().split("\n")
|
|
|
+ result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
|
|
|
|
|
- for line in text.strip().split("\n"):
|
|
|
+ for line in lines:
|
|
|
line = line.strip()
|
|
|
+ if not line:
|
|
|
+ continue
|
|
|
if line.startswith("lex:"):
|
|
|
result["lex"].append(line[4:].strip())
|
|
|
elif line.startswith("vec:"):
|
|
|
result["vec"].append(line[4:].strip())
|
|
|
elif line.startswith("hyde:"):
|
|
|
result["hyde"].append(line[5:].strip())
|
|
|
+ else:
|
|
|
+ result["invalid"].append(line)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
-def compute_format_reward(text: str) -> float:
|
|
|
- """
|
|
|
- Reward for correct format:
|
|
|
- - Has at least 1 lex line: +0.2
|
|
|
- - Has at least 1 vec line: +0.2
|
|
|
- - Has hyde line: +0.1
|
|
|
- - Correct line format (type: content): +0.1 per line (max 0.3)
|
|
|
- - No garbage/malformed lines: +0.2
|
|
|
- """
|
|
|
- reward = 0.0
|
|
|
- parsed = parse_expansion(text)
|
|
|
+def edit_distance_simple(a: str, b: str) -> int:
|
|
|
+ """Simple word-level edit distance."""
|
|
|
+ words_a = set(a.lower().split())
|
|
|
+ words_b = set(b.lower().split())
|
|
|
+ return len(words_a ^ words_b)
|
|
|
|
|
|
- # Check required components
|
|
|
- if parsed["lex"]:
|
|
|
- reward += 0.2
|
|
|
- if parsed["vec"]:
|
|
|
- reward += 0.2
|
|
|
- if parsed["hyde"]:
|
|
|
- reward += 0.1
|
|
|
|
|
|
- # Check line format
|
|
|
- lines = text.strip().split("\n")
|
|
|
- valid_lines = 0
|
|
|
- for line in lines:
|
|
|
- if re.match(r'^(lex|vec|hyde):\s*.+', line.strip()):
|
|
|
- valid_lines += 1
|
|
|
+def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
|
|
|
+ """Check if two strings are sufficiently different."""
|
|
|
+ a, b = a.lower().strip(), b.lower().strip()
|
|
|
+ if a == b:
|
|
|
+ return False
|
|
|
+ if a in b or b in a:
|
|
|
+ return False
|
|
|
+ return edit_distance_simple(a, b) >= min_distance
|
|
|
|
|
|
- reward += min(0.3, valid_lines * 0.1)
|
|
|
|
|
|
- # Penalize malformed lines
|
|
|
- malformed = len(lines) - valid_lines
|
|
|
- if malformed == 0:
|
|
|
- reward += 0.2
|
|
|
- else:
|
|
|
- reward -= malformed * 0.1
|
|
|
+def echoes_query(expansion: str, query: str) -> bool:
|
|
|
+ """Check if expansion is just echoing the query."""
|
|
|
+ exp = expansion.lower().strip()
|
|
|
+ q = query.lower().strip()
|
|
|
+ if exp == q:
|
|
|
+ return True
|
|
|
+ if q in exp and len(exp) < len(q) + 10:
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
|
|
|
- return max(0.0, min(1.0, reward))
|
|
|
+def word_repetition_penalty(text: str) -> int:
|
|
|
+ """Count penalty for repeated words (excluding stopwords)."""
|
|
|
+ words = re.findall(r'\b\w+\b', text.lower())
|
|
|
+ counts = Counter(words)
|
|
|
+ penalty = 0
|
|
|
+ for word, count in counts.items():
|
|
|
+ if count >= 3 and word not in STOPWORDS and len(word) > 2:
|
|
|
+ penalty += (count - 2) * 2
|
|
|
+ return penalty
|
|
|
|
|
|
|
|
|
-def compute_diversity_reward(text: str, embedder) -> float:
|
|
|
+def score_expansion(query: str, expansion: str) -> float:
|
|
|
"""
|
|
|
- Reward for diverse expansions:
|
|
|
- - Penalize exact duplicates
|
|
|
- - Reward semantic distance between expansions
|
|
|
+ Score an expansion based on SCORING.md criteria.
|
|
|
+ Returns normalized score 0.0-1.0 for RL reward.
|
|
|
"""
|
|
|
- parsed = parse_expansion(text)
|
|
|
- all_expansions = parsed["lex"] + parsed["vec"] + parsed["hyde"]
|
|
|
-
|
|
|
- if len(all_expansions) < 2:
|
|
|
- return 0.0
|
|
|
-
|
|
|
- # Penalize exact duplicates
|
|
|
- unique = set(e.lower() for e in all_expansions)
|
|
|
- duplicate_penalty = (len(all_expansions) - len(unique)) * 0.2
|
|
|
-
|
|
|
- # Compute semantic diversity
|
|
|
- if len(unique) >= 2:
|
|
|
- try:
|
|
|
- embeddings = embedder.encode(list(unique))
|
|
|
- # Compute pairwise cosine similarities
|
|
|
- from torch.nn.functional import cosine_similarity
|
|
|
- emb_tensor = torch.tensor(embeddings)
|
|
|
-
|
|
|
- similarities = []
|
|
|
- for i in range(len(emb_tensor)):
|
|
|
- for j in range(i + 1, len(emb_tensor)):
|
|
|
- sim = cosine_similarity(
|
|
|
- emb_tensor[i].unsqueeze(0),
|
|
|
- emb_tensor[j].unsqueeze(0)
|
|
|
- ).item()
|
|
|
- similarities.append(sim)
|
|
|
-
|
|
|
- # Lower similarity = higher diversity = higher reward
|
|
|
- avg_similarity = sum(similarities) / len(similarities) if similarities else 1.0
|
|
|
- diversity_reward = 1.0 - avg_similarity # 0 = identical, 1 = orthogonal
|
|
|
- except Exception:
|
|
|
- diversity_reward = 0.0
|
|
|
+ parsed = parse_expansion(expansion)
|
|
|
+
|
|
|
+ # === FORMAT (0-30) ===
|
|
|
+ format_score = 0
|
|
|
+ if parsed["lex"]:
|
|
|
+ format_score += 10
|
|
|
+ if parsed["vec"]:
|
|
|
+ format_score += 10
|
|
|
+ if not parsed["invalid"]:
|
|
|
+ format_score += 10
|
|
|
else:
|
|
|
- diversity_reward = 0.0
|
|
|
+ format_score += max(0, 10 - len(parsed["invalid"]) * 5)
|
|
|
+
|
|
|
+ # === DIVERSITY (0-30) ===
|
|
|
+ diversity_score = 0
|
|
|
+
|
|
|
+ # 2+ different types
|
|
|
+ types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
|
|
|
+ if types_present >= 2:
|
|
|
+ diversity_score += 10
|
|
|
+
|
|
|
+ # 2+ total expansions
|
|
|
+ total_expansions = len(parsed["lex"]) + len(parsed["vec"])
|
|
|
+ if total_expansions >= 2:
|
|
|
+ diversity_score += 5
|
|
|
+
|
|
|
+ # Lex diversity
|
|
|
+ lex_score = 5
|
|
|
+ for i, a in enumerate(parsed["lex"]):
|
|
|
+ for b in parsed["lex"][i+1:]:
|
|
|
+ if not is_diverse(a, b, 2):
|
|
|
+ lex_score -= 2
|
|
|
+ diversity_score += max(0, lex_score)
|
|
|
+
|
|
|
+ # Vec diversity
|
|
|
+ vec_score = 5
|
|
|
+ for i, a in enumerate(parsed["vec"]):
|
|
|
+ for b in parsed["vec"][i+1:]:
|
|
|
+ if not is_diverse(a, b, 3):
|
|
|
+ vec_score -= 2
|
|
|
+ diversity_score += max(0, vec_score)
|
|
|
+
|
|
|
+ # Don't echo query
|
|
|
+ echo_score = 5
|
|
|
+ for exp in parsed["lex"] + parsed["vec"]:
|
|
|
+ if echoes_query(exp, query):
|
|
|
+ echo_score -= 3 # Heavier penalty for echoing
|
|
|
+ diversity_score += max(0, echo_score)
|
|
|
+
|
|
|
+ # === HYDE (0-20) ===
|
|
|
+ hyde_score = 0
|
|
|
+ if parsed["hyde"]:
|
|
|
+ hyde_text = parsed["hyde"][0]
|
|
|
+ hyde_score += 5 # Present
|
|
|
+
|
|
|
+ # Length check (50-200 chars ideal)
|
|
|
+ hyde_len = len(hyde_text)
|
|
|
+ if 50 <= hyde_len <= 200:
|
|
|
+ hyde_score += 5
|
|
|
+ elif hyde_len < 50:
|
|
|
+ hyde_score += 2
|
|
|
+
|
|
|
+ # No newlines
|
|
|
+ if "\n" not in hyde_text:
|
|
|
+ hyde_score += 5
|
|
|
+
|
|
|
+ # No repetition
|
|
|
+ rep_penalty = word_repetition_penalty(hyde_text)
|
|
|
+ hyde_score += max(0, 5 - rep_penalty)
|
|
|
+
|
|
|
+ # === QUALITY (0-20) ===
|
|
|
+ quality_score = 10 # Base
|
|
|
+
|
|
|
+ # Lex should be shorter than vec
|
|
|
+ if parsed["lex"] and parsed["vec"]:
|
|
|
+ avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
|
|
|
+ avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
|
|
|
+ if avg_lex <= avg_vec:
|
|
|
+ quality_score += 5
|
|
|
+
|
|
|
+ # Vec should be natural language
|
|
|
+ if parsed["vec"]:
|
|
|
+ natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
|
|
|
+ if natural == len(parsed["vec"]):
|
|
|
+ quality_score += 5
|
|
|
+ else:
|
|
|
+ quality_score += 2
|
|
|
|
|
|
- return max(0.0, diversity_reward - duplicate_penalty)
|
|
|
+ # === TOTAL ===
|
|
|
+ total = format_score + diversity_score + hyde_score + quality_score
|
|
|
+ max_possible = 100 if parsed["hyde"] else 80
|
|
|
|
|
|
+ # Normalize to 0-1
|
|
|
+ return total / max_possible
|
|
|
|
|
|
-def compute_length_reward(text: str) -> float:
|
|
|
- """Reward appropriate length (not too short, not too long)."""
|
|
|
- lines = [l for l in text.strip().split("\n") if l.strip()]
|
|
|
|
|
|
- # Ideal: 3-6 lines
|
|
|
- if 3 <= len(lines) <= 6:
|
|
|
- return 0.2
|
|
|
- elif 2 <= len(lines) <= 7:
|
|
|
- return 0.1
|
|
|
- else:
|
|
|
- return 0.0
|
|
|
+def extract_query_from_prompt(prompt: str) -> str:
|
|
|
+ """Extract the query from the prompt template."""
|
|
|
+ # Prompt format: "Expand this search query:\n\n{query}"
|
|
|
+ if "Expand this search query:" in prompt:
|
|
|
+ return prompt.split("Expand this search query:")[-1].strip()
|
|
|
+ return prompt.strip()
|
|
|
|
|
|
|
|
|
class QMDRewardFunction:
|
|
|
- """Combined reward function for QMD query expansion."""
|
|
|
- __name__ = "qmd_format_diversity_reward"
|
|
|
-
|
|
|
- def __init__(self):
|
|
|
- # Load a small embedding model for diversity computation
|
|
|
- print("Loading embedding model for diversity reward...")
|
|
|
- self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
- print("Embedding model loaded.")
|
|
|
+ """Reward function using comprehensive SCORING.md criteria."""
|
|
|
+ __name__ = "qmd_scoring_reward"
|
|
|
|
|
|
def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
|
|
|
"""Compute rewards for a batch of completions."""
|
|
|
rewards = []
|
|
|
|
|
|
- for completion in completions:
|
|
|
- # Extract just the generated part (after prompt)
|
|
|
- text = completion
|
|
|
-
|
|
|
- # Compute component rewards
|
|
|
- format_r = compute_format_reward(text)
|
|
|
- diversity_r = compute_diversity_reward(text, self.embedder)
|
|
|
- length_r = compute_length_reward(text)
|
|
|
-
|
|
|
- # Weighted combination
|
|
|
- total = (
|
|
|
- 0.5 * format_r + # Format is most important
|
|
|
- 0.35 * diversity_r + # Diversity is second
|
|
|
- 0.15 * length_r # Length is minor
|
|
|
- )
|
|
|
+ for i, completion in enumerate(completions):
|
|
|
+ # Get the query from prompt if available
|
|
|
+ query = ""
|
|
|
+ if prompts and i < len(prompts):
|
|
|
+ query = extract_query_from_prompt(prompts[i])
|
|
|
|
|
|
- rewards.append(total)
|
|
|
+ # Score using comprehensive system
|
|
|
+ score = score_expansion(query, completion)
|
|
|
+ rewards.append(score)
|
|
|
|
|
|
return rewards
|
|
|
|
|
|
@@ -194,9 +242,11 @@ def main():
|
|
|
help="SFT model to use as starting point")
|
|
|
parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
|
|
|
help="Base model (for loading tokenizer)")
|
|
|
- parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo",
|
|
|
+ parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo-v2",
|
|
|
help="Output model name on Hub")
|
|
|
parser.add_argument("--epochs", type=int, default=1)
|
|
|
+ parser.add_argument("--lr", type=float, default=1e-6,
|
|
|
+ help="Learning rate (lower for stability)")
|
|
|
parser.add_argument("--dry-run", action="store_true")
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
@@ -206,8 +256,17 @@ def main():
|
|
|
print(f" Base Model: {args.base_model}")
|
|
|
print(f" Output: {args.output}")
|
|
|
print(f" Epochs: {args.epochs}")
|
|
|
+ print(f" LR: {args.lr}")
|
|
|
return
|
|
|
|
|
|
+ # Login to HuggingFace Hub
|
|
|
+ hf_token = os.environ.get("HF_TOKEN")
|
|
|
+ if hf_token:
|
|
|
+ print("Logging in to HuggingFace Hub...")
|
|
|
+ login(token=hf_token)
|
|
|
+ else:
|
|
|
+ print("Warning: HF_TOKEN not set, will try cached login")
|
|
|
+
|
|
|
# Load dataset (just prompts needed for GRPO)
|
|
|
print("Loading dataset...")
|
|
|
dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
|
|
|
@@ -217,7 +276,7 @@ def main():
|
|
|
return {"prompt": example["messages"][0]["content"]}
|
|
|
|
|
|
dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
|
|
|
- dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset)))) # Use subset for GRPO
|
|
|
+ dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset))))
|
|
|
print(f"Using {len(dataset)} prompts for GRPO")
|
|
|
|
|
|
# Load tokenizer
|
|
|
@@ -234,18 +293,17 @@ def main():
|
|
|
device_map="auto",
|
|
|
)
|
|
|
model = PeftModel.from_pretrained(base_model, args.sft_model)
|
|
|
- model = model.merge_and_unload() # Merge LoRA weights
|
|
|
+ model = model.merge_and_unload()
|
|
|
print("Model loaded and LoRA merged.")
|
|
|
|
|
|
- # Add new LoRA adapter for GRPO training
|
|
|
- from peft import get_peft_model
|
|
|
+ # Add new LoRA adapter for GRPO training (smaller rank for stability)
|
|
|
grpo_lora_config = LoraConfig(
|
|
|
- r=8,
|
|
|
- lora_alpha=16,
|
|
|
+ r=4, # Smaller rank for more stable RL
|
|
|
+ lora_alpha=8,
|
|
|
lora_dropout=0.05,
|
|
|
bias="none",
|
|
|
task_type="CAUSAL_LM",
|
|
|
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
|
|
+ target_modules=["q_proj", "v_proj"], # Fewer modules for stability
|
|
|
)
|
|
|
model = get_peft_model(model, grpo_lora_config)
|
|
|
model.print_trainable_parameters()
|
|
|
@@ -254,21 +312,29 @@ def main():
|
|
|
# Initialize reward function
|
|
|
reward_fn = QMDRewardFunction()
|
|
|
|
|
|
- # GRPO config
|
|
|
+ # Test reward function
|
|
|
+ print("\nTesting reward function...")
|
|
|
+ test_good = "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."
|
|
|
+ test_bad = "auth is important for security"
|
|
|
+ print(f" Good output score: {score_expansion('auth', test_good):.2f}")
|
|
|
+ print(f" Bad output score: {score_expansion('auth', test_bad):.2f}")
|
|
|
+
|
|
|
+ # GRPO config with conservative settings
|
|
|
config = GRPOConfig(
|
|
|
- output_dir="qmd-expansion-grpo",
|
|
|
+ output_dir="qmd-expansion-grpo-v2",
|
|
|
push_to_hub=True,
|
|
|
hub_model_id=args.output,
|
|
|
|
|
|
- # GRPO specific
|
|
|
- num_generations=4, # Generate 4 completions per prompt
|
|
|
- max_completion_length=256,
|
|
|
+ # GRPO specific - conservative
|
|
|
+ num_generations=4,
|
|
|
+ max_completion_length=200, # Shorter to avoid rambling
|
|
|
|
|
|
- # Training
|
|
|
+ # Training - very conservative
|
|
|
num_train_epochs=args.epochs,
|
|
|
per_device_train_batch_size=2,
|
|
|
- gradient_accumulation_steps=4,
|
|
|
- learning_rate=5e-6, # Lower LR for RL
|
|
|
+ gradient_accumulation_steps=8,
|
|
|
+ learning_rate=args.lr,
|
|
|
+ max_grad_norm=0.5, # Clip gradients more aggressively
|
|
|
|
|
|
# Logging
|
|
|
logging_steps=10,
|
|
|
@@ -276,8 +342,8 @@ def main():
|
|
|
|
|
|
# Monitoring
|
|
|
report_to="trackio",
|
|
|
- project="qmd-query-expansion-grpo",
|
|
|
- run_name="grpo-format-diversity",
|
|
|
+ project="qmd-query-expansion-grpo-v2",
|
|
|
+ run_name="grpo-scoring-v2",
|
|
|
)
|
|
|
|
|
|
# Create trainer
|