rl.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "trackio",
  10. # "datasets",
  11. # "bitsandbytes",
  12. # "pyyaml",
  13. # ]
  14. # ///
  15. """
  16. GRPO (Group Relative Policy Optimization) training for QMD query expansion.
  17. Uses the scoring system from SCORING.md as the reward function.
  18. Usage:
  19. uv run rl.py --config configs/grpo_v4.yaml
  20. uv run rl.py --config configs/grpo_v4.yaml --dry-run
  21. """
  22. import os
  23. import re
  24. import argparse
  25. import yaml
  26. import torch
  27. import trackio
  28. from collections import Counter
  29. from datasets import load_dataset
  30. from huggingface_hub import login
  31. from peft import LoraConfig, PeftModel, get_peft_model
  32. from transformers import AutoModelForCausalLM, AutoTokenizer
  33. from trl import GRPOTrainer, GRPOConfig
  34. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  35. KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  36. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we'}
  37. def get_key_terms(query: str) -> set:
  38. words = set(query.lower().split())
  39. return words - KEY_TERM_STOPWORDS
  40. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  41. key_terms = get_key_terms(query)
  42. if not key_terms:
  43. return True
  44. lex_words = set(lex_line.lower().split())
  45. return bool(key_terms & lex_words)
  46. def parse_expansion(text: str) -> dict:
  47. lines = text.strip().split("\n")
  48. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  49. for line in lines:
  50. line = line.strip()
  51. if not line:
  52. continue
  53. if line.startswith("lex:"):
  54. result["lex"].append(line[4:].strip())
  55. elif line.startswith("vec:"):
  56. result["vec"].append(line[4:].strip())
  57. elif line.startswith("hyde:"):
  58. result["hyde"].append(line[5:].strip())
  59. else:
  60. result["invalid"].append(line)
  61. return result
  62. def edit_distance_simple(a: str, b: str) -> int:
  63. words_a = set(a.lower().split())
  64. words_b = set(b.lower().split())
  65. return len(words_a ^ words_b)
  66. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  67. a, b = a.lower().strip(), b.lower().strip()
  68. if a == b:
  69. return False
  70. if a in b or b in a:
  71. return False
  72. return edit_distance_simple(a, b) >= min_distance
  73. def echoes_query(expansion: str, query: str) -> bool:
  74. exp = expansion.lower().strip()
  75. q = query.lower().strip()
  76. if exp == q:
  77. return True
  78. if q in exp and len(exp) < len(q) + 10:
  79. return True
  80. return False
  81. def word_repetition_penalty(text: str) -> int:
  82. words = re.findall(r'\b\w+\b', text.lower())
  83. counts = Counter(words)
  84. penalty = 0
  85. for word, count in counts.items():
  86. if count >= 3 and word not in STOPWORDS and len(word) > 2:
  87. penalty += (count - 2) * 2
  88. return penalty
  89. def score_expansion(query: str, expansion: str) -> float:
  90. """Score expansion. Returns 0.0-1.0 for RL reward."""
  91. text = expansion.strip()
  92. # HARD FAIL: Must start with valid prefix (prevents verbose explanations)
  93. first_line = text.split("\n")[0].strip() if text else ""
  94. if not first_line.startswith(("lex:", "vec:", "hyde:")):
  95. return 0.0 # Zero reward for wrong format
  96. parsed = parse_expansion(expansion)
  97. # FORMAT (0-30)
  98. format_score = 0
  99. if parsed["lex"]:
  100. format_score += 10
  101. if parsed["vec"]:
  102. format_score += 10
  103. if not parsed["invalid"]:
  104. format_score += 10
  105. else:
  106. format_score += max(0, 10 - len(parsed["invalid"]) * 5)
  107. # DIVERSITY (0-30)
  108. diversity_score = 0
  109. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  110. if types_present >= 2:
  111. diversity_score += 10
  112. total_expansions = len(parsed["lex"]) + len(parsed["vec"])
  113. if total_expansions >= 2:
  114. diversity_score += 5
  115. lex_score = 5
  116. for i, a in enumerate(parsed["lex"]):
  117. for b in parsed["lex"][i+1:]:
  118. if not is_diverse(a, b, 2):
  119. lex_score -= 2
  120. diversity_score += max(0, lex_score)
  121. vec_score = 5
  122. for i, a in enumerate(parsed["vec"]):
  123. for b in parsed["vec"][i+1:]:
  124. if not is_diverse(a, b, 3):
  125. vec_score -= 2
  126. diversity_score += max(0, vec_score)
  127. echo_score = 5
  128. for exp in parsed["lex"] + parsed["vec"]:
  129. if echoes_query(exp, query):
  130. echo_score -= 3
  131. diversity_score += max(0, echo_score)
  132. # HYDE (0-20)
  133. hyde_score = 0
  134. if parsed["hyde"]:
  135. hyde_text = parsed["hyde"][0]
  136. hyde_score += 5
  137. hyde_len = len(hyde_text)
  138. if 50 <= hyde_len <= 200:
  139. hyde_score += 5
  140. elif hyde_len < 50:
  141. hyde_score += 2
  142. if "\n" not in hyde_text:
  143. hyde_score += 5
  144. rep_penalty = word_repetition_penalty(hyde_text)
  145. hyde_score += max(0, 5 - rep_penalty)
  146. # QUALITY (0-20)
  147. quality_score = 5
  148. if parsed["lex"] and parsed["vec"]:
  149. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  150. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  151. if avg_lex <= avg_vec:
  152. quality_score += 5
  153. if parsed["vec"]:
  154. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  155. if natural == len(parsed["vec"]):
  156. quality_score += 5
  157. else:
  158. quality_score += 2
  159. if parsed["lex"]:
  160. lex_with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  161. if lex_with_terms == len(parsed["lex"]):
  162. quality_score += 5
  163. elif lex_with_terms > 0:
  164. quality_score += 2
  165. total = format_score + diversity_score + hyde_score + quality_score
  166. max_possible = 100 if parsed["hyde"] else 80
  167. return total / max_possible
  168. def extract_query_from_prompt(prompt: str) -> str:
  169. if "Expand this search query:" in prompt:
  170. return prompt.split("Expand this search query:")[-1].strip()
  171. return prompt.strip()
  172. class QMDRewardFunction:
  173. __name__ = "qmd_scoring_reward"
  174. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  175. rewards = []
  176. for i, completion in enumerate(completions):
  177. query = ""
  178. if prompts and i < len(prompts):
  179. query = extract_query_from_prompt(prompts[i])
  180. score = score_expansion(query, completion)
  181. rewards.append(score)
  182. return rewards
  183. def main():
  184. parser = argparse.ArgumentParser()
  185. parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
  186. parser.add_argument("--dry-run", action="store_true")
  187. args = parser.parse_args()
  188. with open(args.config) as f:
  189. cfg = yaml.safe_load(f)
  190. if args.dry_run:
  191. print("GRPO Training Configuration:")
  192. print(yaml.dump(cfg, default_flow_style=False))
  193. print("\nTesting reward function...")
  194. test_good = "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."
  195. test_bad = "auth is important for security"
  196. print(f" Good output score: {score_expansion('auth', test_good):.2f}")
  197. print(f" Bad output score: {score_expansion('auth', test_bad):.2f}")
  198. return
  199. # Login
  200. hf_token = os.environ.get("HF_TOKEN")
  201. if hf_token:
  202. print("Logging in to HuggingFace Hub...")
  203. login(token=hf_token)
  204. # Load dataset
  205. print("Loading dataset...")
  206. dataset = load_dataset(cfg["dataset"]["name"], split="train")
  207. def extract_prompt(example):
  208. return {"prompt": example[cfg["dataset"]["prompt_field"]][0]["content"]}
  209. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  210. max_samples = cfg["dataset"].get("max_samples", len(dataset))
  211. dataset = dataset.shuffle(seed=42).select(range(min(max_samples, len(dataset))))
  212. print(f"Using {len(dataset)} prompts for GRPO")
  213. # Load tokenizer and model
  214. print(f"Loading tokenizer from {cfg['model']['base']}...")
  215. tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["base"])
  216. if tokenizer.pad_token is None:
  217. tokenizer.pad_token = tokenizer.eos_token
  218. print(f"Loading SFT model from {cfg['model']['sft']}...")
  219. base_model = AutoModelForCausalLM.from_pretrained(
  220. cfg["model"]["base"],
  221. torch_dtype=torch.bfloat16,
  222. device_map="auto",
  223. )
  224. model = PeftModel.from_pretrained(base_model, cfg["model"]["sft"])
  225. model = model.merge_and_unload()
  226. print("Model loaded and LoRA merged.")
  227. # Add LoRA for GRPO
  228. grpo_lora_config = LoraConfig(
  229. r=cfg["lora"]["rank"],
  230. lora_alpha=cfg["lora"]["alpha"],
  231. lora_dropout=cfg["lora"]["dropout"],
  232. bias="none",
  233. task_type="CAUSAL_LM",
  234. target_modules=cfg["lora"]["target_modules"],
  235. )
  236. model = get_peft_model(model, grpo_lora_config)
  237. model.print_trainable_parameters()
  238. # Reward function
  239. reward_fn = QMDRewardFunction()
  240. # GRPO config
  241. config = GRPOConfig(
  242. output_dir=cfg["model"]["output"].split("/")[-1],
  243. push_to_hub=True,
  244. hub_model_id=cfg["model"]["output"],
  245. num_generations=cfg["grpo"]["num_generations"],
  246. max_completion_length=cfg["grpo"]["max_completion_length"],
  247. num_train_epochs=cfg["training"]["epochs"],
  248. per_device_train_batch_size=cfg["training"]["batch_size"],
  249. gradient_accumulation_steps=cfg["training"]["gradient_accumulation_steps"],
  250. learning_rate=cfg["training"]["learning_rate"],
  251. max_grad_norm=cfg["training"]["max_grad_norm"],
  252. logging_steps=10,
  253. save_strategy="epoch",
  254. report_to="trackio",
  255. project=cfg["tracking"]["project"],
  256. run_name=cfg["tracking"]["run_name"],
  257. )
  258. # Train
  259. print("Initializing GRPO trainer...")
  260. trainer = GRPOTrainer(
  261. model=model,
  262. processing_class=tokenizer,
  263. args=config,
  264. train_dataset=dataset,
  265. reward_funcs=[reward_fn],
  266. )
  267. print("Starting GRPO training...")
  268. trainer.train()
  269. print("Pushing to Hub...")
  270. trainer.push_to_hub()
  271. trackio.finish()
  272. print(f"Done! Model at: https://huggingface.co/{cfg['model']['output']}")
  273. if __name__ == "__main__":
  274. main()