rl.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "trackio",
  10. # "datasets",
  11. # "bitsandbytes",
  12. # "pyyaml",
  13. # ]
  14. # ///
  15. """
  16. GRPO (Group Relative Policy Optimization) training for QMD query expansion.
  17. Uses the scoring system from SCORING.md as the reward function.
  18. Usage:
  19. uv run rl.py --config configs/grpo_v4.yaml
  20. uv run rl.py --config configs/grpo_v4.yaml --dry-run
  21. """
  22. import os
  23. import re
  24. import argparse
  25. import yaml
  26. import torch
  27. import trackio
  28. from collections import Counter
  29. from datasets import load_dataset
  30. from huggingface_hub import login
  31. from peft import LoraConfig, PeftModel, get_peft_model
  32. from transformers import AutoModelForCausalLM, AutoTokenizer
  33. from trl import GRPOTrainer, GRPOConfig
  34. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  35. KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  36. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  37. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
  38. # Generic filler phrases that should never be in lex queries
  39. GENERIC_LEX_PHRASES = {
  40. 'find information about', 'search for', 'look up', 'get information',
  41. 'learn about', 'information on', 'details about', 'find out about',
  42. 'what is', 'how to', 'guide to', 'help with'
  43. }
  44. def extract_named_entities(query: str) -> set:
  45. """Extract named entities from query using simple heuristics.
  46. Named entities are:
  47. - Capitalized words (except first word which may just be sentence start)
  48. - All-caps words/acronyms (TDS, API, GPU)
  49. - Technical terms with special chars (node.js, C++, .NET)
  50. - Words following acronyms/proper nouns (TDS motorsports -> both words)
  51. """
  52. entities = set()
  53. words = query.split()
  54. prev_was_entity = False
  55. for i, word in enumerate(words):
  56. # Clean punctuation but keep internal special chars
  57. clean = word.strip('.,!?:;()[]"\'')
  58. if not clean:
  59. prev_was_entity = False
  60. continue
  61. is_entity = False
  62. # All-caps words (acronyms): TDS, API, GPU, etc.
  63. if clean.isupper() and len(clean) >= 2:
  64. entities.add(clean.lower())
  65. is_entity = True
  66. # Capitalized words (not first word, not common words)
  67. elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  68. entities.add(clean.lower())
  69. is_entity = True
  70. # Technical terms with special chars: node.js, C++, .NET
  71. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  72. entities.add(clean.lower())
  73. is_entity = True
  74. # CamelCase: JavaScript, TypeScript, etc.
  75. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  76. entities.add(clean.lower())
  77. is_entity = True
  78. # Word following an entity is likely part of compound name (TDS motorsports)
  79. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  80. entities.add(clean.lower())
  81. is_entity = True
  82. prev_was_entity = is_entity
  83. return entities
  84. def get_key_terms(query: str) -> set:
  85. """Get key terms (non-stopwords) from query."""
  86. words = set(query.lower().split())
  87. return words - KEY_TERM_STOPWORDS
  88. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  89. """Check if lex line preserves key terms from query."""
  90. key_terms = get_key_terms(query)
  91. if not key_terms:
  92. return True
  93. lex_words = set(lex_line.lower().split())
  94. return bool(key_terms & lex_words)
  95. def lex_preserves_entities(lex_line: str, entities: set) -> bool:
  96. """Check if lex line contains at least one named entity."""
  97. if not entities:
  98. return True # No entities to preserve
  99. lex_lower = lex_line.lower()
  100. return any(entity in lex_lower for entity in entities)
  101. def lex_is_generic(lex_line: str) -> bool:
  102. """Check if lex line is a generic filler phrase."""
  103. lex_lower = lex_line.lower().strip()
  104. for phrase in GENERIC_LEX_PHRASES:
  105. if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
  106. # Also check if it's ONLY the generic phrase with no specifics
  107. remaining = lex_lower
  108. for word in phrase.split():
  109. remaining = remaining.replace(word, '', 1).strip()
  110. if len(remaining) < 3: # Nothing specific left
  111. return True
  112. return False
  113. def parse_expansion(text: str) -> dict:
  114. lines = text.strip().split("\n")
  115. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  116. for line in lines:
  117. line = line.strip()
  118. if not line:
  119. continue
  120. if line.startswith("lex:"):
  121. result["lex"].append(line[4:].strip())
  122. elif line.startswith("vec:"):
  123. result["vec"].append(line[4:].strip())
  124. elif line.startswith("hyde:"):
  125. result["hyde"].append(line[5:].strip())
  126. else:
  127. result["invalid"].append(line)
  128. return result
  129. def edit_distance_simple(a: str, b: str) -> int:
  130. words_a = set(a.lower().split())
  131. words_b = set(b.lower().split())
  132. return len(words_a ^ words_b)
  133. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  134. a, b = a.lower().strip(), b.lower().strip()
  135. if a == b:
  136. return False
  137. if a in b or b in a:
  138. return False
  139. return edit_distance_simple(a, b) >= min_distance
  140. def echoes_query(expansion: str, query: str) -> bool:
  141. exp = expansion.lower().strip()
  142. q = query.lower().strip()
  143. if exp == q:
  144. return True
  145. if q in exp and len(exp) < len(q) + 10:
  146. return True
  147. return False
  148. def word_repetition_penalty(text: str) -> int:
  149. words = re.findall(r'\b\w+\b', text.lower())
  150. counts = Counter(words)
  151. penalty = 0
  152. for word, count in counts.items():
  153. if count >= 3 and word not in STOPWORDS and len(word) > 2:
  154. penalty += (count - 2) * 2
  155. return penalty
  156. def score_expansion(query: str, expansion: str) -> float:
  157. """Score expansion. Returns 0.0-1.0 for RL reward."""
  158. text = expansion.strip()
  159. # HARD FAIL: Must start with valid prefix (prevents verbose explanations)
  160. first_line = text.split("\n")[0].strip() if text else ""
  161. if not first_line.startswith(("lex:", "vec:", "hyde:")):
  162. return 0.0 # Zero reward for wrong format
  163. parsed = parse_expansion(expansion)
  164. # FORMAT (0-30)
  165. format_score = 0
  166. if parsed["lex"]:
  167. format_score += 10
  168. if parsed["vec"]:
  169. format_score += 10
  170. if not parsed["invalid"]:
  171. format_score += 10
  172. else:
  173. format_score += max(0, 10 - len(parsed["invalid"]) * 5)
  174. # DIVERSITY (0-30)
  175. diversity_score = 0
  176. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  177. if types_present >= 2:
  178. diversity_score += 10
  179. total_expansions = len(parsed["lex"]) + len(parsed["vec"])
  180. if total_expansions >= 2:
  181. diversity_score += 5
  182. lex_score = 5
  183. for i, a in enumerate(parsed["lex"]):
  184. for b in parsed["lex"][i+1:]:
  185. if not is_diverse(a, b, 2):
  186. lex_score -= 2
  187. diversity_score += max(0, lex_score)
  188. vec_score = 5
  189. for i, a in enumerate(parsed["vec"]):
  190. for b in parsed["vec"][i+1:]:
  191. if not is_diverse(a, b, 3):
  192. vec_score -= 2
  193. diversity_score += max(0, vec_score)
  194. echo_score = 5
  195. for exp in parsed["lex"] + parsed["vec"]:
  196. if echoes_query(exp, query):
  197. echo_score -= 3
  198. diversity_score += max(0, echo_score)
  199. # HYDE (0-20)
  200. hyde_score = 0
  201. if parsed["hyde"]:
  202. hyde_text = parsed["hyde"][0]
  203. hyde_score += 5
  204. hyde_len = len(hyde_text)
  205. if 50 <= hyde_len <= 200:
  206. hyde_score += 5
  207. elif hyde_len < 50:
  208. hyde_score += 2
  209. if "\n" not in hyde_text:
  210. hyde_score += 5
  211. rep_penalty = word_repetition_penalty(hyde_text)
  212. hyde_score += max(0, 5 - rep_penalty)
  213. # QUALITY (0-20)
  214. quality_score = 5
  215. if parsed["lex"] and parsed["vec"]:
  216. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  217. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  218. if avg_lex <= avg_vec:
  219. quality_score += 5
  220. if parsed["vec"]:
  221. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  222. if natural == len(parsed["vec"]):
  223. quality_score += 5
  224. else:
  225. quality_score += 2
  226. if parsed["lex"]:
  227. lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  228. if lex_with_terms == len(parsed["lex"]):
  229. quality_score += 5
  230. elif lex_with_terms > 0:
  231. quality_score += 2
  232. # NAMED ENTITY PRESERVATION (critical for quality)
  233. # This score can go heavily negative to punish missing entities
  234. entity_score = 0
  235. entities = extract_named_entities(query)
  236. if entities and parsed["lex"]:
  237. # Count lex lines that preserve at least one entity
  238. lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  239. if lex_with_entities == len(parsed["lex"]):
  240. entity_score += 15 # All lex lines have entities - great!
  241. elif lex_with_entities > 0:
  242. entity_score += 5 # Some have entities
  243. else:
  244. entity_score -= 30 # NO lex lines have entities - HEAVY penalty!
  245. # Penalize generic filler phrases in lex (these are useless for BM25)
  246. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  247. entity_score -= generic_count * 15 # -15 per generic phrase
  248. # Bonus for entities in vec too (less critical but nice)
  249. if parsed["vec"]:
  250. vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  251. if vec_with_entities > 0:
  252. entity_score += 5
  253. elif not entities:
  254. # No entities in query - give base score
  255. entity_score = 10
  256. # Entity score CAN go negative to heavily penalize missing entities
  257. total = format_score + diversity_score + hyde_score + quality_score + entity_score
  258. max_possible = 120 if parsed["hyde"] else 100
  259. return max(0.0, min(1.0, total / max_possible)) # Clamp to 0.0-1.0
  260. def extract_query_from_prompt(prompt: str) -> str:
  261. if "Expand this search query:" in prompt:
  262. return prompt.split("Expand this search query:")[-1].strip()
  263. return prompt.strip()
  264. class QMDRewardFunction:
  265. __name__ = "qmd_scoring_reward"
  266. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  267. rewards = []
  268. for i, completion in enumerate(completions):
  269. query = ""
  270. if prompts and i < len(prompts):
  271. query = extract_query_from_prompt(prompts[i])
  272. score = score_expansion(query, completion)
  273. rewards.append(score)
  274. return rewards
  275. def main():
  276. parser = argparse.ArgumentParser()
  277. parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
  278. parser.add_argument("--dry-run", action="store_true")
  279. args = parser.parse_args()
  280. with open(args.config) as f:
  281. cfg = yaml.safe_load(f)
  282. if args.dry_run:
  283. print("GRPO Training Configuration:")
  284. print(yaml.dump(cfg, default_flow_style=False))
  285. print("\nTesting reward function...")
  286. # Test 1: Basic query
  287. test_good = "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."
  288. test_bad = "auth is important for security"
  289. print(f"\n Query: 'auth'")
  290. print(f" Good output score: {score_expansion('auth', test_good):.2f}")
  291. print(f" Bad output score: {score_expansion('auth', test_bad):.2f}")
  292. # Test 2: Named entity query (the critical case!)
  293. query_entity = "who is TDS motorsports"
  294. entities = extract_named_entities(query_entity)
  295. print(f"\n Query: '{query_entity}'")
  296. print(f" Extracted entities: {entities}")
  297. good_entity = "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"
  298. bad_entity = "lex: find information about\nlex: company details\nvec: who is this company"
  299. print(f" Good (preserves entity): {score_expansion(query_entity, good_entity):.2f}")
  300. print(f" Bad (generic phrases): {score_expansion(query_entity, bad_entity):.2f}")
  301. # Test 3: Technical term
  302. query_tech = "how to use React hooks"
  303. entities_tech = extract_named_entities(query_tech)
  304. print(f"\n Query: '{query_tech}'")
  305. print(f" Extracted entities: {entities_tech}")
  306. good_tech = "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"
  307. bad_tech = "lex: programming tutorial\nlex: how to code\nvec: learn web development"
  308. print(f" Good (preserves React): {score_expansion(query_tech, good_tech):.2f}")
  309. print(f" Bad (generic): {score_expansion(query_tech, bad_tech):.2f}")
  310. return
  311. # Login
  312. hf_token = os.environ.get("HF_TOKEN")
  313. if hf_token:
  314. print("Logging in to HuggingFace Hub...")
  315. login(token=hf_token)
  316. # Load dataset
  317. print("Loading dataset...")
  318. dataset = load_dataset(cfg["dataset"]["name"], split="train")
  319. def extract_prompt(example):
  320. return {"prompt": example[cfg["dataset"]["prompt_field"]][0]["content"]}
  321. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  322. max_samples = cfg["dataset"].get("max_samples", len(dataset))
  323. dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
  324. print(f"Using {len(dataset)} prompts for GRPO")
  325. # Load tokenizer and model
  326. print(f"Loading tokenizer from {cfg['model']['base']}...")
  327. tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["base"])
  328. if tokenizer.pad_token is None:
  329. tokenizer.pad_token = tokenizer.eos_token
  330. print(f"Loading SFT model from {cfg['model']['sft']}...")
  331. base_model = AutoModelForCausalLM.from_pretrained(
  332. cfg["model"]["base"],
  333. torch_dtype=torch.bfloat16,
  334. device_map="auto",
  335. )
  336. model = PeftModel.from_pretrained(base_model, cfg["model"]["sft"])
  337. model = model.merge_and_unload()
  338. print("Model loaded and LoRA merged.")
  339. # Add LoRA for GRPO
  340. grpo_lora_config = LoraConfig(
  341. r=cfg["lora"]["rank"],
  342. lora_alpha=cfg["lora"]["alpha"],
  343. lora_dropout=cfg["lora"]["dropout"],
  344. bias="none",
  345. task_type="CAUSAL_LM",
  346. target_modules=cfg["lora"]["target_modules"],
  347. )
  348. model = get_peft_model(model, grpo_lora_config)
  349. model.print_trainable_parameters()
  350. # Reward function
  351. reward_fn = QMDRewardFunction()
  352. # GRPO config
  353. config = GRPOConfig(
  354. output_dir=cfg["model"]["output"].split("/")[-1],
  355. push_to_hub=True,
  356. hub_model_id=cfg["model"]["output"],
  357. num_generations=cfg["grpo"]["num_generations"],
  358. max_completion_length=cfg["grpo"]["max_completion_length"],
  359. num_train_epochs=cfg["training"]["epochs"],
  360. per_device_train_batch_size=cfg["training"]["batch_size"],
  361. gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
  362. learning_rate=cfg["training"]["learning_rate"],
  363. max_grad_norm=cfg["training"]["max_grad_norm"],
  364. logging_steps=10,
  365. save_strategy="epoch",
  366. report_to="trackio",
  367. project=cfg["tracking"]["project"],
  368. run_name=cfg["tracking"]["run_name"],
  369. )
  370. # Train
  371. print("Initializing GRPO trainer...")
  372. trainer = GRPOTrainer(
  373. model=model,
  374. processing_class=tokenizer,
  375. args=config,
  376. train_dataset=dataset,
  377. reward_funcs=[reward_fn],
  378. )
  379. print("Starting GRPO training...")
  380. trainer.train()
  381. print("Pushing to Hub...")
  382. trainer.push_to_hub()
  383. trackio.finish()
  384. print(f"Done! Model at: https://huggingface.co/{cfg['model']['output']}")
  385. if __name__ == "__main__":
  386. main()