train_grpo.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "huggingface_hub>=0.20.0",
  9. # "trackio",
  10. # "datasets",
  11. # "bitsandbytes",
  12. # ]
  13. # ///
  14. """
  15. GRPO (Group Relative Policy Optimization) training for QMD query expansion.
  16. Uses the comprehensive scoring system from SCORING.md:
  17. - Format (30%): Must have lex: and vec: prefixes
  18. - Diversity (30%): No echoing query, diverse expansions
  19. - Hyde (20%): Concise, no newlines, no repetition
  20. - Quality (20%): lex=keywords, vec=natural language
  21. Usage:
  22. uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
  23. """
  24. import os
  25. import re
  26. import torch
  27. import trackio
  28. from collections import Counter
  29. from datasets import load_dataset
  30. from huggingface_hub import login
  31. from peft import LoraConfig, PeftModel, get_peft_model
  32. from transformers import AutoModelForCausalLM, AutoTokenizer
  33. from trl import GRPOTrainer, GRPOConfig
  34. STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
  35. # ============================================================================
  36. # Scoring Functions (from SCORING.md)
  37. # ============================================================================
  38. def parse_expansion(text: str) -> dict:
  39. """Parse expansion into structured format."""
  40. lines = text.strip().split("\n")
  41. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  42. for line in lines:
  43. line = line.strip()
  44. if not line:
  45. continue
  46. if line.startswith("lex:"):
  47. result["lex"].append(line[4:].strip())
  48. elif line.startswith("vec:"):
  49. result["vec"].append(line[4:].strip())
  50. elif line.startswith("hyde:"):
  51. result["hyde"].append(line[5:].strip())
  52. else:
  53. result["invalid"].append(line)
  54. return result
  55. def edit_distance_simple(a: str, b: str) -> int:
  56. """Simple word-level edit distance."""
  57. words_a = set(a.lower().split())
  58. words_b = set(b.lower().split())
  59. return len(words_a ^ words_b)
  60. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  61. """Check if two strings are sufficiently different."""
  62. a, b = a.lower().strip(), b.lower().strip()
  63. if a == b:
  64. return False
  65. if a in b or b in a:
  66. return False
  67. return edit_distance_simple(a, b) >= min_distance
  68. def echoes_query(expansion: str, query: str) -> bool:
  69. """Check if expansion is just echoing the query."""
  70. exp = expansion.lower().strip()
  71. q = query.lower().strip()
  72. if exp == q:
  73. return True
  74. if q in exp and len(exp) < len(q) + 10:
  75. return True
  76. return False
  77. def word_repetition_penalty(text: str) -> int:
  78. """Count penalty for repeated words (excluding stopwords)."""
  79. words = re.findall(r'\b\w+\b', text.lower())
  80. counts = Counter(words)
  81. penalty = 0
  82. for word, count in counts.items():
  83. if count >= 3 and word not in STOPWORDS and len(word) > 2:
  84. penalty += (count - 2) * 2
  85. return penalty
  86. def score_expansion(query: str, expansion: str) -> float:
  87. """
  88. Score an expansion based on SCORING.md criteria.
  89. Returns normalized score 0.0-1.0 for RL reward.
  90. """
  91. parsed = parse_expansion(expansion)
  92. # === FORMAT (0-30) ===
  93. format_score = 0
  94. if parsed["lex"]:
  95. format_score += 10
  96. if parsed["vec"]:
  97. format_score += 10
  98. if not parsed["invalid"]:
  99. format_score += 10
  100. else:
  101. format_score += max(0, 10 - len(parsed["invalid"]) * 5)
  102. # === DIVERSITY (0-30) ===
  103. diversity_score = 0
  104. # 2+ different types
  105. types_present = sum(1 for t in ["lex", "vec"] if parsed[t])
  106. if types_present >= 2:
  107. diversity_score += 10
  108. # 2+ total expansions
  109. total_expansions = len(parsed["lex"]) + len(parsed["vec"])
  110. if total_expansions >= 2:
  111. diversity_score += 5
  112. # Lex diversity
  113. lex_score = 5
  114. for i, a in enumerate(parsed["lex"]):
  115. for b in parsed["lex"][i+1:]:
  116. if not is_diverse(a, b, 2):
  117. lex_score -= 2
  118. diversity_score += max(0, lex_score)
  119. # Vec diversity
  120. vec_score = 5
  121. for i, a in enumerate(parsed["vec"]):
  122. for b in parsed["vec"][i+1:]:
  123. if not is_diverse(a, b, 3):
  124. vec_score -= 2
  125. diversity_score += max(0, vec_score)
  126. # Don't echo query
  127. echo_score = 5
  128. for exp in parsed["lex"] + parsed["vec"]:
  129. if echoes_query(exp, query):
  130. echo_score -= 3 # Heavier penalty for echoing
  131. diversity_score += max(0, echo_score)
  132. # === HYDE (0-20) ===
  133. hyde_score = 0
  134. if parsed["hyde"]:
  135. hyde_text = parsed["hyde"][0]
  136. hyde_score += 5 # Present
  137. # Length check (50-200 chars ideal)
  138. hyde_len = len(hyde_text)
  139. if 50 <= hyde_len <= 200:
  140. hyde_score += 5
  141. elif hyde_len < 50:
  142. hyde_score += 2
  143. # No newlines
  144. if "\n" not in hyde_text:
  145. hyde_score += 5
  146. # No repetition
  147. rep_penalty = word_repetition_penalty(hyde_text)
  148. hyde_score += max(0, 5 - rep_penalty)
  149. # === QUALITY (0-20) ===
  150. quality_score = 10 # Base
  151. # Lex should be shorter than vec
  152. if parsed["lex"] and parsed["vec"]:
  153. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  154. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  155. if avg_lex <= avg_vec:
  156. quality_score += 5
  157. # Vec should be natural language
  158. if parsed["vec"]:
  159. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  160. if natural == len(parsed["vec"]):
  161. quality_score += 5
  162. else:
  163. quality_score += 2
  164. # === TOTAL ===
  165. total = format_score + diversity_score + hyde_score + quality_score
  166. max_possible = 100 if parsed["hyde"] else 80
  167. # Normalize to 0-1
  168. return total / max_possible
  169. def extract_query_from_prompt(prompt: str) -> str:
  170. """Extract the query from the prompt template."""
  171. # Prompt format: "Expand this search query:\n\n{query}"
  172. if "Expand this search query:" in prompt:
  173. return prompt.split("Expand this search query:")[-1].strip()
  174. return prompt.strip()
  175. class QMDRewardFunction:
  176. """Reward function using comprehensive SCORING.md criteria."""
  177. __name__ = "qmd_scoring_reward"
  178. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  179. """Compute rewards for a batch of completions."""
  180. rewards = []
  181. for i, completion in enumerate(completions):
  182. # Get the query from prompt if available
  183. query = ""
  184. if prompts and i < len(prompts):
  185. query = extract_query_from_prompt(prompts[i])
  186. # Score using comprehensive system
  187. score = score_expansion(query, completion)
  188. rewards.append(score)
  189. return rewards
  190. # ============================================================================
  191. # Main Training
  192. # ============================================================================
  193. def main():
  194. import argparse
  195. parser = argparse.ArgumentParser()
  196. parser.add_argument("--sft-model", default="tobil/qmd-query-expansion-0.6B",
  197. help="SFT model to use as starting point")
  198. parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
  199. help="Base model (for loading tokenizer)")
  200. parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo-v2",
  201. help="Output model name on Hub")
  202. parser.add_argument("--epochs", type=int, default=1)
  203. parser.add_argument("--lr", type=float, default=1e-6,
  204. help="Learning rate (lower for stability)")
  205. parser.add_argument("--dry-run", action="store_true")
  206. args = parser.parse_args()
  207. if args.dry_run:
  208. print("GRPO Training Config:")
  209. print(f" SFT Model: {args.sft_model}")
  210. print(f" Base Model: {args.base_model}")
  211. print(f" Output: {args.output}")
  212. print(f" Epochs: {args.epochs}")
  213. print(f" LR: {args.lr}")
  214. return
  215. # Login to HuggingFace Hub
  216. hf_token = os.environ.get("HF_TOKEN")
  217. if hf_token:
  218. print("Logging in to HuggingFace Hub...")
  219. login(token=hf_token)
  220. else:
  221. print("Warning: HF_TOKEN not set, will try cached login")
  222. # Load dataset (just prompts needed for GRPO)
  223. print("Loading dataset...")
  224. dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
  225. # Extract just the queries as prompts
  226. def extract_prompt(example):
  227. return {"prompt": example["messages"][0]["content"]}
  228. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  229. dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset))))
  230. print(f"Using {len(dataset)} prompts for GRPO")
  231. # Load tokenizer
  232. print(f"Loading tokenizer from {args.base_model}...")
  233. tokenizer = AutoTokenizer.from_pretrained(args.base_model)
  234. if tokenizer.pad_token is None:
  235. tokenizer.pad_token = tokenizer.eos_token
  236. # Load SFT model with LoRA adapter
  237. print(f"Loading SFT model from {args.sft_model}...")
  238. base_model = AutoModelForCausalLM.from_pretrained(
  239. args.base_model,
  240. torch_dtype=torch.bfloat16,
  241. device_map="auto",
  242. )
  243. model = PeftModel.from_pretrained(base_model, args.sft_model)
  244. model = model.merge_and_unload()
  245. print("Model loaded and LoRA merged.")
  246. # Add new LoRA adapter for GRPO training (smaller rank for stability)
  247. grpo_lora_config = LoraConfig(
  248. r=4, # Smaller rank for more stable RL
  249. lora_alpha=8,
  250. lora_dropout=0.05,
  251. bias="none",
  252. task_type="CAUSAL_LM",
  253. target_modules=["q_proj", "v_proj"], # Fewer modules for stability
  254. )
  255. model = get_peft_model(model, grpo_lora_config)
  256. model.print_trainable_parameters()
  257. print("Added new LoRA adapter for GRPO.")
  258. # Initialize reward function
  259. reward_fn = QMDRewardFunction()
  260. # Test reward function
  261. print("\nTesting reward function...")
  262. test_good = "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."
  263. test_bad = "auth is important for security"
  264. print(f" Good output score: {score_expansion('auth', test_good):.2f}")
  265. print(f" Bad output score: {score_expansion('auth', test_bad):.2f}")
  266. # GRPO config with conservative settings
  267. config = GRPOConfig(
  268. output_dir="qmd-expansion-grpo-v2",
  269. push_to_hub=True,
  270. hub_model_id=args.output,
  271. # GRPO specific - conservative
  272. num_generations=4,
  273. max_completion_length=200, # Shorter to avoid rambling
  274. # Training - very conservative
  275. num_train_epochs=args.epochs,
  276. per_device_train_batch_size=2,
  277. gradient_accumulation_steps=8,
  278. learning_rate=args.lr,
  279. max_grad_norm=0.5, # Clip gradients more aggressively
  280. # Logging
  281. logging_steps=10,
  282. save_strategy="epoch",
  283. # Monitoring
  284. report_to="trackio",
  285. project="qmd-query-expansion-grpo-v2",
  286. run_name="grpo-scoring-v2",
  287. )
  288. # Create trainer
  289. print("Initializing GRPO trainer...")
  290. trainer = GRPOTrainer(
  291. model=model,
  292. processing_class=tokenizer,
  293. args=config,
  294. train_dataset=dataset,
  295. reward_funcs=[reward_fn],
  296. )
  297. # Train
  298. print("Starting GRPO training...")
  299. trainer.train()
  300. # Save
  301. print("Pushing to Hub...")
  302. trainer.push_to_hub()
  303. trackio.finish()
  304. print(f"Done! Model at: https://huggingface.co/{args.output}")
  305. if __name__ == "__main__":
  306. main()