# /// script # requires-python = ">=3.10" # dependencies = [ # "trl>=0.12.0", # "peft>=0.7.0", # "transformers>=4.45.0", # "accelerate>=0.24.0", # "huggingface_hub>=0.20.0", # "datasets", # "bitsandbytes", # "torch", # ] # /// """ GRPO training for QMD query expansion (Qwen3-1.7B). Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs: hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py """ import os import re from collections import Counter import torch from datasets import load_dataset from huggingface_hub import login from peft import LoraConfig, PeftModel, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer from trl import GRPOTrainer, GRPOConfig # --- Config (inlined from configs/grpo.yaml) --- BASE_MODEL = "Qwen/Qwen3-1.7B" SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft" OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo" DATASET = "tobil/qmd-query-expansion-train-v2" # ============================================================================= # Reward function (inlined from reward.py — single source of truth) # ============================================================================= STOPWORDS = frozenset({ 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by', }) KEY_TERM_STOPWORDS = frozenset({ 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of', 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we', 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell', }) GENERIC_LEX_PHRASES = frozenset({ 'find information about', 'search for', 'look up', 'get information', 'learn about', 'information on', 'details about', 'find out about', 'what is', 'how to', 'guide to', 'help with', }) CHAT_TEMPLATE_TOKENS = frozenset({ '<|im_start|>', '<|im_end|>', '<|endoftext|>', '\nassistant\n', '\nuser\n', }) def parse_expansion(text: str) -> dict: result = {"lex": [], "vec": [], "hyde": [], "invalid": []} for line in text.strip().split("\n"): line = line.strip() if not line: continue if line.startswith("lex:"): result["lex"].append(line[4:].strip()) elif line.startswith("vec:"): result["vec"].append(line[4:].strip()) elif line.startswith("hyde:"): result["hyde"].append(line[5:].strip()) else: result["invalid"].append(line) return result def clean_model_output(text: str) -> tuple[str, bool]: text = text.replace('<|im_end|>', '').strip() used_thinking = '' in text and '' in text if used_thinking: text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() return text, used_thinking def extract_named_entities(query: str) -> set: entities = set() words = query.split() prev_was_entity = False for i, word in enumerate(words): clean = word.strip('.,!?:;()[]"\'') if not clean: prev_was_entity = False continue is_entity = False if clean.isupper() and len(clean) >= 2: entities.add(clean.lower()) is_entity = True elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS: entities.add(clean.lower()) is_entity = True elif any(c in clean for c in '.+-#@') and len(clean) >= 2: entities.add(clean.lower()) is_entity = True elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper(): entities.add(clean.lower()) is_entity = True elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS: entities.add(clean.lower()) is_entity = True prev_was_entity = is_entity return entities def get_key_terms(query: str) -> set: return set(query.lower().split()) - KEY_TERM_STOPWORDS def lex_preserves_key_terms(lex_line: str, query: str) -> bool: key_terms = get_key_terms(query) if not key_terms: return True return bool(key_terms & set(lex_line.lower().split())) def lex_preserves_entities(line: str, entities: set) -> bool: if not entities: return True lower = line.lower() return any(e in lower for e in entities) def lex_is_generic(lex_line: str) -> bool: lower = lex_line.lower().strip() for phrase in GENERIC_LEX_PHRASES: if phrase in lower or lower.startswith(phrase.split()[0]): remaining = lower for word in phrase.split(): remaining = remaining.replace(word, '', 1).strip() if len(remaining) < 3: return True return False def word_set_distance(a: str, b: str) -> int: return len(set(a.lower().split()) ^ set(b.lower().split())) def is_diverse(a: str, b: str, min_distance: int = 2) -> bool: a, b = a.lower().strip(), b.lower().strip() if a == b or a in b or b in a: return False return word_set_distance(a, b) >= min_distance def echoes_query(expansion: str, query: str) -> bool: exp, q = expansion.lower().strip(), query.lower().strip() return exp == q or (q in exp and len(exp) < len(q) + 10) def word_repetition_penalty(text: str) -> int: counts = Counter(re.findall(r'\b\w+\b', text.lower())) return sum((c - 2) * 2 for w, c in counts.items() if c >= 3 and w not in STOPWORDS and len(w) > 2) def score_expansion(query: str, expansion: str) -> float: """Score expansion as float in [0.0, 1.0] for RL reward.""" text, used_thinking = clean_model_output(expansion.strip()) # Hard fail: chat template leakage if any(tok in text for tok in CHAT_TEMPLATE_TOKENS): return 0.0 # Hard fail: invalid lines for line in text.split("\n"): line = line.strip() if line and not line.startswith(("lex:", "vec:", "hyde:")): return 0.0 parsed = parse_expansion(text) # Format (0-30) format_score = 10 # no invalid lines if parsed["lex"]: format_score += 10 if parsed["vec"]: format_score += 10 # Diversity (0-30) diversity_score = 0 types_present = sum(1 for t in ("lex", "vec") if parsed[t]) if types_present >= 2: diversity_score += 10 if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5 lex_div = 5 for i, a in enumerate(parsed["lex"]): for b in parsed["lex"][i+1:]: if not is_diverse(a, b, 2): lex_div -= 2 diversity_score += max(0, lex_div) vec_div = 5 for i, a in enumerate(parsed["vec"]): for b in parsed["vec"][i+1:]: if not is_diverse(a, b, 3): vec_div -= 2 diversity_score += max(0, vec_div) echo = 5 for exp in parsed["lex"] + parsed["vec"]: if echoes_query(exp, query): echo -= 3 diversity_score += max(0, echo) # HyDE (0-20) hyde_score = 0 if parsed["hyde"]: hyde_text = parsed["hyde"][0] hyde_score += 5 hyde_len = len(hyde_text) if 50 <= hyde_len <= 200: hyde_score += 5 elif hyde_len < 50: hyde_score += 2 if "\n" not in hyde_text: hyde_score += 5 hyde_score += max(0, 5 - word_repetition_penalty(hyde_text)) # Quality (0-20) quality_score = 5 if parsed["lex"] and parsed["vec"]: avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"]) avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"]) if avg_lex <= avg_vec: quality_score += 5 if parsed["vec"]: natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15) quality_score += 5 if natural == len(parsed["vec"]) else 2 if parsed["lex"]: with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query)) if with_terms == len(parsed["lex"]): quality_score += 5 elif with_terms > 0: quality_score += 2 # Entity (-45 to +20) entity_score = 0 entities = extract_named_entities(query) if entities and parsed["lex"]: with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities)) if with_entities == len(parsed["lex"]): entity_score += 15 elif with_entities > 0: entity_score += 5 else: entity_score -= 30 generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l)) if generic_count: entity_score -= generic_count * 15 if parsed["vec"]: vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities)) if vec_with > 0: entity_score += 5 elif not entities: entity_score = 10 # Think bonus (0-20) think_bonus = 0 if used_thinking else 20 total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus max_possible = 140 if parsed["hyde"] else 120 return max(0.0, min(1.0, total / max_possible)) def extract_query_from_prompt(prompt: str) -> str: if "Expand this search query:" in prompt: query = prompt.split("Expand this search query:")[-1].strip() if "<|im_end|>" in query: query = query.split("<|im_end|>")[0].strip() return query return prompt.strip() class QMDRewardFunction: __name__ = "qmd_scoring_reward" def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]: rewards = [] for i, completion in enumerate(completions): query = "" if prompts and i < len(prompts): query = extract_query_from_prompt(prompts[i]) rewards.append(score_expansion(query, completion)) return rewards # ============================================================================= # Main training # ============================================================================= def main(): hf_token = os.environ.get("HF_TOKEN") if hf_token: login(token=hf_token) print(f"Loading tokenizer from {BASE_MODEL}...") tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load and format dataset print(f"Loading dataset: {DATASET}...") dataset = load_dataset(DATASET, split="train") def extract_prompt(example): content = example["messages"][0]["content"] messages = [{"role": "user", "content": content}] formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return {"prompt": formatted} dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names) dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset)))) print(f"Using {len(dataset)} prompts for GRPO") # Load base model, merge SFT adapter print(f"Loading base model {BASE_MODEL}...") base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto", ) print(f"Merging SFT adapter {SFT_MODEL}...") model = PeftModel.from_pretrained(base_model, SFT_MODEL) model = model.merge_and_unload() print("SFT adapter merged.") # Fresh LoRA for GRPO (small: rank 4, q/v only) grpo_lora = LoraConfig( r=4, lora_alpha=8, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], ) model = get_peft_model(model, grpo_lora) model.print_trainable_parameters() config = GRPOConfig( output_dir="qmd-query-expansion-1.7B-grpo", push_to_hub=True, hub_model_id=OUTPUT_MODEL, num_generations=4, max_completion_length=200, beta=0.04, # KL regularization — prevents drift from SFT checkpoint num_train_epochs=1, per_device_train_batch_size=2, gradient_accumulation_steps=8, learning_rate=5e-7, max_grad_norm=0.5, max_steps=200, logging_steps=10, save_strategy="epoch", bf16=True, report_to="none", ) print("Initializing GRPO trainer...") trainer = GRPOTrainer( model=model, processing_class=tokenizer, args=config, train_dataset=dataset, reward_funcs=[QMDRewardFunction()], ) print("Starting GRPO training...") trainer.train() print("Pushing to Hub...") trainer.push_to_hub() print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}") if __name__ == "__main__": main()