reward.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = []
  4. # ///
  5. """
  6. QMD Query Expansion Reward Function
  7. Single source of truth for scoring query expansions. Used by:
  8. - GRPO training (as the RL reward signal)
  9. - Evaluation scripts (for scoring model outputs)
  10. Scores expansions on five dimensions:
  11. Format (30) - Has lex/vec lines, no invalid lines
  12. Diversity (30) - Multiple types, diverse content, no echoes
  13. HyDE (20) - Optional bonus for hypothetical document passage
  14. Quality (20) - Lex shorter than vec, natural language, key terms
  15. Entity (20) - Named entity preservation in lex/vec lines
  16. Returns 0.0-1.0 for RL rewards, or a detailed breakdown dict for evaluation.
  17. """
  18. import re
  19. from collections import Counter
  20. # =============================================================================
  21. # Constants
  22. # =============================================================================
  23. # "only:" mode patterns - when query ends with these, expect only that type
  24. ONLY_MODE_PATTERN = re.compile(r'\s+only:\s*(lex|vec|hyde)\s*$', re.IGNORECASE)
  25. STOPWORDS = frozenset({
  26. 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
  27. 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
  28. })
  29. KEY_TERM_STOPWORDS = frozenset({
  30. 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  31. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  32. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
  33. })
  34. GENERIC_LEX_PHRASES = frozenset({
  35. 'find information about', 'search for', 'look up', 'get information',
  36. 'learn about', 'information on', 'details about', 'find out about',
  37. 'what is', 'how to', 'guide to', 'help with',
  38. })
  39. # Chat template tokens that indicate a broken output
  40. CHAT_TEMPLATE_TOKENS = frozenset({
  41. '<|im_start|>', '<|im_end|>', '<|endoftext|>',
  42. '\nassistant\n', '\nuser\n',
  43. })
  44. # =============================================================================
  45. # Parsing
  46. # =============================================================================
  47. def parse_expansion(text: str) -> dict:
  48. """Parse a multi-line expansion into {lex, vec, hyde, invalid} lists."""
  49. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  50. for line in text.strip().split("\n"):
  51. line = line.strip()
  52. if not line:
  53. continue
  54. if line.startswith("lex:"):
  55. result["lex"].append(line[4:].strip())
  56. elif line.startswith("vec:"):
  57. result["vec"].append(line[4:].strip())
  58. elif line.startswith("hyde:"):
  59. result["hyde"].append(line[5:].strip())
  60. else:
  61. result["invalid"].append(line)
  62. return result
  63. def detect_only_mode(query: str) -> tuple[str | None, str]:
  64. """Detect if query ends with 'only: lex/vec/hyde'.
  65. Returns (only_type, base_query) where only_type is None for normal queries.
  66. """
  67. match = ONLY_MODE_PATTERN.search(query)
  68. if match:
  69. only_type = match.group(1).lower()
  70. base_query = query[:match.start()].strip()
  71. return only_type, base_query
  72. return None, query
  73. def clean_model_output(text: str) -> tuple[str, bool]:
  74. """Strip chat template artifacts from model output.
  75. Returns (cleaned_text, used_thinking) where used_thinking is True
  76. if the model emitted <think>...</think> blocks.
  77. """
  78. text = text.replace('<|im_end|>', '').strip()
  79. used_thinking = '<think>' in text and '</think>' in text
  80. if used_thinking:
  81. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
  82. return text, used_thinking
  83. # =============================================================================
  84. # Helpers
  85. # =============================================================================
  86. def extract_named_entities(query: str) -> set:
  87. """Extract named entities using heuristics.
  88. Detects: ALL-CAPS acronyms (TDS, API), capitalized proper nouns (React),
  89. technical terms with special chars (node.js, C++), CamelCase (JavaScript),
  90. and compound names (TDS motorsports -> both words).
  91. """
  92. entities = set()
  93. words = query.split()
  94. prev_was_entity = False
  95. for i, word in enumerate(words):
  96. clean = word.strip('.,!?:;()[]"\'')
  97. if not clean:
  98. prev_was_entity = False
  99. continue
  100. is_entity = False
  101. if clean.isupper() and len(clean) >= 2:
  102. entities.add(clean.lower())
  103. is_entity = True
  104. elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  105. entities.add(clean.lower())
  106. is_entity = True
  107. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  108. entities.add(clean.lower())
  109. is_entity = True
  110. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  111. entities.add(clean.lower())
  112. is_entity = True
  113. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  114. entities.add(clean.lower())
  115. is_entity = True
  116. prev_was_entity = is_entity
  117. return entities
  118. def get_key_terms(query: str) -> set:
  119. """Get non-stopword terms from a query."""
  120. return set(query.lower().split()) - KEY_TERM_STOPWORDS
  121. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  122. """Does the lex line contain at least one key term from the query?"""
  123. key_terms = get_key_terms(query)
  124. if not key_terms:
  125. return True
  126. return bool(key_terms & set(lex_line.lower().split()))
  127. def lex_preserves_entities(line: str, entities: set) -> bool:
  128. """Does the line contain at least one named entity?"""
  129. if not entities:
  130. return True
  131. lower = line.lower()
  132. return any(e in lower for e in entities)
  133. def lex_is_generic(lex_line: str) -> bool:
  134. """Is this lex line a useless generic filler phrase?"""
  135. lower = lex_line.lower().strip()
  136. for phrase in GENERIC_LEX_PHRASES:
  137. if phrase in lower or lower.startswith(phrase.split()[0]):
  138. remaining = lower
  139. for word in phrase.split():
  140. remaining = remaining.replace(word, '', 1).strip()
  141. if len(remaining) < 3:
  142. return True
  143. return False
  144. def word_set_distance(a: str, b: str) -> int:
  145. """Symmetric difference of word sets (how many words are unique to one)."""
  146. return len(set(a.lower().split()) ^ set(b.lower().split()))
  147. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  148. """Are two strings sufficiently different?"""
  149. a, b = a.lower().strip(), b.lower().strip()
  150. if a == b or a in b or b in a:
  151. return False
  152. return word_set_distance(a, b) >= min_distance
  153. def echoes_query(expansion: str, query: str) -> bool:
  154. """Is this expansion just echoing the original query?"""
  155. exp, q = expansion.lower().strip(), query.lower().strip()
  156. return exp == q or (q in exp and len(exp) < len(q) + 10)
  157. def word_repetition_penalty(text: str) -> int:
  158. """Penalty for words repeated 3+ times (excluding stopwords)."""
  159. counts = Counter(re.findall(r'\b\w+\b', text.lower()))
  160. return sum((c - 2) * 2 for w, c in counts.items()
  161. if c >= 3 and w not in STOPWORDS and len(w) > 2)
  162. # =============================================================================
  163. # Scoring
  164. # =============================================================================
  165. def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool, only_type: str) -> dict:
  166. """Score an 'only:' mode expansion. Expects ONLY the requested type."""
  167. parsed = parse_expansion(text)
  168. deductions = []
  169. # Expected type must be present
  170. expected_items = parsed.get(only_type, [])
  171. if not expected_items:
  172. return {
  173. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  174. "think_bonus": 0, "total": 0, "max_possible": 100,
  175. "percentage": 0.0, "rating": "Failed",
  176. "deductions": [f"missing expected {only_type}: output"],
  177. "parsed": parsed,
  178. "entities_detected": [],
  179. "only_mode": only_type,
  180. }
  181. # Penalize presence of OTHER types
  182. other_types = {"lex", "vec", "hyde"} - {only_type}
  183. unwanted_count = sum(len(parsed.get(t, [])) for t in other_types)
  184. if unwanted_count > 0:
  185. deductions.append(f"contains unwanted types (expected only {only_type})")
  186. # --- Format (0-30) ---
  187. format_score = 30 if unwanted_count == 0 else max(0, 30 - unwanted_count * 10)
  188. # --- Diversity (0-30) ---
  189. diversity_score = 0
  190. if len(expected_items) >= 2:
  191. diversity_score += 15
  192. # Check for diversity among items
  193. div_score = 15
  194. for i, a in enumerate(expected_items):
  195. for b in expected_items[i+1:]:
  196. if not is_diverse(a, b, 2):
  197. div_score -= 5
  198. deductions.append(f"{only_type} duplicate: {a[:20]}...")
  199. diversity_score += max(0, div_score)
  200. elif len(expected_items) == 1:
  201. diversity_score = 15 # One item is fine for single-type output
  202. # Check for echoes
  203. for exp in expected_items:
  204. if echoes_query(exp, base_query):
  205. diversity_score -= 5
  206. deductions.append(f"echoes query: {exp[:20]}...")
  207. diversity_score = max(0, diversity_score)
  208. # --- Type-specific quality (0-20) ---
  209. quality_score = 10 # base
  210. entities = extract_named_entities(base_query)
  211. if only_type == "lex":
  212. # Lex should be short keyword phrases with key terms
  213. with_terms = sum(1 for l in expected_items if lex_preserves_key_terms(l, base_query))
  214. if with_terms == len(expected_items):
  215. quality_score += 5
  216. # Check for generic phrases
  217. generic = sum(1 for l in expected_items if lex_is_generic(l))
  218. if generic == 0:
  219. quality_score += 5
  220. else:
  221. deductions.append(f"{generic} generic lex phrases")
  222. elif only_type == "vec":
  223. # Vec should be natural language sentences
  224. natural = sum(1 for v in expected_items if " " in v and len(v) > 15)
  225. if natural == len(expected_items):
  226. quality_score += 10
  227. else:
  228. quality_score += 5
  229. deductions.append("vec not all natural language")
  230. elif only_type == "hyde":
  231. # Hyde should be a document snippet (50-200 chars)
  232. hyde_text = expected_items[0]
  233. hyde_len = len(hyde_text)
  234. if 50 <= hyde_len <= 200:
  235. quality_score += 10
  236. elif 30 <= hyde_len <= 300:
  237. quality_score += 5
  238. deductions.append(f"hyde length {hyde_len} (ideal: 50-200)")
  239. else:
  240. deductions.append(f"hyde length {hyde_len} out of range")
  241. # --- Entity preservation (0-20) ---
  242. entity_score = 10 # base
  243. if entities:
  244. with_entities = sum(1 for item in expected_items if lex_preserves_entities(item, entities))
  245. if with_entities == len(expected_items):
  246. entity_score += 10
  247. elif with_entities > 0:
  248. entity_score += 5
  249. else:
  250. entity_score = 0
  251. deductions.append(f"missing entities: {entities}")
  252. # --- Think bonus (0-20) ---
  253. think_bonus = 0 if used_thinking else 20
  254. # --- Total ---
  255. total = format_score + diversity_score + quality_score + entity_score + think_bonus
  256. max_possible = 120
  257. percentage = max(0.0, min(100.0, total / max_possible * 100))
  258. if percentage >= 80:
  259. rating = "Excellent"
  260. elif percentage >= 60:
  261. rating = "Good"
  262. elif percentage >= 40:
  263. rating = "Acceptable"
  264. elif percentage >= 20:
  265. rating = "Poor"
  266. else:
  267. rating = "Failed"
  268. return {
  269. "format": format_score,
  270. "diversity": diversity_score,
  271. "hyde": 0, # not used in only mode (quality covers it)
  272. "quality": quality_score,
  273. "entity": entity_score,
  274. "think_bonus": think_bonus,
  275. "total": total,
  276. "max_possible": max_possible,
  277. "percentage": round(percentage, 1),
  278. "rating": rating,
  279. "deductions": deductions,
  280. "parsed": parsed,
  281. "entities_detected": list(entities) if entities else [],
  282. "only_mode": only_type,
  283. }
  284. def score_expansion_detailed(query: str, expansion: str) -> dict:
  285. """Score an expansion with full breakdown. Returns dict with all dimensions."""
  286. text, used_thinking = clean_model_output(expansion.strip())
  287. deductions = []
  288. # Detect "only:" mode
  289. only_type, base_query = detect_only_mode(query)
  290. def _fail(reason):
  291. return {
  292. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  293. "think_bonus": 0, "total": 0, "max_possible": 100,
  294. "percentage": 0.0, "rating": "Failed",
  295. "deductions": [reason],
  296. "parsed": parse_expansion(expansion),
  297. "entities_detected": [],
  298. "only_mode": only_type,
  299. }
  300. # Hard fail: remaining chat template tokens
  301. if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
  302. return _fail("CHAT TEMPLATE LEAKAGE")
  303. # Hard fail: every non-empty line must have a valid prefix
  304. for line in text.split("\n"):
  305. line = line.strip()
  306. if line and not line.startswith(("lex:", "vec:", "hyde:")):
  307. return _fail(f"INVALID LINE: {line[:50]}")
  308. # --- Handle "only:" mode separately ---
  309. if only_type:
  310. return _score_only_mode(query, base_query, text, used_thinking, only_type)
  311. parsed = parse_expansion(text)
  312. # --- Format (0-30) ---
  313. format_score = 10 # no invalid lines (guaranteed by hard fail)
  314. if parsed["lex"]:
  315. format_score += 10
  316. else:
  317. deductions.append("missing lex:")
  318. if parsed["vec"]:
  319. format_score += 10
  320. else:
  321. deductions.append("missing vec:")
  322. # --- Diversity (0-30) ---
  323. diversity_score = 0
  324. types_present = sum(1 for t in ("lex", "vec") if parsed[t])
  325. if types_present >= 2:
  326. diversity_score += 10
  327. else:
  328. deductions.append("only one type")
  329. if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
  330. diversity_score += 5
  331. lex_div = 5
  332. for i, a in enumerate(parsed["lex"]):
  333. for b in parsed["lex"][i+1:]:
  334. if not is_diverse(a, b, 2):
  335. lex_div -= 2
  336. deductions.append(f"lex duplicate: {a[:20]}...")
  337. diversity_score += max(0, lex_div)
  338. vec_div = 5
  339. for i, a in enumerate(parsed["vec"]):
  340. for b in parsed["vec"][i+1:]:
  341. if not is_diverse(a, b, 3):
  342. vec_div -= 2
  343. deductions.append(f"vec duplicate: {a[:20]}...")
  344. diversity_score += max(0, vec_div)
  345. echo = 5
  346. for exp in parsed["lex"] + parsed["vec"]:
  347. if echoes_query(exp, query):
  348. echo -= 3
  349. deductions.append(f"echoes query: {exp[:20]}...")
  350. diversity_score += max(0, echo)
  351. # --- HyDE (0-20, optional bonus) ---
  352. hyde_score = 0
  353. if parsed["hyde"]:
  354. hyde_text = parsed["hyde"][0]
  355. hyde_score += 5
  356. hyde_len = len(hyde_text)
  357. if 50 <= hyde_len <= 200:
  358. hyde_score += 5
  359. elif hyde_len < 50:
  360. hyde_score += 2
  361. deductions.append(f"hyde too short ({hyde_len})")
  362. else:
  363. deductions.append(f"hyde too long ({hyde_len})")
  364. if "\n" not in hyde_text:
  365. hyde_score += 5
  366. hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
  367. # --- Quality (0-20) ---
  368. quality_score = 5 # base relevance
  369. if parsed["lex"] and parsed["vec"]:
  370. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  371. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  372. if avg_lex <= avg_vec:
  373. quality_score += 5
  374. else:
  375. deductions.append("lex longer than vec")
  376. if parsed["vec"]:
  377. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  378. quality_score += 5 if natural == len(parsed["vec"]) else 2
  379. if parsed["lex"]:
  380. with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  381. if with_terms == len(parsed["lex"]):
  382. quality_score += 5
  383. elif with_terms > 0:
  384. quality_score += 2
  385. else:
  386. deductions.append("lex missing key terms")
  387. # --- Entity Preservation (-45 to +20) ---
  388. entity_score = 0
  389. entities = extract_named_entities(query)
  390. if entities and parsed["lex"]:
  391. with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  392. if with_entities == len(parsed["lex"]):
  393. entity_score += 15
  394. elif with_entities > 0:
  395. entity_score += 5
  396. else:
  397. entity_score -= 30
  398. deductions.append(f"lex missing entities: {entities}")
  399. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  400. if generic_count:
  401. entity_score -= generic_count * 15
  402. deductions.append(f"{generic_count} generic lex phrases")
  403. if parsed["vec"]:
  404. vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  405. if vec_with > 0:
  406. entity_score += 5
  407. elif not entities:
  408. entity_score = 10
  409. # --- Think bonus (0-20): reward NOT using thinking mode ---
  410. think_bonus = 0 if used_thinking else 20
  411. # --- Total ---
  412. total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
  413. max_possible = 140 if parsed["hyde"] else 120
  414. percentage = max(0.0, min(100.0, total / max_possible * 100))
  415. if percentage >= 80:
  416. rating = "Excellent"
  417. elif percentage >= 60:
  418. rating = "Good"
  419. elif percentage >= 40:
  420. rating = "Acceptable"
  421. elif percentage >= 20:
  422. rating = "Poor"
  423. else:
  424. rating = "Failed"
  425. return {
  426. "format": format_score,
  427. "diversity": diversity_score,
  428. "hyde": hyde_score,
  429. "quality": quality_score,
  430. "entity": max(0, entity_score),
  431. "think_bonus": think_bonus,
  432. "total": max(0, total),
  433. "max_possible": max_possible,
  434. "percentage": round(percentage, 1),
  435. "rating": rating,
  436. "deductions": deductions,
  437. "parsed": parsed,
  438. "entities_detected": list(entities) if entities else [],
  439. "only_mode": None,
  440. }
  441. def score_expansion(query: str, expansion: str) -> float:
  442. """Score expansion as a float in [0.0, 1.0] for use as RL reward."""
  443. result = score_expansion_detailed(query, expansion)
  444. return max(0.0, min(1.0, result["total"] / result["max_possible"]))
  445. def extract_query_from_prompt(prompt: str) -> str:
  446. """Extract the query string from a chat-formatted prompt."""
  447. if "Expand this search query:" in prompt:
  448. query = prompt.split("Expand this search query:")[-1].strip()
  449. if "<|im_end|>" in query:
  450. query = query.split("<|im_end|>")[0].strip()
  451. return query
  452. return prompt.strip()
  453. # =============================================================================
  454. # TRL-compatible reward class
  455. # =============================================================================
  456. class QMDRewardFunction:
  457. """Reward function compatible with TRL's GRPOTrainer."""
  458. __name__ = "qmd_scoring_reward"
  459. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  460. rewards = []
  461. for i, completion in enumerate(completions):
  462. query = ""
  463. if prompts and i < len(prompts):
  464. query = extract_query_from_prompt(prompts[i])
  465. rewards.append(score_expansion(query, completion))
  466. return rewards
  467. # =============================================================================
  468. # CLI: run standalone to test the reward function
  469. # =============================================================================
  470. if __name__ == "__main__":
  471. print("QMD Reward Function Self-Test")
  472. print("=" * 60)
  473. tests = [
  474. ("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."),
  475. ("auth", "auth is important for security"),
  476. ("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"),
  477. ("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"),
  478. ("how to use React hooks", "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"),
  479. ("auth", "<think>Let me think...</think>\nlex: auth"),
  480. ("auth", "lex: auth\nThis is some explanation\nvec: more"),
  481. # "only:" mode tests
  482. ("auth only: lex", "lex: auth setup\nlex: authentication config\nlex: login credentials"),
  483. ("auth only: lex", "lex: auth setup\nvec: how to configure authentication"), # should fail - has vec
  484. ("React hooks only: vec", "vec: how to use React hooks in functional components\nvec: useState and useEffect patterns in React"),
  485. ("PostgreSQL indexing only: hyde", "hyde: PostgreSQL uses B-tree indexes by default. Create indexes with CREATE INDEX idx_name ON table(column). EXPLAIN ANALYZE shows whether queries use indexes efficiently."),
  486. ]
  487. for query, expansion in tests:
  488. score = score_expansion(query, expansion)
  489. detail = score_expansion_detailed(query, expansion)
  490. only_mode = detail.get("only_mode")
  491. mode_str = f" [only:{only_mode}]" if only_mode else ""
  492. print(f"\n Query: '{query}'{mode_str}")
  493. print(f" Score: {score:.2f} ({detail['rating']})")
  494. if detail["deductions"]:
  495. print(f" Issues: {', '.join(detail['deductions'][:3])}")