train_1.7B_grpo.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "trackio",
  10. # "datasets",
  11. # "bitsandbytes",
  12. # ]
  13. # ///
  14. """
  15. GRPO training for Qwen3-1.7B query expansion model.
  16. Trains on top of merged SFT weights with reward function.
  17. """
  18. import os
  19. import re
  20. from collections import Counter
  21. import torch
  22. import trackio
  23. from datasets import load_dataset
  24. from huggingface_hub import login
  25. from peft import LoraConfig, PeftModel, get_peft_model
  26. from transformers import AutoModelForCausalLM, AutoTokenizer
  27. from trl import GRPOTrainer, GRPOConfig
  28. # ==================== REWARD FUNCTION ====================
  29. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  30. KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  31. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  32. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
  33. GENERIC_LEX_PHRASES = {
  34. 'find information about', 'search for', 'look up', 'get information',
  35. 'learn about', 'information on', 'details about', 'find out about',
  36. 'what is', 'how to', 'guide to', 'help with'
  37. }
  38. def extract_named_entities(query: str) -> set:
  39. """Extract named entities from query using simple heuristics."""
  40. entities = set()
  41. words = query.split()
  42. prev_was_entity = False
  43. for i, word in enumerate(words):
  44. clean = word.strip('.,!?:;()[]"\'')
  45. if not clean:
  46. prev_was_entity = False
  47. continue
  48. is_entity = False
  49. if clean.isupper() and len(clean) >= 2:
  50. entities.add(clean.lower())
  51. is_entity = True
  52. elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  53. entities.add(clean.lower())
  54. is_entity = True
  55. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  56. entities.add(clean.lower())
  57. is_entity = True
  58. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  59. entities.add(clean.lower())
  60. is_entity = True
  61. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  62. entities.add(clean.lower())
  63. is_entity = True
  64. prev_was_entity = is_entity
  65. return entities
  66. def get_key_terms(query: str) -> set:
  67. words = set(query.lower().split())
  68. return words - KEY_TERM_STOPWORDS
  69. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  70. key_terms = get_key_terms(query)
  71. if not key_terms:
  72. return True
  73. lex_words = set(lex_line.lower().split())
  74. return bool(key_terms & lex_words)
  75. def lex_preserves_entities(lex_line: str, entities: set) -> bool:
  76. if not entities:
  77. return True
  78. lex_lower = lex_line.lower()
  79. return any(entity in lex_lower for entity in entities)
  80. def lex_is_generic(lex_line: str) -> bool:
  81. lex_lower = lex_line.lower().strip()
  82. for phrase in GENERIC_LEX_PHRASES:
  83. if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
  84. remaining = lex_lower
  85. for word in phrase.split():
  86. remaining = remaining.replace(word, '', 1).strip()
  87. if len(remaining) < 3:
  88. return True
  89. return False
  90. def parse_expansion(text: str) -> dict:
  91. lines = text.strip().split("\n")
  92. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  93. for line in lines:
  94. line = line.strip()
  95. if not line:
  96. continue
  97. if line.startswith("lex:"):
  98. result["lex"].append(line[4:].strip())
  99. elif line.startswith("vec:"):
  100. result["vec"].append(line[4:].strip())
  101. elif line.startswith("hyde:"):
  102. result["hyde"].append(line[5:].strip())
  103. else:
  104. result["invalid"].append(line)
  105. return result
  106. def edit_distance_simple(a: str, b: str) -> int:
  107. words_a = set(a.lower().split())
  108. words_b = set(b.lower().split())
  109. return len(words_a ^ words_b)
  110. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  111. a, b = a.lower().strip(), b.lower().strip()
  112. if a == b:
  113. return False
  114. if a in b or b in a:
  115. return False
  116. return edit_distance_simple(a, b) >= min_distance
  117. def echoes_query(expansion: str, query: str) -> bool:
  118. exp = expansion.lower().strip()
  119. q = query.lower().strip()
  120. if exp == q:
  121. return True
  122. if q in exp and len(exp) < len(q) + 10:
  123. return True
  124. return False
  125. def word_repetition_penalty(text: str) -> int:
  126. words = re.findall(r'\b\w+\b', text.lower())
  127. counts = Counter(words)
  128. penalty = 0
  129. for word, count in counts.items():
  130. if count >= 3 and word not in STOPWORDS and len(word) > 2:
  131. penalty += (count - 2) * 2
  132. return penalty
  133. def score_expansion(query: str, expansion: str) -> float:
  134. """Score expansion. Returns 0.0-1.0 for RL reward."""
  135. text = expansion.strip()
  136. # HARD FAIL: Chat template artifacts
  137. if any(token in text for token in ['<|im_start|>', '<|im_end|>', '<think>', '</think>',
  138. '\nassistant\n', '\nuser\n', '<|endoftext|>']):
  139. return 0.0
  140. # HARD FAIL: EVERY line must start with lex:, vec:, or hyde:
  141. for line in text.split("\n"):
  142. line = line.strip()
  143. if not line:
  144. continue
  145. if not line.startswith(("lex:", "vec:", "hyde:")):
  146. return 0.0
  147. parsed = parse_expansion(expansion)
  148. # FORMAT (0-30)
  149. format_score = 0
  150. if parsed["lex"]:
  151. format_score += 10
  152. if parsed["vec"]:
  153. format_score += 10
  154. format_score += 10
  155. # DIVERSITY (0-30)
  156. diversity_score = 0
  157. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  158. if types_present >= 2:
  159. diversity_score += 10
  160. total_expansions = len(parsed["lex"]) + len(parsed["vec"])
  161. if total_expansions >= 2:
  162. diversity_score += 5
  163. lex_score = 5
  164. for i, a in enumerate(parsed["lex"]):
  165. for b in parsed["lex"][i+1:]:
  166. if not is_diverse(a, b, 2):
  167. lex_score -= 2
  168. diversity_score += max(0, lex_score)
  169. vec_score = 5
  170. for i, a in enumerate(parsed["vec"]):
  171. for b in parsed["vec"][i+1:]:
  172. if not is_diverse(a, b, 3):
  173. vec_score -= 2
  174. diversity_score += max(0, vec_score)
  175. echo_score = 5
  176. for exp in parsed["lex"] + parsed["vec"]:
  177. if echoes_query(exp, query):
  178. echo_score -= 3
  179. diversity_score += max(0, echo_score)
  180. # HYDE (0-20)
  181. hyde_score = 0
  182. if parsed["hyde"]:
  183. hyde_text = parsed["hyde"][0]
  184. hyde_score += 5
  185. hyde_len = len(hyde_text)
  186. if 50 <= hyde_len <= 200:
  187. hyde_score += 5
  188. elif hyde_len < 50:
  189. hyde_score += 2
  190. if "\n" not in hyde_text:
  191. hyde_score += 5
  192. rep_penalty = word_repetition_penalty(hyde_text)
  193. hyde_score += max(0, 5 - rep_penalty)
  194. # QUALITY (0-20)
  195. quality_score = 5
  196. if parsed["lex"] and parsed["vec"]:
  197. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  198. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  199. if avg_lex <= avg_vec:
  200. quality_score += 5
  201. if parsed["vec"]:
  202. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  203. if natural == len(parsed["vec"]):
  204. quality_score += 5
  205. else:
  206. quality_score += 2
  207. if parsed["lex"]:
  208. lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  209. if lex_with_terms == len(parsed["lex"]):
  210. quality_score += 5
  211. elif lex_with_terms > 0:
  212. quality_score += 2
  213. # NAMED ENTITY PRESERVATION
  214. entity_score = 0
  215. entities = extract_named_entities(query)
  216. if entities and parsed["lex"]:
  217. lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  218. if lex_with_entities == len(parsed["lex"]):
  219. entity_score += 15
  220. elif lex_with_entities > 0:
  221. entity_score += 5
  222. else:
  223. entity_score -= 30
  224. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  225. entity_score -= generic_count * 15
  226. if parsed["vec"]:
  227. vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  228. if vec_with_entities > 0:
  229. entity_score += 5
  230. elif not entities:
  231. entity_score = 10
  232. total = format_score + diversity_score + hyde_score + quality_score + entity_score
  233. max_possible = 120 if parsed["hyde"] else 100
  234. return max(0.0, min(1.0, total / max_possible))
  235. def extract_query_from_prompt(prompt: str) -> str:
  236. if "Expand this search query:" in prompt:
  237. return prompt.split("Expand this search query:")[-1].strip()
  238. return prompt.strip()
  239. class QMDRewardFunction:
  240. __name__ = "qmd_scoring_reward"
  241. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  242. rewards = []
  243. for i, completion in enumerate(completions):
  244. query = ""
  245. if prompts and i < len(prompts):
  246. query = extract_query_from_prompt(prompts[i])
  247. score = score_expansion(query, completion)
  248. rewards.append(score)
  249. return rewards
  250. # ==================== MAIN ====================
  251. def main():
  252. # Config
  253. SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
  254. BASE_MODEL = "Qwen/Qwen3-1.7B"
  255. OUTPUT_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
  256. DATASET = "tobil/qmd-query-expansion-train-v2"
  257. # Login
  258. hf_token = os.environ.get("HF_TOKEN")
  259. if hf_token:
  260. print("Logging in to HuggingFace Hub...")
  261. login(token=hf_token)
  262. # Load dataset
  263. print("Loading dataset...")
  264. dataset = load_dataset(DATASET, split="train")
  265. def extract_prompt(example):
  266. return {"prompt": example["messages"][0]["content"]}
  267. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  268. dataset = dataset.shuffle(seed=42).select(range(min(1000, len(dataset))))
  269. print(f"Using {len(dataset)} prompts for GRPO")
  270. # Load tokenizer and model
  271. print(f"Loading tokenizer from {BASE_MODEL}...")
  272. tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
  273. if tokenizer.pad_token is None:
  274. tokenizer.pad_token = tokenizer.eos_token
  275. print(f"Loading SFT model from {SFT_MODEL}...")
  276. base_model = AutoModelForCausalLM.from_pretrained(
  277. BASE_MODEL,
  278. torch_dtype=torch.bfloat16,
  279. device_map="auto",
  280. )
  281. model = PeftModel.from_pretrained(base_model, SFT_MODEL)
  282. model = model.merge_and_unload()
  283. print("Model loaded and LoRA merged.")
  284. # Add LoRA for GRPO
  285. grpo_lora_config = LoraConfig(
  286. r=4,
  287. lora_alpha=8,
  288. lora_dropout=0.05,
  289. bias="none",
  290. task_type="CAUSAL_LM",
  291. target_modules=["q_proj", "v_proj"],
  292. )
  293. model = get_peft_model(model, grpo_lora_config)
  294. model.print_trainable_parameters()
  295. # GRPO config
  296. config = GRPOConfig(
  297. output_dir="qmd-query-expansion-1.7B-grpo",
  298. push_to_hub=True,
  299. hub_model_id=OUTPUT_MODEL,
  300. num_generations=4,
  301. max_completion_length=200,
  302. num_train_epochs=1,
  303. per_device_train_batch_size=2,
  304. gradient_accumulation_steps=8,
  305. learning_rate=5e-7,
  306. max_grad_norm=0.5,
  307. max_steps=200,
  308. logging_steps=10,
  309. save_strategy="epoch",
  310. report_to="trackio",
  311. project="qmd-query-expansion",
  312. run_name="qwen3-1.7b-grpo",
  313. )
  314. # Train
  315. print("Initializing GRPO trainer...")
  316. trainer = GRPOTrainer(
  317. model=model,
  318. processing_class=tokenizer,
  319. args=config,
  320. train_dataset=dataset,
  321. reward_funcs=[QMDRewardFunction()],
  322. )
  323. print("Starting GRPO training...")
  324. trainer.train()
  325. print("Pushing to Hub...")
  326. trainer.push_to_hub()
  327. trackio.finish()
  328. print(f"Complete! Model at: https://huggingface.co/{OUTPUT_MODEL}")
  329. if __name__ == "__main__":
  330. main()