train_grpo.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "trackio",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # "sentence-transformers",
  12. # ]
  13. # ///
  14. """
  15. GRPO (Group Relative Policy Optimization) training for QMD query expansion.
  16. Reward Type 2: Format + Diversity
  17. - Rewards correct lex/vec/hyde format
  18. - Penalizes repetition between lines
  19. - Rewards semantic diversity of expansions
  20. Usage:
  21. uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
  22. """
  23. import re
  24. import torch
  25. import trackio
  26. from datasets import load_dataset
  27. from peft import LoraConfig, PeftModel
  28. from transformers import AutoModelForCausalLM, AutoTokenizer
  29. from trl import GRPOTrainer, GRPOConfig
  30. from sentence_transformers import SentenceTransformer
  31. # ============================================================================
  32. # Reward Function: Format + Diversity
  33. # ============================================================================
  34. def parse_expansion(text: str) -> dict:
  35. """Parse expansion output into lex/vec/hyde components."""
  36. result = {"lex": [], "vec": [], "hyde": []}
  37. for line in text.strip().split("\n"):
  38. line = line.strip()
  39. if line.startswith("lex:"):
  40. result["lex"].append(line[4:].strip())
  41. elif line.startswith("vec:"):
  42. result["vec"].append(line[4:].strip())
  43. elif line.startswith("hyde:"):
  44. result["hyde"].append(line[5:].strip())
  45. return result
  46. def compute_format_reward(text: str) -> float:
  47. """
  48. Reward for correct format:
  49. - Has at least 1 lex line: +0.2
  50. - Has at least 1 vec line: +0.2
  51. - Has hyde line: +0.1
  52. - Correct line format (type: content): +0.1 per line (max 0.3)
  53. - No garbage/malformed lines: +0.2
  54. """
  55. reward = 0.0
  56. parsed = parse_expansion(text)
  57. # Check required components
  58. if parsed["lex"]:
  59. reward += 0.2
  60. if parsed["vec"]:
  61. reward += 0.2
  62. if parsed["hyde"]:
  63. reward += 0.1
  64. # Check line format
  65. lines = text.strip().split("\n")
  66. valid_lines = 0
  67. for line in lines:
  68. if re.match(r'^(lex|vec|hyde):\s*.+', line.strip()):
  69. valid_lines += 1
  70. reward += min(0.3, valid_lines * 0.1)
  71. # Penalize malformed lines
  72. malformed = len(lines) - valid_lines
  73. if malformed == 0:
  74. reward += 0.2
  75. else:
  76. reward -= malformed * 0.1
  77. return max(0.0, min(1.0, reward))
  78. def compute_diversity_reward(text: str, embedder) -> float:
  79. """
  80. Reward for diverse expansions:
  81. - Penalize exact duplicates
  82. - Reward semantic distance between expansions
  83. """
  84. parsed = parse_expansion(text)
  85. all_expansions = parsed["lex"] + parsed["vec"] + parsed["hyde"]
  86. if len(all_expansions) < 2:
  87. return 0.0
  88. # Penalize exact duplicates
  89. unique = set(e.lower() for e in all_expansions)
  90. duplicate_penalty = (len(all_expansions) - len(unique)) * 0.2
  91. # Compute semantic diversity
  92. if len(unique) >= 2:
  93. try:
  94. embeddings = embedder.encode(list(unique))
  95. # Compute pairwise cosine similarities
  96. from torch.nn.functional import cosine_similarity
  97. emb_tensor = torch.tensor(embeddings)
  98. similarities = []
  99. for i in range(len(emb_tensor)):
  100. for j in range(i + 1, len(emb_tensor)):
  101. sim = cosine_similarity(
  102. emb_tensor[i].unsqueeze(0),
  103. emb_tensor[j].unsqueeze(0)
  104. ).item()
  105. similarities.append(sim)
  106. # Lower similarity = higher diversity = higher reward
  107. avg_similarity = sum(similarities) / len(similarities) if similarities else 1.0
  108. diversity_reward = 1.0 - avg_similarity # 0 = identical, 1 = orthogonal
  109. except Exception:
  110. diversity_reward = 0.0
  111. else:
  112. diversity_reward = 0.0
  113. return max(0.0, diversity_reward - duplicate_penalty)
  114. def compute_length_reward(text: str) -> float:
  115. """Reward appropriate length (not too short, not too long)."""
  116. lines = [l for l in text.strip().split("\n") if l.strip()]
  117. # Ideal: 3-6 lines
  118. if 3 <= len(lines) <= 6:
  119. return 0.2
  120. elif 2 <= len(lines) <= 7:
  121. return 0.1
  122. else:
  123. return 0.0
  124. class QMDRewardFunction:
  125. """Combined reward function for QMD query expansion."""
  126. def __init__(self):
  127. # Load a small embedding model for diversity computation
  128. print("Loading embedding model for diversity reward...")
  129. self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
  130. print("Embedding model loaded.")
  131. def __call__(self, completions: list[str], prompts: list[str] = None) -> list[float]:
  132. """Compute rewards for a batch of completions."""
  133. rewards = []
  134. for completion in completions:
  135. # Extract just the generated part (after prompt)
  136. text = completion
  137. # Compute component rewards
  138. format_r = compute_format_reward(text)
  139. diversity_r = compute_diversity_reward(text, self.embedder)
  140. length_r = compute_length_reward(text)
  141. # Weighted combination
  142. total = (
  143. 0.5 * format_r + # Format is most important
  144. 0.35 * diversity_r + # Diversity is second
  145. 0.15 * length_r # Length is minor
  146. )
  147. rewards.append(total)
  148. return rewards
  149. # ============================================================================
  150. # Main Training
  151. # ============================================================================
  152. def main():
  153. import argparse
  154. parser = argparse.ArgumentParser()
  155. parser.add_argument("--sft-model", default="tobil/qmd-query-expansion-0.6B",
  156. help="SFT model to use as starting point")
  157. parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
  158. help="Base model (for loading tokenizer)")
  159. parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo",
  160. help="Output model name on Hub")
  161. parser.add_argument("--epochs", type=int, default=1)
  162. parser.add_argument("--dry-run", action="store_true")
  163. args = parser.parse_args()
  164. if args.dry_run:
  165. print("GRPO Training Config:")
  166. print(f" SFT Model: {args.sft_model}")
  167. print(f" Base Model: {args.base_model}")
  168. print(f" Output: {args.output}")
  169. print(f" Epochs: {args.epochs}")
  170. return
  171. # Load dataset (just prompts needed for GRPO)
  172. print("Loading dataset...")
  173. dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
  174. # Extract just the queries as prompts
  175. def extract_prompt(example):
  176. return {"prompt": example["messages"][0]["content"]}
  177. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  178. dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset)))) # Use subset for GRPO
  179. print(f"Using {len(dataset)} prompts for GRPO")
  180. # Load tokenizer
  181. print(f"Loading tokenizer from {args.base_model}...")
  182. tokenizer = AutoTokenizer.from_pretrained(args.base_model)
  183. if tokenizer.pad_token is None:
  184. tokenizer.pad_token = tokenizer.eos_token
  185. # Load SFT model with LoRA adapter
  186. print(f"Loading SFT model from {args.sft_model}...")
  187. base_model = AutoModelForCausalLM.from_pretrained(
  188. args.base_model,
  189. torch_dtype=torch.bfloat16,
  190. device_map="auto",
  191. )
  192. model = PeftModel.from_pretrained(base_model, args.sft_model)
  193. model = model.merge_and_unload() # Merge LoRA weights
  194. print("Model loaded and LoRA merged.")
  195. # Initialize reward function
  196. reward_fn = QMDRewardFunction()
  197. # GRPO config
  198. config = GRPOConfig(
  199. output_dir="qmd-expansion-grpo",
  200. push_to_hub=True,
  201. hub_model_id=args.output,
  202. # GRPO specific
  203. num_generations=4, # Generate 4 completions per prompt
  204. max_new_tokens=256,
  205. temperature=0.8,
  206. # Training
  207. num_train_epochs=args.epochs,
  208. per_device_train_batch_size=2,
  209. gradient_accumulation_steps=4,
  210. learning_rate=5e-6, # Lower LR for RL
  211. # Logging
  212. logging_steps=10,
  213. save_strategy="epoch",
  214. # Monitoring
  215. report_to="trackio",
  216. project="qmd-query-expansion-grpo",
  217. run_name="grpo-format-diversity",
  218. )
  219. # Create trainer
  220. print("Initializing GRPO trainer...")
  221. trainer = GRPOTrainer(
  222. model=model,
  223. tokenizer=tokenizer,
  224. config=config,
  225. train_dataset=dataset,
  226. reward_funcs=reward_fn,
  227. )
  228. # Train
  229. print("Starting GRPO training...")
  230. trainer.train()
  231. # Save
  232. print("Pushing to Hub...")
  233. trainer.push_to_hub()
  234. trackio.finish()
  235. print(f"Done! Model at: https://huggingface.co/{args.output}")
  236. if __name__ == "__main__":
  237. main()