train_grpo.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "trl>=0.12.0",
  5. # "peft>=0.7.0",
  6. # "transformers>=4.45.0",
  7. # "accelerate>=0.24.0",
  8. # "trackio",
  9. # "datasets",
  10. # "bitsandbytes",
  11. # "sentence-transformers",
  12. # ]
  13. # ///
  14. """
  15. GRPO (Group Relative Policy Optimization) training for QMD query expansion.
  16. Reward Type 2: Format + Diversity
  17. - Rewards correct lex/vec/hyde format
  18. - Penalizes repetition between lines
  19. - Rewards semantic diversity of expansions
  20. Usage:
  21. uv run train_grpo.py --sft-model tobil/qmd-query-expansion-0.6B
  22. """
  23. import re
  24. import torch
  25. import trackio
  26. from datasets import load_dataset
  27. from peft import LoraConfig, PeftModel
  28. from transformers import AutoModelForCausalLM, AutoTokenizer
  29. from trl import GRPOTrainer, GRPOConfig
  30. from sentence_transformers import SentenceTransformer
  31. # ============================================================================
  32. # Reward Function: Format + Diversity
  33. # ============================================================================
  34. def parse_expansion(text: str) -> dict:
  35. """Parse expansion output into lex/vec/hyde components."""
  36. result = {"lex": [], "vec": [], "hyde": []}
  37. for line in text.strip().split("\n"):
  38. line = line.strip()
  39. if line.startswith("lex:"):
  40. result["lex"].append(line[4:].strip())
  41. elif line.startswith("vec:"):
  42. result["vec"].append(line[4:].strip())
  43. elif line.startswith("hyde:"):
  44. result["hyde"].append(line[5:].strip())
  45. return result
  46. def compute_format_reward(text: str) -> float:
  47. """
  48. Reward for correct format:
  49. - Has at least 1 lex line: +0.2
  50. - Has at least 1 vec line: +0.2
  51. - Has hyde line: +0.1
  52. - Correct line format (type: content): +0.1 per line (max 0.3)
  53. - No garbage/malformed lines: +0.2
  54. """
  55. reward = 0.0
  56. parsed = parse_expansion(text)
  57. # Check required components
  58. if parsed["lex"]:
  59. reward += 0.2
  60. if parsed["vec"]:
  61. reward += 0.2
  62. if parsed["hyde"]:
  63. reward += 0.1
  64. # Check line format
  65. lines = text.strip().split("\n")
  66. valid_lines = 0
  67. for line in lines:
  68. if re.match(r'^(lex|vec|hyde):\s*.+', line.strip()):
  69. valid_lines += 1
  70. reward += min(0.3, valid_lines * 0.1)
  71. # Penalize malformed lines
  72. malformed = len(lines) - valid_lines
  73. if malformed == 0:
  74. reward += 0.2
  75. else:
  76. reward -= malformed * 0.1
  77. return max(0.0, min(1.0, reward))
  78. def compute_diversity_reward(text: str, embedder) -> float:
  79. """
  80. Reward for diverse expansions:
  81. - Penalize exact duplicates
  82. - Reward semantic distance between expansions
  83. """
  84. parsed = parse_expansion(text)
  85. all_expansions = parsed["lex"] + parsed["vec"] + parsed["hyde"]
  86. if len(all_expansions) < 2:
  87. return 0.0
  88. # Penalize exact duplicates
  89. unique = set(e.lower() for e in all_expansions)
  90. duplicate_penalty = (len(all_expansions) - len(unique)) * 0.2
  91. # Compute semantic diversity
  92. if len(unique) >= 2:
  93. try:
  94. embeddings = embedder.encode(list(unique))
  95. # Compute pairwise cosine similarities
  96. from torch.nn.functional import cosine_similarity
  97. emb_tensor = torch.tensor(embeddings)
  98. similarities = []
  99. for i in range(len(emb_tensor)):
  100. for j in range(i + 1, len(emb_tensor)):
  101. sim = cosine_similarity(
  102. emb_tensor[i].unsqueeze(0),
  103. emb_tensor[j].unsqueeze(0)
  104. ).item()
  105. similarities.append(sim)
  106. # Lower similarity = higher diversity = higher reward
  107. avg_similarity = sum(similarities) / len(similarities) if similarities else 1.0
  108. diversity_reward = 1.0 - avg_similarity # 0 = identical, 1 = orthogonal
  109. except Exception:
  110. diversity_reward = 0.0
  111. else:
  112. diversity_reward = 0.0
  113. return max(0.0, diversity_reward - duplicate_penalty)
  114. def compute_length_reward(text: str) -> float:
  115. """Reward appropriate length (not too short, not too long)."""
  116. lines = [l for l in text.strip().split("\n") if l.strip()]
  117. # Ideal: 3-6 lines
  118. if 3 <= len(lines) <= 6:
  119. return 0.2
  120. elif 2 <= len(lines) <= 7:
  121. return 0.1
  122. else:
  123. return 0.0
  124. class QMDRewardFunction:
  125. """Combined reward function for QMD query expansion."""
  126. __name__ = "qmd_format_diversity_reward"
  127. def __init__(self):
  128. # Load a small embedding model for diversity computation
  129. print("Loading embedding model for diversity reward...")
  130. self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
  131. print("Embedding model loaded.")
  132. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  133. """Compute rewards for a batch of completions."""
  134. rewards = []
  135. for completion in completions:
  136. # Extract just the generated part (after prompt)
  137. text = completion
  138. # Compute component rewards
  139. format_r = compute_format_reward(text)
  140. diversity_r = compute_diversity_reward(text, self.embedder)
  141. length_r = compute_length_reward(text)
  142. # Weighted combination
  143. total = (
  144. 0.5 * format_r + # Format is most important
  145. 0.35 * diversity_r + # Diversity is second
  146. 0.15 * length_r # Length is minor
  147. )
  148. rewards.append(total)
  149. return rewards
  150. # ============================================================================
  151. # Main Training
  152. # ============================================================================
  153. def main():
  154. import argparse
  155. parser = argparse.ArgumentParser()
  156. parser.add_argument("--sft-model", default="tobil/qmd-query-expansion-0.6B",
  157. help="SFT model to use as starting point")
  158. parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B",
  159. help="Base model (for loading tokenizer)")
  160. parser.add_argument("--output", default="tobil/qmd-query-expansion-0.6B-grpo",
  161. help="Output model name on Hub")
  162. parser.add_argument("--epochs", type=int, default=1)
  163. parser.add_argument("--dry-run", action="store_true")
  164. args = parser.parse_args()
  165. if args.dry_run:
  166. print("GRPO Training Config:")
  167. print(f" SFT Model: {args.sft_model}")
  168. print(f" Base Model: {args.base_model}")
  169. print(f" Output: {args.output}")
  170. print(f" Epochs: {args.epochs}")
  171. return
  172. # Load dataset (just prompts needed for GRPO)
  173. print("Loading dataset...")
  174. dataset = load_dataset("tobil/qmd-query-expansion-train", split="train")
  175. # Extract just the queries as prompts
  176. def extract_prompt(example):
  177. return {"prompt": example["messages"][0]["content"]}
  178. dataset = dataset.map(extract_prompt, remove_columns=dataset.column_names)
  179. dataset = dataset.shuffle(seed=42).select(range(min(2000, len(dataset)))) # Use subset for GRPO
  180. print(f"Using {len(dataset)} prompts for GRPO")
  181. # Load tokenizer
  182. print(f"Loading tokenizer from {args.base_model}...")
  183. tokenizer = AutoTokenizer.from_pretrained(args.base_model)
  184. if tokenizer.pad_token is None:
  185. tokenizer.pad_token = tokenizer.eos_token
  186. # Load SFT model with LoRA adapter
  187. print(f"Loading SFT model from {args.sft_model}...")
  188. base_model = AutoModelForCausalLM.from_pretrained(
  189. args.base_model,
  190. torch_dtype=torch.bfloat16,
  191. device_map="auto",
  192. )
  193. model = PeftModel.from_pretrained(base_model, args.sft_model)
  194. model = model.merge_and_unload() # Merge LoRA weights
  195. print("Model loaded and LoRA merged.")
  196. # Add new LoRA adapter for GRPO training
  197. from peft import get_peft_model
  198. grpo_lora_config = LoraConfig(
  199. r=8,
  200. lora_alpha=16,
  201. lora_dropout=0.05,
  202. bias="none",
  203. task_type="CAUSAL_LM",
  204. target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
  205. )
  206. model = get_peft_model(model, grpo_lora_config)
  207. model.print_trainable_parameters()
  208. print("Added new LoRA adapter for GRPO.")
  209. # Initialize reward function
  210. reward_fn = QMDRewardFunction()
  211. # GRPO config
  212. config = GRPOConfig(
  213. output_dir="qmd-expansion-grpo",
  214. push_to_hub=True,
  215. hub_model_id=args.output,
  216. # GRPO specific
  217. num_generations=4, # Generate 4 completions per prompt
  218. max_completion_length=256,
  219. # Training
  220. num_train_epochs=args.epochs,
  221. per_device_train_batch_size=2,
  222. gradient_accumulation_steps=4,
  223. learning_rate=5e-6, # Lower LR for RL
  224. # Logging
  225. logging_steps=10,
  226. save_strategy="epoch",
  227. # Monitoring
  228. report_to="trackio",
  229. project="qmd-query-expansion-grpo",
  230. run_name="grpo-format-diversity",
  231. )
  232. # Create trainer
  233. print("Initializing GRPO trainer...")
  234. trainer = GRPOTrainer(
  235. model=model,
  236. processing_class=tokenizer,
  237. args=config,
  238. train_dataset=dataset,
  239. reward_funcs=[reward_fn],
  240. )
  241. # Train
  242. print("Starting GRPO training...")
  243. trainer.train()
  244. # Save
  245. print("Pushing to Hub...")
  246. trainer.push_to_hub()
  247. trackio.finish()
  248. print(f"Done! Model at: https://huggingface.co/{args.output}")
  249. if __name__ == "__main__":
  250. main()