reward.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = []
  4. # ///
  5. """
  6. QMD Query Expansion Reward Function
  7. Single source of truth for scoring query expansions. Used by:
  8. - GRPO training (as the RL reward signal)
  9. - Evaluation scripts (for scoring model outputs)
  10. Scores expansions on five dimensions:
  11. Format (30) - Has lex/vec lines, no invalid lines
  12. Diversity (30) - Multiple types, diverse content, no echoes
  13. HyDE (20) - Optional bonus for hypothetical document passage
  14. Quality (20) - Lex shorter than vec, natural language, key terms
  15. Entity (20) - Named entity preservation in lex/vec lines
  16. Returns 0.0-1.0 for RL rewards, or a detailed breakdown dict for evaluation.
  17. """
  18. import re
  19. from collections import Counter
  20. # =============================================================================
  21. # Constants
  22. # =============================================================================
  23. STOPWORDS = frozenset({
  24. 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
  25. 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
  26. })
  27. KEY_TERM_STOPWORDS = frozenset({
  28. 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  29. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  30. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
  31. })
  32. GENERIC_LEX_PHRASES = frozenset({
  33. 'find information about', 'search for', 'look up', 'get information',
  34. 'learn about', 'information on', 'details about', 'find out about',
  35. 'what is', 'how to', 'guide to', 'help with',
  36. })
  37. # Chat template tokens that indicate a broken output
  38. CHAT_TEMPLATE_TOKENS = frozenset({
  39. '<|im_start|>', '<|im_end|>', '<|endoftext|>',
  40. '\nassistant\n', '\nuser\n',
  41. })
  42. # =============================================================================
  43. # Parsing
  44. # =============================================================================
  45. def parse_expansion(text: str) -> dict:
  46. """Parse a multi-line expansion into {lex, vec, hyde, invalid} lists."""
  47. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  48. for line in text.strip().split("\n"):
  49. line = line.strip()
  50. if not line:
  51. continue
  52. if line.startswith("lex:"):
  53. result["lex"].append(line[4:].strip())
  54. elif line.startswith("vec:"):
  55. result["vec"].append(line[4:].strip())
  56. elif line.startswith("hyde:"):
  57. result["hyde"].append(line[5:].strip())
  58. else:
  59. result["invalid"].append(line)
  60. return result
  61. def clean_model_output(text: str) -> tuple[str, bool]:
  62. """Strip chat template artifacts from model output.
  63. Returns (cleaned_text, used_thinking) where used_thinking is True
  64. if the model emitted <think>...</think> blocks.
  65. """
  66. text = text.replace('<|im_end|>', '').strip()
  67. used_thinking = '<think>' in text and '</think>' in text
  68. if used_thinking:
  69. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
  70. return text, used_thinking
  71. # =============================================================================
  72. # Helpers
  73. # =============================================================================
  74. def extract_named_entities(query: str) -> set:
  75. """Extract named entities using heuristics.
  76. Detects: ALL-CAPS acronyms (TDS, API), capitalized proper nouns (React),
  77. technical terms with special chars (node.js, C++), CamelCase (JavaScript),
  78. and compound names (TDS motorsports -> both words).
  79. """
  80. entities = set()
  81. words = query.split()
  82. prev_was_entity = False
  83. for i, word in enumerate(words):
  84. clean = word.strip('.,!?:;()[]"\'')
  85. if not clean:
  86. prev_was_entity = False
  87. continue
  88. is_entity = False
  89. if clean.isupper() and len(clean) >= 2:
  90. entities.add(clean.lower())
  91. is_entity = True
  92. elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  93. entities.add(clean.lower())
  94. is_entity = True
  95. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  96. entities.add(clean.lower())
  97. is_entity = True
  98. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  99. entities.add(clean.lower())
  100. is_entity = True
  101. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  102. entities.add(clean.lower())
  103. is_entity = True
  104. prev_was_entity = is_entity
  105. return entities
  106. def get_key_terms(query: str) -> set:
  107. """Get non-stopword terms from a query."""
  108. return set(query.lower().split()) - KEY_TERM_STOPWORDS
  109. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  110. """Does the lex line contain at least one key term from the query?"""
  111. key_terms = get_key_terms(query)
  112. if not key_terms:
  113. return True
  114. return bool(key_terms & set(lex_line.lower().split()))
  115. def lex_preserves_entities(line: str, entities: set) -> bool:
  116. """Does the line contain at least one named entity?"""
  117. if not entities:
  118. return True
  119. lower = line.lower()
  120. return any(e in lower for e in entities)
  121. def lex_is_generic(lex_line: str) -> bool:
  122. """Is this lex line a useless generic filler phrase?"""
  123. lower = lex_line.lower().strip()
  124. for phrase in GENERIC_LEX_PHRASES:
  125. if phrase in lower or lower.startswith(phrase.split()[0]):
  126. remaining = lower
  127. for word in phrase.split():
  128. remaining = remaining.replace(word, '', 1).strip()
  129. if len(remaining) < 3:
  130. return True
  131. return False
  132. def word_set_distance(a: str, b: str) -> int:
  133. """Symmetric difference of word sets (how many words are unique to one)."""
  134. return len(set(a.lower().split()) ^ set(b.lower().split()))
  135. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  136. """Are two strings sufficiently different?"""
  137. a, b = a.lower().strip(), b.lower().strip()
  138. if a == b or a in b or b in a:
  139. return False
  140. return word_set_distance(a, b) >= min_distance
  141. def echoes_query(expansion: str, query: str) -> bool:
  142. """Is this expansion just echoing the original query?"""
  143. exp, q = expansion.lower().strip(), query.lower().strip()
  144. return exp == q or (q in exp and len(exp) < len(q) + 10)
  145. def word_repetition_penalty(text: str) -> int:
  146. """Penalty for words repeated 3+ times (excluding stopwords)."""
  147. counts = Counter(re.findall(r'\b\w+\b', text.lower()))
  148. return sum((c - 2) * 2 for w, c in counts.items()
  149. if c >= 3 and w not in STOPWORDS and len(w) > 2)
  150. # =============================================================================
  151. # Scoring
  152. # =============================================================================
  153. def score_expansion_detailed(query: str, expansion: str) -> dict:
  154. """Score an expansion with full breakdown. Returns dict with all dimensions."""
  155. text, used_thinking = clean_model_output(expansion.strip())
  156. deductions = []
  157. def _fail(reason):
  158. return {
  159. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  160. "think_bonus": 0, "total": 0, "max_possible": 100,
  161. "percentage": 0.0, "rating": "Failed",
  162. "deductions": [reason],
  163. "parsed": parse_expansion(expansion),
  164. "entities_detected": [],
  165. }
  166. # Hard fail: remaining chat template tokens
  167. if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
  168. return _fail("CHAT TEMPLATE LEAKAGE")
  169. # Hard fail: every non-empty line must have a valid prefix
  170. for line in text.split("\n"):
  171. line = line.strip()
  172. if line and not line.startswith(("lex:", "vec:", "hyde:")):
  173. return _fail(f"INVALID LINE: {line[:50]}")
  174. parsed = parse_expansion(text)
  175. # --- Format (0-30) ---
  176. format_score = 10 # no invalid lines (guaranteed by hard fail)
  177. if parsed["lex"]:
  178. format_score += 10
  179. else:
  180. deductions.append("missing lex:")
  181. if parsed["vec"]:
  182. format_score += 10
  183. else:
  184. deductions.append("missing vec:")
  185. # --- Diversity (0-30) ---
  186. diversity_score = 0
  187. types_present = sum(1 for t in ("lex", "vec") if parsed[t])
  188. if types_present >= 2:
  189. diversity_score += 10
  190. else:
  191. deductions.append("only one type")
  192. if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
  193. diversity_score += 5
  194. lex_div = 5
  195. for i, a in enumerate(parsed["lex"]):
  196. for b in parsed["lex"][i+1:]:
  197. if not is_diverse(a, b, 2):
  198. lex_div -= 2
  199. deductions.append(f"lex duplicate: {a[:20]}...")
  200. diversity_score += max(0, lex_div)
  201. vec_div = 5
  202. for i, a in enumerate(parsed["vec"]):
  203. for b in parsed["vec"][i+1:]:
  204. if not is_diverse(a, b, 3):
  205. vec_div -= 2
  206. deductions.append(f"vec duplicate: {a[:20]}...")
  207. diversity_score += max(0, vec_div)
  208. echo = 5
  209. for exp in parsed["lex"] + parsed["vec"]:
  210. if echoes_query(exp, query):
  211. echo -= 3
  212. deductions.append(f"echoes query: {exp[:20]}...")
  213. diversity_score += max(0, echo)
  214. # --- HyDE (0-20, optional bonus) ---
  215. hyde_score = 0
  216. if parsed["hyde"]:
  217. hyde_text = parsed["hyde"][0]
  218. hyde_score += 5
  219. hyde_len = len(hyde_text)
  220. if 50 <= hyde_len <= 200:
  221. hyde_score += 5
  222. elif hyde_len < 50:
  223. hyde_score += 2
  224. deductions.append(f"hyde too short ({hyde_len})")
  225. else:
  226. deductions.append(f"hyde too long ({hyde_len})")
  227. if "\n" not in hyde_text:
  228. hyde_score += 5
  229. hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
  230. # --- Quality (0-20) ---
  231. quality_score = 5 # base relevance
  232. if parsed["lex"] and parsed["vec"]:
  233. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  234. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  235. if avg_lex <= avg_vec:
  236. quality_score += 5
  237. else:
  238. deductions.append("lex longer than vec")
  239. if parsed["vec"]:
  240. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  241. quality_score += 5 if natural == len(parsed["vec"]) else 2
  242. if parsed["lex"]:
  243. with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  244. if with_terms == len(parsed["lex"]):
  245. quality_score += 5
  246. elif with_terms > 0:
  247. quality_score += 2
  248. else:
  249. deductions.append("lex missing key terms")
  250. # --- Entity Preservation (-45 to +20) ---
  251. entity_score = 0
  252. entities = extract_named_entities(query)
  253. if entities and parsed["lex"]:
  254. with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  255. if with_entities == len(parsed["lex"]):
  256. entity_score += 15
  257. elif with_entities > 0:
  258. entity_score += 5
  259. else:
  260. entity_score -= 30
  261. deductions.append(f"lex missing entities: {entities}")
  262. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  263. if generic_count:
  264. entity_score -= generic_count * 15
  265. deductions.append(f"{generic_count} generic lex phrases")
  266. if parsed["vec"]:
  267. vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  268. if vec_with > 0:
  269. entity_score += 5
  270. elif not entities:
  271. entity_score = 10
  272. # --- Think bonus (0-20): reward NOT using thinking mode ---
  273. think_bonus = 0 if used_thinking else 20
  274. # --- Total ---
  275. total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
  276. max_possible = 140 if parsed["hyde"] else 120
  277. percentage = max(0.0, min(100.0, total / max_possible * 100))
  278. if percentage >= 80:
  279. rating = "Excellent"
  280. elif percentage >= 60:
  281. rating = "Good"
  282. elif percentage >= 40:
  283. rating = "Acceptable"
  284. elif percentage >= 20:
  285. rating = "Poor"
  286. else:
  287. rating = "Failed"
  288. return {
  289. "format": format_score,
  290. "diversity": diversity_score,
  291. "hyde": hyde_score,
  292. "quality": quality_score,
  293. "entity": max(0, entity_score),
  294. "think_bonus": think_bonus,
  295. "total": max(0, total),
  296. "max_possible": max_possible,
  297. "percentage": round(percentage, 1),
  298. "rating": rating,
  299. "deductions": deductions,
  300. "parsed": parsed,
  301. "entities_detected": list(entities) if entities else [],
  302. }
  303. def score_expansion(query: str, expansion: str) -> float:
  304. """Score expansion as a float in [0.0, 1.0] for use as RL reward."""
  305. result = score_expansion_detailed(query, expansion)
  306. return max(0.0, min(1.0, result["total"] / result["max_possible"]))
  307. def extract_query_from_prompt(prompt: str) -> str:
  308. """Extract the query string from a chat-formatted prompt."""
  309. if "Expand this search query:" in prompt:
  310. query = prompt.split("Expand this search query:")[-1].strip()
  311. if "<|im_end|>" in query:
  312. query = query.split("<|im_end|>")[0].strip()
  313. return query
  314. return prompt.strip()
  315. # =============================================================================
  316. # TRL-compatible reward class
  317. # =============================================================================
  318. class QMDRewardFunction:
  319. """Reward function compatible with TRL's GRPOTrainer."""
  320. __name__ = "qmd_scoring_reward"
  321. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  322. rewards = []
  323. for i, completion in enumerate(completions):
  324. query = ""
  325. if prompts and i < len(prompts):
  326. query = extract_query_from_prompt(prompts[i])
  327. rewards.append(score_expansion(query, completion))
  328. return rewards
  329. # =============================================================================
  330. # CLI: run standalone to test the reward function
  331. # =============================================================================
  332. if __name__ == "__main__":
  333. print("QMD Reward Function Self-Test")
  334. print("=" * 60)
  335. tests = [
  336. ("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."),
  337. ("auth", "auth is important for security"),
  338. ("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"),
  339. ("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"),
  340. ("how to use React hooks", "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"),
  341. ("auth", "<think>Let me think...</think>\nlex: auth"),
  342. ("auth", "lex: auth\nThis is some explanation\nvec: more"),
  343. ]
  344. for query, expansion in tests:
  345. score = score_expansion(query, expansion)
  346. detail = score_expansion_detailed(query, expansion)
  347. print(f"\n Query: '{query}'")
  348. print(f" Score: {score:.2f} ({detail['rating']})")
  349. if detail["deductions"]:
  350. print(f" Issues: {', '.join(detail['deductions'][:3])}")