train_4B_grpo.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "trackio",
  10. # "datasets",
  11. # "bitsandbytes",
  12. # ]
  13. # ///
  14. """
  15. GRPO training for Qwen3-4B query expansion model.
  16. Trains on top of merged SFT weights with reward function.
  17. """
  18. import os
  19. import re
  20. from collections import Counter
  21. import torch
  22. import trackio
  23. from datasets import load_dataset
  24. from huggingface_hub import login
  25. from peft import LoraConfig, PeftModel, get_peft_model
  26. from transformers import AutoModelForCausalLM, AutoTokenizer
  27. from trl import GRPOTrainer, GRPOConfig
  28. # ==================== REWARD FUNCTION ====================
  29. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  30. KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  31. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  32. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
  33. GENERIC_LEX_PHRASES = {
  34. 'find information about', 'search for', 'look up', 'get information',
  35. 'learn about', 'information on', 'details about', 'find out about',
  36. 'what is', 'how to', 'guide to', 'help with'
  37. }
  38. def extract_named_entities(query: str) -> set:
  39. """Extract named entities from query using simple heuristics."""
  40. entities = set()
  41. words = query.split()
  42. prev_was_entity = False
  43. for i, word in enumerate(words):
  44. clean = word.strip('.,!?:;()[]"\'')
  45. if not clean:
  46. prev_was_entity = False
  47. continue
  48. is_entity = False
  49. if clean.isupper() and len(clean) >= 2:
  50. entities.add(clean.lower())
  51. is_entity = True
  52. elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  53. entities.add(clean.lower())
  54. is_entity = True
  55. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  56. entities.add(clean.lower())
  57. is_entity = True
  58. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  59. entities.add(clean.lower())
  60. is_entity = True
  61. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  62. entities.add(clean.lower())
  63. is_entity = True
  64. prev_was_entity = is_entity
  65. return entities
  66. def get_key_terms(query: str) -> set:
  67. words = set(query.lower().split())
  68. return words - KEY_TERM_STOPWORDS
  69. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  70. key_terms = get_key_terms(query)
  71. if not key_terms:
  72. return True
  73. lex_words = set(lex_line.lower().split())
  74. return bool(key_terms & lex_words)
  75. def lex_preserves_entities(lex_line: str, entities: set) -> bool:
  76. if not entities:
  77. return True
  78. lex_lower = lex_line.lower()
  79. return any(entity in lex_lower for entity in entities)
  80. def lex_is_generic(lex_line: str) -> bool:
  81. lex_lower = lex_line.lower().strip()
  82. for phrase in GENERIC_LEX_PHRASES:
  83. if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
  84. remaining = lex_lower
  85. for word in phrase.split():
  86. remaining = remaining.replace(word, '', 1).strip()
  87. if len(remaining) < 3:
  88. return True
  89. return False
  90. def parse_expansion(text: str) -> dict:
  91. lines = text.strip().split("\n")
  92. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  93. for line in lines:
  94. line = line.strip()
  95. if not line:
  96. continue
  97. if line.startswith("lex:"):
  98. result["lex"].append(line[4:].strip())
  99. elif line.startswith("vec:"):
  100. result["vec"].append(line[4:].strip())
  101. elif line.startswith("hyde:"):
  102. result["hyde"].append(line[5:].strip())
  103. else:
  104. result["invalid"].append(line)
  105. return result
  106. def edit_distance_simple(a: str, b: str) -> int:
  107. words_a = set(a.lower().split())
  108. words_b = set(b.lower().split())
  109. return len(words_a ^ words_b)
  110. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  111. a, b = a.lower().strip(), b.lower().strip()
  112. if a == b:
  113. return False
  114. if a in b or b in a:
  115. return False
  116. return edit_distance_simple(a, b) >= min_distance
  117. def echoes_query(expansion: str, query: str) -> bool:
  118. exp = expansion.lower().strip()
  119. q = query.lower().strip()
  120. if exp == q:
  121. return True
  122. if q in exp and len(exp) < len(q) + 10:
  123. return True
  124. return False
  125. def word_repetition_penalty(text: str) -> int:
  126. words = re.findall(r'\b\w+\b', text.lower())
  127. counts = Counter(words)
  128. penalty = 0
  129. for word, count in counts.items():
  130. if count >= 3 and word not in STOPWORDS and len(word) > 2:
  131. penalty += (count - 2) * 2
  132. return penalty
  133. def score_expansion(query: str, expansion: str) -> float:
  134. """Score expansion. Returns 0.0-1.0 for RL reward."""
  135. text = expansion.strip()
  136. # Strip end token if present
  137. text = text.replace('<|im_end|>', '').strip()
  138. # Check for <think>...</think> blocks - strip and mark as not skipped
  139. skipped_think = 20 # Bonus for not using thinking mode
  140. if '<think>' in text and '</think>' in text:
  141. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
  142. skipped_think = 0 # No bonus if thinking was used
  143. # HARD FAIL: Chat template artifacts
  144. if any(token in text for token in ['<|im_start|>', '\nassistant\n', '\nuser\n', '<|endoftext|>']):
  145. return 0.0
  146. # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
  147. for line in text.split("\n"):
  148. line = line.strip()
  149. if not line:
  150. continue
  151. if not line.startswith(("lex:", "vec:", "hyde:")):
  152. return 0.0
  153. parsed = parse_expansion(expansion)
  154. # FORMAT (0-30)
  155. format_score = 0
  156. if parsed["lex"]:
  157. format_score += 10
  158. if parsed["vec"]:
  159. format_score += 10
  160. format_score += 10
  161. # DIVERSITY (0-30)
  162. diversity_score = 0
  163. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  164. if types_present >= 2:
  165. diversity_score += 10
  166. total_expansions = len(parsed["lex"]) + len(parsed["vec"])
  167. if total_expansions >= 2:
  168. diversity_score += 5
  169. lex_score = 5
  170. for i, a in enumerate(parsed["lex"]):
  171. for b in parsed["lex"][i+1:]:
  172. if not is_diverse(a, b, 2):
  173. lex_score -= 2
  174. diversity_score += max(0, lex_score)
  175. vec_score = 5
  176. for i, a in enumerate(parsed["vec"]):
  177. for b in parsed["vec"][i+1:]:
  178. if not is_diverse(a, b, 3):
  179. vec_score -= 2
  180. diversity_score += max(0, vec_score)
  181. echo_score = 5
  182. for exp in parsed["lex"] + parsed["vec"]:
  183. if echoes_query(exp, query):
  184. echo_score -= 3
  185. diversity_score += max(0, echo_score)
  186. # HYDE (0-20)
  187. hyde_score = 0
  188. if parsed["hyde"]:
  189. hyde_text = parsed["hyde"][0]
  190. hyde_score += 5
  191. hyde_len = len(hyde_text)
  192. if 50 <= hyde_len <= 200:
  193. hyde_score += 5
  194. elif hyde_len < 50:
  195. hyde_score += 2
  196. if "\n" not in hyde_text:
  197. hyde_score += 5
  198. rep_penalty = word_repetition_penalty(hyde_text)
  199. hyde_score += max(0, 5 - rep_penalty)
  200. # QUALITY (0-20)
  201. quality_score = 5
  202. if parsed["lex"] and parsed["vec"]:
  203. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  204. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  205. if avg_lex <= avg_vec:
  206. quality_score += 5
  207. if parsed["vec"]:
  208. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  209. if natural == len(parsed["vec"]):
  210. quality_score += 5
  211. else:
  212. quality_score += 2
  213. if parsed["lex"]:
  214. lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  215. if lex_with_terms == len(parsed["lex"]):
  216. quality_score += 5
  217. elif lex_with_terms > 0:
  218. quality_score += 2
  219. # NAMED ENTITY PRESERVATION
  220. entity_score = 0
  221. entities = extract_named_entities(query)
  222. if entities and parsed["lex"]:
  223. lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  224. if lex_with_entities == len(parsed["lex"]):
  225. entity_score += 15
  226. elif lex_with_entities > 0:
  227. entity_score += 5
  228. else:
  229. entity_score -= 30
  230. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  231. entity_score -= generic_count * 15
  232. if parsed["vec"]:
  233. vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  234. if vec_with_entities > 0:
  235. entity_score += 5
  236. elif not entities:
  237. entity_score = 10
  238. total = format_score + diversity_score + hyde_score + quality_score + entity_score + skipped_think
  239. max_possible = 140 if parsed["hyde"] else 120 # +20 for skipped_think bonus
  240. return max(0.0, min(1.0, total / max_possible))
  241. def extract_query_from_prompt(prompt: str) -> str:
  242. if "Expand this search query:" in prompt:
  243. return prompt.split("Expand this search query:")[-1].strip()
  244. return prompt.strip()
  245. class QMDRewardFunction:
  246. __name__ = "qmd_scoring_reward"
  247. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  248. rewards = []
  249. for i, completion in enumerate(completions):
  250. query = ""
  251. if prompts and i < len(prompts):
  252. query = extract_query_from_prompt(prompts[i])
  253. score = score_expansion(query, completion)
  254. rewards.append(score)
  255. return rewards
  256. # ==================== MAIN ====================
  257. def main():
  258. # Config
  259. SFT_MODEL = "tobil/qmd-query-expansion-4B-sft"
  260. BASE_MODEL = "Qwen/Qwen3-4B"
  261. OUTPUT_MODEL = "tobil/qmd-query-expansion-4B-grpo"
  262. DATASET = "tobil/qmd-query-expansion-train-v2"
  263. # Login
  264. hf_token = os.environ.get("HF_TOKEN")
  265. if hf_token:
  266. print("Logging in to HuggingFace Hub...")
  267. login(token=hf_token)
  268. # Load dataset
  269. print("Loading dataset...")
  270. dataset = load_dataset(DATASET, split="train")
  271. def extract_prompt(example):
  272. return {"prompt": example["messages"][0]["content"]}
  273. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  274. dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
  275. print(f"Using {len(dataset)} prompts for GRPO")
  276. # Load tokenizer and model
  277. print(f"Loading tokenizer from {BASE_MODEL}...")
  278. tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
  279. if tokenizer.pad_token is None:
  280. tokenizer.pad_token = tokenizer.eos_token
  281. print(f"Loading SFT model from {SFT_MODEL}...")
  282. base_model = AutoModelForCausalLM.from_pretrained(
  283. BASE_MODEL,
  284. torch_dtype=torch.bfloat16,
  285. device_map="auto",
  286. )
  287. model = PeftModel.from_pretrained(base_model, SFT_MODEL)
  288. model = model.merge_and_unload()
  289. print("Model loaded and LoRA merged.")
  290. # Add LoRA for GRPO
  291. grpo_lora_config = LoraConfig(
  292. r=4,
  293. lora_alpha=8,
  294. lora_dropout=0.05,
  295. bias="none",
  296. task_type="CAUSAL_LM",
  297. target_modules=["q_proj", "v_proj"],
  298. )
  299. model = get_peft_model(model, grpo_lora_config)
  300. model.print_trainable_parameters()
  301. # GRPO config
  302. config = GRPOConfig(
  303. output_dir="qmd-query-expansion-4B-grpo",
  304. push_to_hub=True,
  305. hub_model_id=OUTPUT_MODEL,
  306. num_generations=4,
  307. max_completion_length=200,
  308. num_train_epochs=1,
  309. per_device_train_batch_size=1, # Smaller for 4B model
  310. gradient_accumulation_steps=16, # Compensate with more accumulation
  311. learning_rate=5e-7,
  312. max_grad_norm=0.5,
  313. max_steps=200,
  314. logging_steps=10,
  315. save_strategy="epoch",
  316. report_to="trackio",
  317. project="qmd-query-expansion",
  318. run_name="qwen3-4b-grpo",
  319. )
  320. # Train
  321. print("Initializing GRPO trainer...")
  322. trainer = GRPOTrainer(
  323. model=model,
  324. processing_class=tokenizer,
  325. args=config,
  326. train_dataset=dataset,
  327. reward_funcs=[QMDRewardFunction()],
  328. )
  329. print("Starting GRPO training...")
  330. trainer.train()
  331. print("Pushing to Hub...")
  332. trainer.push_to_hub()
  333. trackio.finish()
  334. print(f"Complete! Model at: https://huggingface.co/{OUTPUT_MODEL}")
  335. if __name__ == "__main__":
  336. main()