grpo.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # "torch",
  12. # ]
  13. # ///
  14. """
  15. GRPO training for QMD query expansion (Qwen3-1.7B).
  16. Runs on top of merged SFT weights. Self-contained for HuggingFace Jobs:
  17. hf jobs uv run --flavor a10g-large --secrets HF_TOKEN --timeout 4h jobs/grpo.py
  18. """
  19. import os
  20. import re
  21. from collections import Counter
  22. import torch
  23. from datasets import load_dataset
  24. from huggingface_hub import login
  25. from peft import LoraConfig, PeftModel, get_peft_model
  26. from transformers import AutoModelForCausalLM, AutoTokenizer
  27. from trl import GRPOTrainer, GRPOConfig
  28. # --- Config (inlined from configs/grpo.yaml) ---
  29. BASE_MODEL = "Qwen/Qwen3-1.7B"
  30. SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
  31. OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
  32. DATASET = "tobil/qmd-query-expansion-train-v2"
  33. # =============================================================================
  34. # Reward function (inlined from reward.py — single source of truth)
  35. # =============================================================================
  36. STOPWORDS = frozenset({
  37. 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
  38. 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
  39. })
  40. KEY_TERM_STOPWORDS = frozenset({
  41. 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  42. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  43. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
  44. })
  45. GENERIC_LEX_PHRASES = frozenset({
  46. 'find information about', 'search for', 'look up', 'get information',
  47. 'learn about', 'information on', 'details about', 'find out about',
  48. 'what is', 'how to', 'guide to', 'help with',
  49. })
  50. CHAT_TEMPLATE_TOKENS = frozenset({
  51. '<|im_start|>', '<|im_end|>', '<|endoftext|>',
  52. '\nassistant\n', '\nuser\n',
  53. })
  54. def parse_expansion(text: str) -> dict:
  55. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  56. for line in text.strip().split("\n"):
  57. line = line.strip()
  58. if not line:
  59. continue
  60. if line.startswith("lex:"):
  61. result["lex"].append(line[4:].strip())
  62. elif line.startswith("vec:"):
  63. result["vec"].append(line[4:].strip())
  64. elif line.startswith("hyde:"):
  65. result["hyde"].append(line[5:].strip())
  66. else:
  67. result["invalid"].append(line)
  68. return result
  69. def clean_model_output(text: str) -> tuple[str, bool]:
  70. text = text.replace('<|im_end|>', '').strip()
  71. used_thinking = '<think>' in text and '</think>' in text
  72. if used_thinking:
  73. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
  74. return text, used_thinking
  75. def extract_named_entities(query: str) -> set:
  76. entities = set()
  77. words = query.split()
  78. prev_was_entity = False
  79. for i, word in enumerate(words):
  80. clean = word.strip('.,!?:;()[]"\'')
  81. if not clean:
  82. prev_was_entity = False
  83. continue
  84. is_entity = False
  85. if clean.isupper() and len(clean) >= 2:
  86. entities.add(clean.lower())
  87. is_entity = True
  88. elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  89. entities.add(clean.lower())
  90. is_entity = True
  91. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  92. entities.add(clean.lower())
  93. is_entity = True
  94. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  95. entities.add(clean.lower())
  96. is_entity = True
  97. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  98. entities.add(clean.lower())
  99. is_entity = True
  100. prev_was_entity = is_entity
  101. return entities
  102. def get_key_terms(query: str) -> set:
  103. return set(query.lower().split()) - KEY_TERM_STOPWORDS
  104. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  105. key_terms = get_key_terms(query)
  106. if not key_terms:
  107. return True
  108. return bool(key_terms & set(lex_line.lower().split()))
  109. def lex_preserves_entities(line: str, entities: set) -> bool:
  110. if not entities:
  111. return True
  112. lower = line.lower()
  113. return any(e in lower for e in entities)
  114. def lex_is_generic(lex_line: str) -> bool:
  115. lower = lex_line.lower().strip()
  116. for phrase in GENERIC_LEX_PHRASES:
  117. if phrase in lower or lower.startswith(phrase.split()[0]):
  118. remaining = lower
  119. for word in phrase.split():
  120. remaining = remaining.replace(word, '', 1).strip()
  121. if len(remaining) < 3:
  122. return True
  123. return False
  124. def word_set_distance(a: str, b: str) -> int:
  125. return len(set(a.lower().split()) ^ set(b.lower().split()))
  126. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  127. a, b = a.lower().strip(), b.lower().strip()
  128. if a == b or a in b or b in a:
  129. return False
  130. return word_set_distance(a, b) >= min_distance
  131. def echoes_query(expansion: str, query: str) -> bool:
  132. exp, q = expansion.lower().strip(), query.lower().strip()
  133. return exp == q or (q in exp and len(exp) < len(q) + 10)
  134. def word_repetition_penalty(text: str) -> int:
  135. counts = Counter(re.findall(r'\b\w+\b', text.lower()))
  136. return sum((c - 2) * 2 for w, c in counts.items()
  137. if c >= 3 and w not in STOPWORDS and len(w) > 2)
  138. def score_expansion(query: str, expansion: str) -> float:
  139. """Score expansion as float in [0.0, 1.0] for RL reward."""
  140. text, used_thinking = clean_model_output(expansion.strip())
  141. # Hard fail: chat template leakage
  142. if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
  143. return 0.0
  144. # Hard fail: invalid lines
  145. for line in text.split("\n"):
  146. line = line.strip()
  147. if line and not line.startswith(("lex:", "vec:", "hyde:")):
  148. return 0.0
  149. parsed = parse_expansion(text)
  150. # Format (0-30)
  151. format_score = 10 # no invalid lines
  152. if parsed["lex"]:
  153. format_score += 10
  154. if parsed["vec"]:
  155. format_score += 10
  156. # Diversity (0-30)
  157. diversity_score = 0
  158. types_present = sum(1 for t in ("lex", "vec") if parsed[t])
  159. if types_present >= 2:
  160. diversity_score += 10
  161. if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
  162. diversity_score += 5
  163. lex_div = 5
  164. for i, a in enumerate(parsed["lex"]):
  165. for b in parsed["lex"][i+1:]:
  166. if not is_diverse(a, b, 2):
  167. lex_div -= 2
  168. diversity_score += max(0, lex_div)
  169. vec_div = 5
  170. for i, a in enumerate(parsed["vec"]):
  171. for b in parsed["vec"][i+1:]:
  172. if not is_diverse(a, b, 3):
  173. vec_div -= 2
  174. diversity_score += max(0, vec_div)
  175. echo = 5
  176. for exp in parsed["lex"] + parsed["vec"]:
  177. if echoes_query(exp, query):
  178. echo -= 3
  179. diversity_score += max(0, echo)
  180. # HyDE (0-20)
  181. hyde_score = 0
  182. if parsed["hyde"]:
  183. hyde_text = parsed["hyde"][0]
  184. hyde_score += 5
  185. hyde_len = len(hyde_text)
  186. if 50 <= hyde_len <= 200:
  187. hyde_score += 5
  188. elif hyde_len < 50:
  189. hyde_score += 2
  190. if "\n" not in hyde_text:
  191. hyde_score += 5
  192. hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
  193. # Quality (0-20)
  194. quality_score = 5
  195. if parsed["lex"] and parsed["vec"]:
  196. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  197. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  198. if avg_lex <= avg_vec:
  199. quality_score += 5
  200. if parsed["vec"]:
  201. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  202. quality_score += 5 if natural == len(parsed["vec"]) else 2
  203. if parsed["lex"]:
  204. with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  205. if with_terms == len(parsed["lex"]):
  206. quality_score += 5
  207. elif with_terms > 0:
  208. quality_score += 2
  209. # Entity (-45 to +20)
  210. entity_score = 0
  211. entities = extract_named_entities(query)
  212. if entities and parsed["lex"]:
  213. with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  214. if with_entities == len(parsed["lex"]):
  215. entity_score += 15
  216. elif with_entities > 0:
  217. entity_score += 5
  218. else:
  219. entity_score -= 30
  220. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  221. if generic_count:
  222. entity_score -= generic_count * 15
  223. if parsed["vec"]:
  224. vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  225. if vec_with > 0:
  226. entity_score += 5
  227. elif not entities:
  228. entity_score = 10
  229. # Think bonus (0-20)
  230. think_bonus = 0 if used_thinking else 20
  231. total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
  232. max_possible = 140 if parsed["hyde"] else 120
  233. return max(0.0, min(1.0, total / max_possible))
  234. def extract_query_from_prompt(prompt: str) -> str:
  235. if "Expand this search query:" in prompt:
  236. query = prompt.split("Expand this search query:")[-1].strip()
  237. if "<|im_end|>" in query:
  238. query = query.split("<|im_end|>")[0].strip()
  239. return query
  240. return prompt.strip()
  241. class QMDRewardFunction:
  242. __name__ = "qmd_scoring_reward"
  243. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  244. rewards = []
  245. for i, completion in enumerate(completions):
  246. query = ""
  247. if prompts and i < len(prompts):
  248. query = extract_query_from_prompt(prompts[i])
  249. rewards.append(score_expansion(query, completion))
  250. return rewards
  251. # =============================================================================
  252. # Main training
  253. # =============================================================================
  254. def main():
  255. hf_token = os.environ.get("HF_TOKEN")
  256. if hf_token:
  257. login(token=hf_token)
  258. print(f"Loading tokenizer from {BASE_MODEL}...")
  259. tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
  260. if tokenizer.pad_token is None:
  261. tokenizer.pad_token = tokenizer.eos_token
  262. # Load and format dataset
  263. print(f"Loading dataset: {DATASET}...")
  264. dataset = load_dataset(DATASET, split="train")
  265. def extract_prompt(example):
  266. content = example["messages"][0]["content"]
  267. messages = [{"role": "user", "content": content}]
  268. formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  269. return {"prompt": formatted}
  270. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  271. dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
  272. print(f"Using {len(dataset)} prompts for GRPO")
  273. # Load base model, merge SFT adapter
  274. print(f"Loading base model {BASE_MODEL}...")
  275. base_model = AutoModelForCausalLM.from_pretrained(
  276. BASE_MODEL, torch_dtype=torch.bfloat16, device_map="auto",
  277. )
  278. print(f"Merging SFT adapter {SFT_MODEL}...")
  279. model = PeftModel.from_pretrained(base_model, SFT_MODEL)
  280. model = model.merge_and_unload()
  281. print("SFT adapter merged.")
  282. # Fresh LoRA for GRPO (small: rank 4, q/v only)
  283. grpo_lora = LoraConfig(
  284. r=4, lora_alpha=8, lora_dropout=0.05,
  285. bias="none", task_type="CAUSAL_LM",
  286. target_modules=["q_proj", "v_proj"],
  287. )
  288. model = get_peft_model(model, grpo_lora)
  289. model.print_trainable_parameters()
  290. config = GRPOConfig(
  291. output_dir="qmd-query-expansion-1.7B-grpo",
  292. push_to_hub=True,
  293. hub_model_id=OUTPUT_MODEL,
  294. num_generations=4,
  295. max_completion_length=200,
  296. beta=0.04, # KL regularization — prevents drift from SFT checkpoint
  297. num_train_epochs=1,
  298. per_device_train_batch_size=2,
  299. gradient_accumulation_steps=8,
  300. learning_rate=5e-7,
  301. max_grad_norm=0.5,
  302. max_steps=200,
  303. logging_steps=10,
  304. save_strategy="epoch",
  305. bf16=True,
  306. report_to="none",
  307. )
  308. print("Initializing GRPO trainer...")
  309. trainer = GRPOTrainer(
  310. model=model,
  311. processing_class=tokenizer,
  312. args=config,
  313. train_dataset=dataset,
  314. reward_funcs=[QMDRewardFunction()],
  315. )
  316. print("Starting GRPO training...")
  317. trainer.train()
  318. print("Pushing to Hub...")
  319. trainer.push_to_hub()
  320. print(f"Done! Model: https://huggingface.co/{OUTPUT_MODEL}")
  321. if __name__ == "__main__":
  322. main()