| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389 |
- # /// script
- # requires-python = ">=3.10"
- # dependencies = [
- # "trl>=0.12.0",
- # "peft>=0.7.0",
- # "transformers>=4.45.0",
- # "accelerate>=0.24.0",
- # "huggingface_hub>=0.20.0",
- # "datasets",
- # "bitsandbytes",
- # "torch",
- # ]
- # ///
- """
- GRPO training for QMD query expansion (Qwen3-1.7B).
- Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs:
- hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py
- """
- import os
- import re
- from collections import Counter
- import torch
- from datasets import load_dataset
- from huggingface_hub import login
- from peft import LoraConfig, PeftModel, get_peft_model
- from transformers import AutoModelForCausalLM, AutoTokenizer
- from trl import GRPOTrainer, GRPOConfig
- # --- Config (inlined from configs/grpo.yaml) ---
- BASE_MODEL = "Qwen/Qwen3-1.7B"
- SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
- OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
- DATASET = "tobil/qmd-query-expansion-train-v2"
- # =============================================================================
- # Reward function (inlined from reward.py — single source of truth)
- # =============================================================================
- STOPWORDS = frozenset({
- 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
- 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
- })
- KEY_TERM_STOPWORDS = frozenset({
- 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
- 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
- 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
- })
- GENERIC_LEX_PHRASES = frozenset({
- 'find information about', 'search for', 'look up', 'get information',
- 'learn about', 'information on', 'details about', 'find out about',
- 'what is', 'how to', 'guide to', 'help with',
- })
- CHAT_TEMPLATE_TOKENS = frozenset({
- '<|im_start|>', '<|im_end|>', '<|endoftext|>',
- '\nassistant\n', '\nuser\n',
- })
- def parse_expansion(text: str) -> dict:
- result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
- for line in text.strip().split("\n"):
- line = line.strip()
- if not line:
- continue
- if line.startswith("lex:"):
- result["lex"].append(line[4:].strip())
- elif line.startswith("vec:"):
- result["vec"].append(line[4:].strip())
- elif line.startswith("hyde:"):
- result["hyde"].append(line[5:].strip())
- else:
- result["invalid"].append(line)
- return result
- def clean_model_output(text: str) -> tuple[str, bool]:
- text = text.replace('<|im_end|>', '').strip()
- used_thinking = '<think>' in text and '</think>' in text
- if used_thinking:
- text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
- return text, used_thinking
- def extract_named_entities(query: str) -> set:
- entities = set()
- words = query.split()
- prev_was_entity = False
- for i, word in enumerate(words):
- clean = word.strip('.,!?:;()[]"\'')
- if not clean:
- prev_was_entity = False
- continue
- is_entity = False
- if clean.isupper() and len(clean) >= 2:
- entities.add(clean.lower())
- is_entity = True
- elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
- entities.add(clean.lower())
- is_entity = True
- elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
- entities.add(clean.lower())
- is_entity = True
- elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
- entities.add(clean.lower())
- is_entity = True
- elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
- entities.add(clean.lower())
- is_entity = True
- prev_was_entity = is_entity
- return entities
- def get_key_terms(query: str) -> set:
- return set(query.lower().split()) - KEY_TERM_STOPWORDS
- def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
- key_terms = get_key_terms(query)
- if not key_terms:
- return True
- return bool(key_terms & set(lex_line.lower().split()))
- def lex_preserves_entities(line: str, entities: set) -> bool:
- if not entities:
- return True
- lower = line.lower()
- return any(e in lower for e in entities)
- def lex_is_generic(lex_line: str) -> bool:
- lower = lex_line.lower().strip()
- for phrase in GENERIC_LEX_PHRASES:
- if phrase in lower or lower.startswith(phrase.split()[0]):
- remaining = lower
- for word in phrase.split():
- remaining = remaining.replace(word, '', 1).strip()
- if len(remaining) < 3:
- return True
- return False
- def word_set_distance(a: str, b: str) -> int:
- return len(set(a.lower().split()) ^ set(b.lower().split()))
- def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
- a, b = a.lower().strip(), b.lower().strip()
- if a == b or a in b or b in a:
- return False
- return word_set_distance(a, b) >= min_distance
- def echoes_query(expansion: str, query: str) -> bool:
- exp, q = expansion.lower().strip(), query.lower().strip()
- return exp == q or (q in exp and len(exp) < len(q) + 10)
- def word_repetition_penalty(text: str) -> int:
- counts = Counter(re.findall(r'\b\w+\b', text.lower()))
- return sum((c - 2) * 2 for w, c in counts.items()
- if c >= 3 and w not in STOPWORDS and len(w) > 2)
- def score_expansion(query: str, expansion: str) -> float:
- """Score expansion as float in [0.0, 1.0] for RL reward."""
- text, used_thinking = clean_model_output(expansion.strip())
- # Hard fail: chat template leakage
- if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
- return 0.0
- # Hard fail: invalid lines
- for line in text.split("\n"):
- line = line.strip()
- if line and not line.startswith(("lex:", "vec:", "hyde:")):
- return 0.0
- parsed = parse_expansion(text)
- # Format (0-30)
- format_score = 10 # no invalid lines
- if parsed["lex"]:
- format_score += 10
- if parsed["vec"]:
- format_score += 10
- # Diversity (0-30)
- diversity_score = 0
- types_present = sum(1 for t in ("lex", "vec") if parsed[t])
- if types_present >= 2:
- diversity_score += 10
- if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
- diversity_score += 5
- lex_div = 5
- for i, a in enumerate(parsed["lex"]):
- for b in parsed["lex"][i+1:]:
- if not is_diverse(a, b, 2):
- lex_div -= 2
- diversity_score += max(0, lex_div)
- vec_div = 5
- for i, a in enumerate(parsed["vec"]):
- for b in parsed["vec"][i+1:]:
- if not is_diverse(a, b, 3):
- vec_div -= 2
- diversity_score += max(0, vec_div)
- echo = 5
- for exp in parsed["lex"] + parsed["vec"]:
- if echoes_query(exp, query):
- echo -= 3
- diversity_score += max(0, echo)
- # HyDE (0-20)
- hyde_score = 0
- if parsed["hyde"]:
- hyde_text = parsed["hyde"][0]
- hyde_score += 5
- hyde_len = len(hyde_text)
- if 50 <= hyde_len <= 200:
- hyde_score += 5
- elif hyde_len < 50:
- hyde_score += 2
- if "\n" not in hyde_text:
- hyde_score += 5
- hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
- # Quality (0-20)
- quality_score = 5
- if parsed["lex"] and parsed["vec"]:
- avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
- avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
- if avg_lex <= avg_vec:
- quality_score += 5
- if parsed["vec"]:
- natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
- quality_score += 5 if natural == len(parsed["vec"]) else 2
- if parsed["lex"]:
- with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
- if with_terms == len(parsed["lex"]):
- quality_score += 5
- elif with_terms > 0:
- quality_score += 2
- # Entity (-45 to +20)
- entity_score = 0
- entities = extract_named_entities(query)
- if entities and parsed["lex"]:
- with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
- if with_entities == len(parsed["lex"]):
- entity_score += 15
- elif with_entities > 0:
- entity_score += 5
- else:
- entity_score -= 30
- generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
- if generic_count:
- entity_score -= generic_count * 15
- if parsed["vec"]:
- vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
- if vec_with > 0:
- entity_score += 5
- elif not entities:
- entity_score = 10
- # Think bonus (0-20)
- think_bonus = 0 if used_thinking else 20
- total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
- max_possible = 140 if parsed["hyde"] else 120
- return max(0.0, min(1.0, total / max_possible))
- def extract_query_from_prompt(prompt: str) -> str:
- if "Expand this search query:" in prompt:
- query = prompt.split("Expand this search query:")[-1].strip()
- if "<|im_end|>" in query:
- query = query.split("<|im_end|>")[0].strip()
- return query
- return prompt.strip()
- class QMDRewardFunction:
- __name__ = "qmd_scoring_reward"
- def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
- rewards = []
- for i, completion in enumerate(completions):
- query = ""
- if prompts and i < len(prompts):
- query = extract_query_from_prompt(prompts[i])
- rewards.append(score_expansion(query, completion))
- return rewards
- # =============================================================================
- # Main training
- # =============================================================================
- def main():
- hf_token = os.environ.get("HF_TOKEN")
- if hf_token:
- login(token=hf_token)
- print(f"Loading tokenizer from {BASE_MODEL}...")
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- # Load and format dataset
- print(f"Loading dataset: {DATASET}...")
- dataset = load_dataset(DATASET, split="train")
- def extract_prompt(example):
- content = example["messages"][0]["content"]
- messages = [{"role": "user", "content": content}]
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
- return {"prompt": formatted}
- dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
- dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
- print(f"Using {len(dataset)} prompts for GRPO")
- # Load base model, merge SFT adapter
- print(f"Loading base model {BASE_MODEL}...")
- base_model = AutoModelForCausalLM.from_pretrained(
- BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto",
- )
- print(f"Merging SFT adapter {SFT_MODEL}...")
- model = PeftModel.from_pretrained(base_model, SFT_MODEL)
- model = model.merge_and_unload()
- print("SFT adapter merged.")
- # Fresh LoRA for GRPO (small: rank 4, q/v only)
- grpo_lora = LoraConfig(
- r=4, lora_alpha=8, lora_dropout=0.05,
- bias="none", task_type="CAUSAL_LM",
- target_modules=["q_proj", "v_proj"],
- )
- model = get_peft_model(model, grpo_lora)
- model.print_trainable_parameters()
- config = GRPOConfig(
- output_dir="qmd-query-expansion-1.7B-grpo",
- push_to_hub=True,
- hub_model_id=OUTPUT_MODEL,
- num_generations=4,
- max_completion_length=200,
- beta=0.04, # KL regularization — prevents drift from SFT checkpoint
- num_train_epochs=1,
- per_device_train_batch_size=2,
- gradient_accumulation_steps=8,
- learning_rate=5e-7,
- max_grad_norm=0.5,
- max_steps=200,
- logging_steps=10,
- save_strategy="epoch",
- bf16=True,
- report_to="none",
- )
- print("Initializing GRPO trainer...")
- trainer = GRPOTrainer(
- model=model,
- processing_class=tokenizer,
- args=config,
- train_dataset=dataset,
- reward_funcs=[QMDRewardFunction()],
- )
- print("Starting GRPO training...")
- trainer.train()
- print("Pushing to Hub...")
- trainer.push_to_hub()
- print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
- if __name__ == "__main__":
- main()
|