| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- # /// script
- # requires-python = ">=3.10"
- # dependencies = [
- # "trl>=0.12.0",
- # "peft>=0.7.0",
- # "transformers>=4.45.0",
- # "accelerate>=0.24.0",
- # "huggingface_hub>=0.20.0",
- # "trackio",
- # "datasets",
- # "bitsandbytes",
- # ]
- # ///
- """
- GRPO training for Qwen3-1.7B query expansion model.
- Trains on top of merged SFT weights with reward function.
- """
- import os
- import re
- from collections import Counter
- import torch
- import trackio
- from datasets import load_dataset
- from huggingface_hub import login
- from peft import LoraConfig, PeftModel, get_peft_model
- from transformers import AutoModelForCausalLM, AutoTokenizer
- from trl import GRPOTrainer, GRPOConfig
- # ==================== REWARD FUNCTION ====================
- STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
- KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
- 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
- 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
- GENERIC_LEX_PHRASES = {
- 'find information about', 'search for', 'look up', 'get information',
- 'learn about', 'information on', 'details about', 'find out about',
- 'what is', 'how to', 'guide to', 'help with'
- }
- def extract_named_entities(query: str) -> set:
- """Extract named entities from query using simple heuristics."""
- entities = set()
- words = query.split()
- prev_was_entity = False
- for i, word in enumerate(words):
- clean = word.strip('.,!?:;()[]"\'')
- if not clean:
- prev_was_entity = False
- continue
- is_entity = False
- if clean.isupper() and len(clean) >= 2:
- entities.add(clean.lower())
- is_entity = True
- elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
- entities.add(clean.lower())
- is_entity = True
- elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
- entities.add(clean.lower())
- is_entity = True
- elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
- entities.add(clean.lower())
- is_entity = True
- elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
- entities.add(clean.lower())
- is_entity = True
- prev_was_entity = is_entity
- return entities
- def get_key_terms(query: str) -> set:
- words = set(query.lower().split())
- return words - KEY_TERM_STOPWORDS
- def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
- key_terms = get_key_terms(query)
- if not key_terms:
- return True
- lex_words = set(lex_line.lower().split())
- return bool(key_terms & lex_words)
- def lex_preserves_entities(lex_line: str, entities: set) -> bool:
- if not entities:
- return True
- lex_lower = lex_line.lower()
- return any(entity in lex_lower for entity in entities)
- def lex_is_generic(lex_line: str) -> bool:
- lex_lower = lex_line.lower().strip()
- for phrase in GENERIC_LEX_PHRASES:
- if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
- remaining = lex_lower
- for word in phrase.split():
- remaining = remaining.replace(word, '', 1).strip()
- if len(remaining) < 3:
- return True
- return False
- def parse_expansion(text: str) -> dict:
- lines = text.strip().split("\n")
- result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
- for line in lines:
- line = line.strip()
- if not line:
- continue
- if line.startswith("lex:"):
- result["lex"].append(line[4:].strip())
- elif line.startswith("vec:"):
- result["vec"].append(line[4:].strip())
- elif line.startswith("hyde:"):
- result["hyde"].append(line[5:].strip())
- else:
- result["invalid"].append(line)
- return result
- def edit_distance_simple(a: str, b: str) -> int:
- words_a = set(a.lower().split())
- words_b = set(b.lower().split())
- return len(words_a ^ words_b)
- def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
- a, b = a.lower().strip(), b.lower().strip()
- if a == b:
- return False
- if a in b or b in a:
- return False
- return edit_distance_simple(a, b) >= min_distance
- def echoes_query(expansion: str, query: str) -> bool:
- exp = expansion.lower().strip()
- q = query.lower().strip()
- if exp == q:
- return True
- if q in exp and len(exp) < len(q) + 10:
- return True
- return False
- def word_repetition_penalty(text: str) -> int:
- words = re.findall(r'\b\w+\b', text.lower())
- counts = Counter(words)
- penalty = 0
- for word, count in counts.items():
- if count >= 3 and word not in STOPWORDS and len(word) > 2:
- penalty += (count - 2) * 2
- return penalty
- def score_expansion(query: str, expansion: str) -> float:
- """Score expansion. Returns 0.0-1.0 for RL reward."""
- text = expansion.strip()
- # Strip end token if present
- text = text.replace('<|im_end|>', '').strip()
- # Check for <think>...</think> blocks - strip and mark as not skipped
- skipped_think = 20 # Bonus for not using thinking mode
- if '<think>' in text and '</think>' in text:
- text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
- skipped_think = 0 # No bonus if thinking was used
- # HARD FAIL: Chat template artifacts
- if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
- return 0.0
- # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
- for line in text.split("\n"):
- line = line.strip()
- if not line:
- continue
- if not line.startswith(("lex:", "vec:", "hyde:")):
- return 0.0
- parsed = parse_expansion(expansion)
- # FORMAT (0-30)
- format_score = 0
- if parsed["lex"]:
- format_score += 10
- if parsed["vec"]:
- format_score += 10
- format_score += 10
- # DIVERSITY (0-30)
- diversity_score = 0
- types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
- if types_present >= 2:
- diversity_score += 10
- total_expansions = len(parsed["lex"]) + len(parsed["vec"])
- if total_expansions >= 2:
- diversity_score += 5
- lex_score = 5
- for i, a in enumerate(parsed["lex"]):
- for b in parsed["lex"][i+1:]:
- if not is_diverse(a, b, 2):
- lex_score -= 2
- diversity_score += max(0, lex_score)
- vec_score = 5
- for i, a in enumerate(parsed["vec"]):
- for b in parsed["vec"][i+1:]:
- if not is_diverse(a, b, 3):
- vec_score -= 2
- diversity_score += max(0, vec_score)
- echo_score = 5
- for exp in parsed["lex"] + parsed["vec"]:
- if echoes_query(exp, query):
- echo_score -= 3
- diversity_score += max(0, echo_score)
- # HYDE (0-20)
- hyde_score = 0
- if parsed["hyde"]:
- hyde_text = parsed["hyde"][0]
- hyde_score += 5
- hyde_len = len(hyde_text)
- if 50 <= hyde_len <= 200:
- hyde_score += 5
- elif hyde_len < 50:
- hyde_score += 2
- if "\n" not in hyde_text:
- hyde_score += 5
- rep_penalty = word_repetition_penalty(hyde_text)
- hyde_score += max(0, 5 - rep_penalty)
- # QUALITY (0-20)
- quality_score = 5
- if parsed["lex"] and parsed["vec"]:
- avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
- avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
- if avg_lex <= avg_vec:
- quality_score += 5
- if parsed["vec"]:
- natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
- if natural == len(parsed["vec"]):
- quality_score += 5
- else:
- quality_score += 2
- if parsed["lex"]:
- lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
- if lex_with_terms == len(parsed["lex"]):
- quality_score += 5
- elif lex_with_terms > 0:
- quality_score += 2
- # NAMED ENTITY PRESERVATION
- entity_score = 0
- entities = extract_named_entities(query)
- if entities and parsed["lex"]:
- lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
- if lex_with_entities == len(parsed["lex"]):
- entity_score += 15
- elif lex_with_entities > 0:
- entity_score += 5
- else:
- entity_score -= 30
- generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
- entity_score -= generic_count * 15
- if parsed["vec"]:
- vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
- if vec_with_entities > 0:
- entity_score += 5
- elif not entities:
- entity_score = 10
- total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
- max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
- return max(0.0, min(1.0, total / max_possible))
- def extract_query_from_prompt(prompt: str) -> str:
- if "Expand this search query:" in prompt:
- return prompt.split("Expand this search query:")[-1].strip()
- return prompt.strip()
- class QMDRewardFunction:
- __name__ = "qmd_scoring_reward"
- def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
- rewards = []
- for i, completion in enumerate(completions):
- query = ""
- if prompts and i < len(prompts):
- query = extract_query_from_prompt(prompts[i])
- score = score_expansion(query, completion)
- rewards.append(score)
- return rewards
- # ==================== MAIN ====================
- def main():
- # Config
- SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
- BASE_MODEL = "Qwen/Qwen3-1.7B"
- OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
- DATASET = "tobil/qmd-query-expansion-train-v2"
- # Login
- hf_token = os.environ.get("HF_TOKEN")
- if hf_token:
- print("Logging in to HuggingFace Hub...")
- login(token=hf_token)
- # Load dataset
- print("Loading dataset...")
- dataset = load_dataset(DATASET, split="train")
- def extract_prompt(example):
- return {"prompt": example["messages"][0]["content"]}
- dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
- dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
- print(f"Using {len(dataset)} prompts for GRPO")
- # Load tokenizer and model
- print(f"Loading tokenizer from {BASE_MODEL}...")
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
- print(f"Loading SFT model from {SFT_MODEL}...")
- base_model = AutoModelForCausalLM.from_pretrained(
- BASE_MODEL,
- torch_dtype=torch.bfloat16,
- device_map="auto",
- )
- model = PeftModel.from_pretrained(base_model, SFT_MODEL)
- model = model.merge_and_unload()
- print("Model loaded and LoRA merged.")
- # Add LoRA for GRPO
- grpo_lora_config = LoraConfig(
- r=4,
- lora_alpha=8,
- lora_dropout=0.05,
- bias="none",
- task_type="CAUSAL_LM",
- target_modules=["q_proj", "v_proj"],
- )
- model = get_peft_model(model, grpo_lora_config)
- model.print_trainable_parameters()
- # GRPO config
- config = GRPOConfig(
- output_dir="qmd-query-expansion-1.7B-grpo",
- push_to_hub=True,
- hub_model_id=OUTPUT_MODEL,
- num_generations=4,
- max_completion_length=200,
- num_train_epochs=1,
- per_device_train_batch_size=2,
- gradient_accumulation_steps=8,
- learning_rate=5e-7,
- max_grad_norm=0.5,
- max_steps=200,
- logging_steps=10,
- save_strategy="epoch",
- report_to="trackio",
- project="qmd-query-expansion",
- run_name="qwen3-1.7b-grpo",
- )
- # Train
- print("Initializing GRPO trainer...")
- trainer = GRPOTrainer(
- model=model,
- processing_class=tokenizer,
- args=config,
- train_dataset=dataset,
- reward_funcs=[QMDRewardFunction()],
- )
- print("Starting GRPO training...")
- trainer.train()
- print("Pushing to Hub...")
- trainer.push_to_hub()
- trackio.finish()
- print(f"Complete! Model at: https://huggingface.co/{OUTPUT_MODEL}")
- if __name__ == "__main__":
- main()
|