reward.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = []
  4. # ///
  5. """
  6. QMD Query Expansion Reward Function
  7. Single source of truth for scoring query expansions. Used by:
  8. - GRPO training (as the RL reward signal)
  9. - Evaluation scripts (for scoring model outputs)
  10. Scores expansions on five dimensions:
  11. Format (30) - Has lex/vec lines, no invalid lines
  12. Diversity (30) - Multiple types, diverse content, no echoes
  13. HyDE (20) - Optional bonus for hypothetical document passage
  14. Quality (20) - Lex shorter than vec, natural language, key terms
  15. Entity (20) - Named entity preservation in lex/vec lines
  16. Returns 0.0-1.0 for RL rewards, or a detailed breakdown dict for evaluation.
  17. """
  18. import re
  19. from collections import Counter
  20. # =============================================================================
  21. # Constants
  22. # =============================================================================
  23. # "only:" mode patterns - when query ends with these, expect only that type
  24. # Format: "query /only:lex" (slash prefix, no space after colon)
  25. ONLY_MODE_PATTERN = re.compile(r'\s+/only:(lex|vec|hyde)\s*$', re.IGNORECASE)
  26. STOPWORDS = frozenset({
  27. 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
  28. 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
  29. })
  30. KEY_TERM_STOPWORDS = frozenset({
  31. 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  32. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  33. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
  34. 'about', 'from', 'into', 'between', 'through', 'during', 'after',
  35. 'before', 'like', 'than', 'then', 'that', 'this', 'their', 'its',
  36. 'was', 'were', 'has', 'had', 'been', 'being', 'have', 'not', 'but',
  37. 'just', 'also', 'very', 'so', 'if', 'at', 'by', 'up', 'out', 'all',
  38. 'some', 'any', 'no', 'each', 'every', 'both', 'few', 'more', 'most',
  39. 'other', 'only', 'same', 'such', 'here', 'there', 'asked', 'said',
  40. 'notes', 'meeting', 'email', 'discussion', 'conversation', 'call',
  41. })
  42. # Words that commonly start queries but aren't named entities.
  43. # Used for position-0 entity detection to avoid false positives.
  44. QUERY_VERB_STOPWORDS = frozenset({
  45. 'configure', 'setup', 'install', 'build', 'create', 'make', 'run',
  46. 'start', 'stop', 'check', 'test', 'debug', 'fix', 'update', 'change',
  47. 'add', 'remove', 'delete', 'use', 'using', 'need', 'want', 'should',
  48. 'would', 'could', 'help', 'please', 'best', 'good', 'new', 'old',
  49. 'latest', 'recent', 'setting', 'settings', 'compare', 'comparing',
  50. 'implement', 'implementing', 'deploy', 'deploying', 'migrate',
  51. 'migrating', 'optimize', 'optimizing', 'understand', 'understanding',
  52. 'explain', 'list', 'describe', 'define', 'convert', 'connecting',
  53. 'performance', 'overview', 'introduction', 'tutorial', 'example',
  54. 'difference', 'between', 'about', 'review', 'resolve', 'resolving',
  55. 'troubleshoot', 'troubleshooting', 'monitor', 'monitoring', 'manage',
  56. 'managing', 'enable', 'disable', 'set', 'write', 'read', 'search',
  57. 'possible', 'common', 'typical', 'recommended', 'alternative',
  58. })
  59. GENERIC_LEX_PHRASES = frozenset({
  60. 'find information about', 'search for', 'look up', 'get information',
  61. 'learn about', 'information on', 'details about', 'find out about',
  62. 'what is', 'how to', 'guide to', 'help with',
  63. })
  64. # Words commonly injected as filler/noise into lex lines by template generators
  65. # (e.g. "ancient overview rome timeline"). Penalized when absent from the query.
  66. INTERIOR_FILLER_WORDS = frozenset({'overview', 'basics'})
  67. # Chat template tokens that indicate a broken output
  68. CHAT_TEMPLATE_TOKENS = frozenset({
  69. '<|im_start|>', '<|im_end|>', '<|endoftext|>',
  70. '\nassistant\n', '\nuser\n',
  71. })
  72. # =============================================================================
  73. # Parsing
  74. # =============================================================================
  75. def parse_expansion(text: str) -> dict:
  76. """Parse a multi-line expansion into {lex, vec, hyde, invalid} lists."""
  77. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  78. for line in text.strip().split("\n"):
  79. line = line.strip()
  80. if not line:
  81. continue
  82. if line.startswith("lex:"):
  83. result["lex"].append(line[4:].strip())
  84. elif line.startswith("vec:"):
  85. result["vec"].append(line[4:].strip())
  86. elif line.startswith("hyde:"):
  87. result["hyde"].append(line[5:].strip())
  88. else:
  89. result["invalid"].append(line)
  90. return result
  91. def detect_only_mode(query: str) -> tuple[str | None, str]:
  92. """Detect if query ends with 'only: lex/vec/hyde'.
  93. Returns (only_type, base_query) where only_type is None for normal queries.
  94. """
  95. match = ONLY_MODE_PATTERN.search(query)
  96. if match:
  97. only_type = match.group(1).lower()
  98. base_query = query[:match.start()].strip()
  99. return only_type, base_query
  100. return None, query
  101. def clean_model_output(text: str) -> tuple[str, bool]:
  102. """Strip chat template artifacts from model output.
  103. Returns (cleaned_text, used_thinking) where used_thinking is True
  104. if the model emitted <think>...</think> blocks.
  105. """
  106. text = text.replace('<|im_end|>', '').strip()
  107. used_thinking = '<think>' in text and '</think>' in text
  108. if used_thinking:
  109. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
  110. return text, used_thinking
  111. # =============================================================================
  112. # Helpers
  113. # =============================================================================
  114. def extract_named_entities(query: str) -> set:
  115. """Extract named entities using heuristics.
  116. Detects: ALL-CAPS acronyms (TDS, API), capitalized proper nouns (React, Bob),
  117. technical terms with special chars (node.js, C++), CamelCase (JavaScript),
  118. and compound names (TDS motorsports -> both words).
  119. Position-0 words are also detected as entities if they are capitalized and
  120. not common query-starting verbs (e.g. "Bob asked about deploy" -> "bob").
  121. Compound chaining extends one level from a directly-detected entity:
  122. "TDS motorsports" -> {tds, motorsports}; "TDS motorsports team" -> {tds, motorsports}.
  123. """
  124. entities = set()
  125. words = query.split()
  126. prev_was_base_entity = False
  127. for i, word in enumerate(words):
  128. clean = word.strip('.,!?:;()[]"\'')
  129. if not clean:
  130. prev_was_base_entity = False
  131. continue
  132. is_base_entity = False
  133. # ALL-CAPS acronyms: TDS, API, GPU, AWS
  134. if clean.isupper() and len(clean) >= 2:
  135. entities.add(clean.lower())
  136. is_base_entity = True
  137. # Capitalized proper nouns (any position, including first word)
  138. elif clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  139. if i > 0:
  140. # Non-first words: always treat as entity
  141. entities.add(clean.lower())
  142. is_base_entity = True
  143. elif clean.lower() not in QUERY_VERB_STOPWORDS:
  144. # First word: also entity if not a common query verb
  145. entities.add(clean.lower())
  146. is_base_entity = True
  147. # Technical terms with special chars: node.js, C++, .NET
  148. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  149. entities.add(clean.lower())
  150. is_base_entity = True
  151. # CamelCase: JavaScript, TypeScript
  152. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  153. entities.add(clean.lower())
  154. is_base_entity = True
  155. # Compound names: word following a BASE entity only (one level deep).
  156. elif prev_was_base_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  157. entities.add(clean.lower())
  158. prev_was_base_entity = is_base_entity
  159. return entities
  160. def get_key_terms(query: str) -> set:
  161. """Get non-stopword terms from a query."""
  162. return set(query.lower().split()) - KEY_TERM_STOPWORDS
  163. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  164. """Does the lex line contain at least one key term from the query?"""
  165. key_terms = get_key_terms(query)
  166. if not key_terms:
  167. return True
  168. return bool(key_terms & set(lex_line.lower().split()))
  169. def lex_preserves_entities(line: str, entities: set) -> bool:
  170. """Does the line contain at least one named entity?"""
  171. if not entities:
  172. return True
  173. lower = line.lower()
  174. return any(e in lower for e in entities)
  175. def lex_has_filler(lex_line: str, query: str) -> bool:
  176. """Does the lex line contain an INTERIOR_FILLER_WORDS word absent from the query?"""
  177. query_words = set(query.lower().split())
  178. return any(w in INTERIOR_FILLER_WORDS and w not in query_words
  179. for w in lex_line.lower().split())
  180. def lex_is_generic(lex_line: str) -> bool:
  181. """Is this lex line a useless generic filler phrase?"""
  182. lower = lex_line.lower().strip()
  183. for phrase in GENERIC_LEX_PHRASES:
  184. if phrase in lower or lower.startswith(phrase.split()[0]):
  185. remaining = lower
  186. for word in phrase.split():
  187. remaining = remaining.replace(word, '', 1).strip()
  188. if len(remaining) < 3:
  189. return True
  190. return False
  191. def word_set_distance(a: str, b: str) -> int:
  192. """Symmetric difference of word sets (how many words are unique to one)."""
  193. return len(set(a.lower().split()) ^ set(b.lower().split()))
  194. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  195. """Are two strings sufficiently different?"""
  196. a, b = a.lower().strip(), b.lower().strip()
  197. if a == b or a in b or b in a:
  198. return False
  199. return word_set_distance(a, b) >= min_distance
  200. def echoes_query(expansion: str, query: str) -> bool:
  201. """Is this expansion just echoing the original query?"""
  202. exp, q = expansion.lower().strip(), query.lower().strip()
  203. return exp == q or (q in exp and len(exp) < len(q) + 10)
  204. def word_repetition_penalty(text: str) -> int:
  205. """Penalty for words repeated 3+ times (excluding stopwords)."""
  206. counts = Counter(re.findall(r'\b\w+\b', text.lower()))
  207. return sum((c - 2) * 2 for w, c in counts.items()
  208. if c >= 3 and w not in STOPWORDS and len(w) > 2)
  209. # =============================================================================
  210. # Scoring
  211. # =============================================================================
  212. def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool, only_type: str) -> dict:
  213. """Score an 'only:' mode expansion. Expects ONLY the requested type."""
  214. parsed = parse_expansion(text)
  215. deductions = []
  216. # Expected type must be present
  217. expected_items = parsed.get(only_type, [])
  218. if not expected_items:
  219. return {
  220. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  221. "think_bonus": 0, "total": 0, "max_possible": 100,
  222. "percentage": 0.0, "rating": "Failed",
  223. "deductions": [f"missing expected {only_type}: output"],
  224. "parsed": parsed,
  225. "entities_detected": [],
  226. "only_mode": only_type,
  227. }
  228. # Penalize presence of OTHER types
  229. other_types = {"lex", "vec", "hyde"} - {only_type}
  230. unwanted_count = sum(len(parsed.get(t, [])) for t in other_types)
  231. if unwanted_count > 0:
  232. deductions.append(f"contains unwanted types (expected only {only_type})")
  233. # --- Format (0-30) ---
  234. format_score = 30 if unwanted_count == 0 else max(0, 30 - unwanted_count * 10)
  235. # --- Diversity (0-30) ---
  236. diversity_score = 0
  237. div_threshold = 3 if len(base_query.split()) >= 5 else 2
  238. if len(expected_items) >= 2:
  239. diversity_score += 15
  240. # Check for diversity among items
  241. div_score = 15
  242. for i, a in enumerate(expected_items):
  243. for b in expected_items[i+1:]:
  244. if not is_diverse(a, b, div_threshold):
  245. div_score -= 5
  246. deductions.append(f"{only_type} duplicate: {a[:20]}...")
  247. diversity_score += max(0, div_score)
  248. elif len(expected_items) == 1:
  249. diversity_score = 15 # One item is fine for single-type output
  250. # Check for echoes
  251. for exp in expected_items:
  252. if echoes_query(exp, base_query):
  253. diversity_score -= 5
  254. deductions.append(f"echoes query: {exp[:20]}...")
  255. diversity_score = max(0, diversity_score)
  256. # --- Type-specific quality (0-20) ---
  257. quality_score = 10 # base
  258. entities = extract_named_entities(base_query)
  259. if only_type == "lex":
  260. # Lex should be short keyword phrases with key terms
  261. with_terms = sum(1 for l in expected_items if lex_preserves_key_terms(l, base_query))
  262. if with_terms == len(expected_items):
  263. quality_score += 5
  264. # Check for generic phrases
  265. generic = sum(1 for l in expected_items if lex_is_generic(l))
  266. if generic == 0:
  267. quality_score += 5
  268. else:
  269. deductions.append(f"{generic} generic lex phrases")
  270. # Penalty: lex lines containing filler words absent from the query
  271. filler_count = sum(1 for l in expected_items if lex_has_filler(l, base_query))
  272. if filler_count > 0:
  273. quality_score -= filler_count * 3
  274. deductions.append(f"{filler_count} lex line(s) with filler words")
  275. elif only_type == "vec":
  276. # Vec should be natural language sentences
  277. natural = sum(1 for v in expected_items if " " in v and len(v) > 15)
  278. if natural == len(expected_items):
  279. quality_score += 10
  280. else:
  281. quality_score += 5
  282. deductions.append("vec not all natural language")
  283. elif only_type == "hyde":
  284. # Hyde should be a document snippet (50-200 chars)
  285. hyde_text = expected_items[0]
  286. hyde_len = len(hyde_text)
  287. if 50 <= hyde_len <= 200:
  288. quality_score += 10
  289. elif 30 <= hyde_len <= 300:
  290. quality_score += 5
  291. deductions.append(f"hyde length {hyde_len} (ideal: 50-200)")
  292. else:
  293. deductions.append(f"hyde length {hyde_len} out of range")
  294. # --- Entity preservation (0-20) ---
  295. entity_score = 10 # base
  296. if entities:
  297. with_entities = sum(1 for item in expected_items if lex_preserves_entities(item, entities))
  298. if with_entities == len(expected_items):
  299. entity_score += 10
  300. elif with_entities > 0:
  301. entity_score += 5
  302. else:
  303. entity_score = 0
  304. deductions.append(f"missing entities: {entities}")
  305. # --- Think bonus (0-20) ---
  306. think_bonus = 0 if used_thinking else 20
  307. # --- Total ---
  308. total = format_score + diversity_score + quality_score + entity_score + think_bonus
  309. max_possible = 120
  310. percentage = max(0.0, min(100.0, total / max_possible * 100))
  311. if percentage >= 80:
  312. rating = "Excellent"
  313. elif percentage >= 60:
  314. rating = "Good"
  315. elif percentage >= 40:
  316. rating = "Acceptable"
  317. elif percentage >= 20:
  318. rating = "Poor"
  319. else:
  320. rating = "Failed"
  321. return {
  322. "format": format_score,
  323. "diversity": diversity_score,
  324. "hyde": 0, # not used in only mode (quality covers it)
  325. "quality": quality_score,
  326. "entity": entity_score,
  327. "think_bonus": think_bonus,
  328. "total": total,
  329. "max_possible": max_possible,
  330. "percentage": round(percentage, 1),
  331. "rating": rating,
  332. "deductions": deductions,
  333. "parsed": parsed,
  334. "entities_detected": list(entities) if entities else [],
  335. "only_mode": only_type,
  336. }
  337. def score_expansion_detailed(query: str, expansion: str) -> dict:
  338. """Score an expansion with full breakdown. Returns dict with all dimensions."""
  339. text, used_thinking = clean_model_output(expansion.strip())
  340. deductions = []
  341. # Detect "only:" mode
  342. only_type, base_query = detect_only_mode(query)
  343. def _fail(reason):
  344. return {
  345. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  346. "think_bonus": 0, "total": 0, "max_possible": 100,
  347. "percentage": 0.0, "rating": "Failed",
  348. "deductions": [reason],
  349. "parsed": parse_expansion(expansion),
  350. "entities_detected": [],
  351. "only_mode": only_type,
  352. }
  353. # Hard fail: remaining chat template tokens
  354. if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
  355. return _fail("CHAT TEMPLATE LEAKAGE")
  356. # Hard fail: every non-empty line must have a valid prefix
  357. for line in text.split("\n"):
  358. line = line.strip()
  359. if line and not line.startswith(("lex:", "vec:", "hyde:")):
  360. return _fail(f"INVALID LINE: {line[:50]}")
  361. # --- Handle "only:" mode separately ---
  362. if only_type:
  363. return _score_only_mode(query, base_query, text, used_thinking, only_type)
  364. parsed = parse_expansion(text)
  365. # --- Format (0-30) ---
  366. format_score = 10 # no invalid lines (guaranteed by hard fail)
  367. if parsed["lex"]:
  368. format_score += 10
  369. else:
  370. deductions.append("missing lex:")
  371. if parsed["vec"]:
  372. format_score += 10
  373. else:
  374. deductions.append("missing vec:")
  375. # --- Diversity (0-30) ---
  376. diversity_score = 0
  377. types_present = sum(1 for t in ("lex", "vec") if parsed[t])
  378. if types_present >= 2:
  379. diversity_score += 10
  380. else:
  381. deductions.append("only one type")
  382. if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
  383. diversity_score += 5
  384. div_threshold = 3 if len(query.split()) >= 5 else 2
  385. lex_div = 5
  386. for i, a in enumerate(parsed["lex"]):
  387. for b in parsed["lex"][i+1:]:
  388. if not is_diverse(a, b, div_threshold):
  389. lex_div -= 2
  390. deductions.append(f"lex duplicate: {a[:20]}...")
  391. diversity_score += max(0, lex_div)
  392. vec_div = 5
  393. for i, a in enumerate(parsed["vec"]):
  394. for b in parsed["vec"][i+1:]:
  395. if not is_diverse(a, b, div_threshold):
  396. vec_div -= 2
  397. deductions.append(f"vec duplicate: {a[:20]}...")
  398. diversity_score += max(0, vec_div)
  399. echo = 5
  400. lex_echo_count = 0
  401. for exp in parsed["lex"]:
  402. if echoes_query(exp, query):
  403. lex_echo_count += 1
  404. deductions.append(f"lex echoes query: {exp[:20]}...")
  405. # Harsh penalty for lex echoes - they're useless
  406. if lex_echo_count > 0:
  407. echo -= lex_echo_count * 10 # -10 per echo
  408. for exp in parsed["vec"]:
  409. if echoes_query(exp, query):
  410. echo -= 3 # vec echoes less severe (natural language overlap ok)
  411. deductions.append(f"vec echoes query: {exp[:20]}...")
  412. diversity_score += max(-10, echo) # can go negative
  413. # --- HyDE (0-20, optional bonus) ---
  414. hyde_score = 0
  415. if parsed["hyde"]:
  416. hyde_text = parsed["hyde"][0]
  417. hyde_score += 5
  418. hyde_len = len(hyde_text)
  419. if 50 <= hyde_len <= 200:
  420. hyde_score += 5
  421. elif hyde_len < 50:
  422. hyde_score += 2
  423. deductions.append(f"hyde too short ({hyde_len})")
  424. else:
  425. deductions.append(f"hyde too long ({hyde_len})")
  426. if "\n" not in hyde_text:
  427. hyde_score += 5
  428. hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
  429. # --- Extract entities (used by both quality and entity sections) ---
  430. entities = extract_named_entities(query)
  431. # --- Quality (0-20) ---
  432. quality_score = 5 # base relevance
  433. if parsed["lex"] and parsed["vec"]:
  434. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  435. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  436. if avg_lex <= avg_vec:
  437. quality_score += 5
  438. else:
  439. deductions.append("lex longer than vec")
  440. if parsed["vec"]:
  441. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  442. quality_score += 5 if natural == len(parsed["vec"]) else 2
  443. if parsed["lex"]:
  444. with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  445. if with_terms == len(parsed["lex"]):
  446. quality_score += 5
  447. elif with_terms > 0:
  448. quality_score += 2
  449. else:
  450. deductions.append("lex missing key terms")
  451. # Penalty: lex lines containing filler words absent from the query
  452. if parsed["lex"]:
  453. filler_count = sum(1 for l in parsed["lex"] if lex_has_filler(l, query))
  454. if filler_count > 0:
  455. quality_score -= filler_count * 3
  456. deductions.append(f"{filler_count} lex line(s) with filler words")
  457. # Bonus: lex uses quoted phrases for multi-word queries (+3)
  458. if parsed["lex"] and len(query.split()) >= 2:
  459. lex_joined = " ".join(parsed["lex"])
  460. if '"' in lex_joined:
  461. quality_score += 3
  462. # --- Entity Preservation (-45 to +20) ---
  463. entity_score = 0
  464. if entities and parsed["lex"]:
  465. # Per-line check: do lex lines contain entities?
  466. with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  467. if with_entities == len(parsed["lex"]):
  468. entity_score += 15
  469. elif with_entities > 0:
  470. entity_score += 5
  471. else:
  472. entity_score -= 30
  473. deductions.append(f"lex missing entities: {entities}")
  474. # Per-entity coverage: is each entity mentioned somewhere in lex+vec?
  475. all_output = " ".join(parsed["lex"] + parsed["vec"]).lower()
  476. missing_entities = {e for e in entities if e not in all_output}
  477. if missing_entities:
  478. penalty = len(missing_entities) * 20
  479. entity_score -= penalty
  480. deductions.append(f"entities dropped: {missing_entities}")
  481. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  482. if generic_count:
  483. entity_score -= generic_count * 15
  484. deductions.append(f"{generic_count} generic lex phrases")
  485. if parsed["vec"]:
  486. vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  487. if vec_with > 0:
  488. entity_score += 5
  489. elif not entities:
  490. entity_score = 10
  491. # --- Think bonus (0-20): reward NOT using thinking mode ---
  492. think_bonus = 0 if used_thinking else 20
  493. # --- Total ---
  494. total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
  495. max_possible = 140 if parsed["hyde"] else 120
  496. percentage = max(0.0, min(100.0, total / max_possible * 100))
  497. # Hard cap: lex echoes are unacceptable - cap at 50%
  498. if lex_echo_count > 0:
  499. percentage = min(percentage, 50.0)
  500. deductions.insert(0, f"CAPPED: {lex_echo_count} lex echo(es)")
  501. if percentage >= 80:
  502. rating = "Excellent"
  503. elif percentage >= 60:
  504. rating = "Good"
  505. elif percentage >= 40:
  506. rating = "Acceptable"
  507. elif percentage >= 20:
  508. rating = "Poor"
  509. else:
  510. rating = "Failed"
  511. return {
  512. "format": format_score,
  513. "diversity": diversity_score,
  514. "hyde": hyde_score,
  515. "quality": quality_score,
  516. "entity": max(0, entity_score),
  517. "think_bonus": think_bonus,
  518. "total": max(0, total),
  519. "max_possible": max_possible,
  520. "percentage": round(percentage, 1),
  521. "rating": rating,
  522. "deductions": deductions,
  523. "parsed": parsed,
  524. "entities_detected": list(entities) if entities else [],
  525. "only_mode": None,
  526. }
  527. def score_expansion(query: str, expansion: str) -> float:
  528. """Score expansion as a float in [0.0, 1.0] for use as RL reward."""
  529. result = score_expansion_detailed(query, expansion)
  530. return max(0.0, min(1.0, result["total"] / result["max_possible"]))
  531. def extract_query_from_prompt(prompt: str) -> str:
  532. """Extract the query string from a chat-formatted prompt."""
  533. if "Expand this search query:" in prompt:
  534. query = prompt.split("Expand this search query:")[-1].strip()
  535. if "<|im_end|>" in query:
  536. query = query.split("<|im_end|>")[0].strip()
  537. return query
  538. return prompt.strip()
  539. # =============================================================================
  540. # TRL-compatible reward class
  541. # =============================================================================
  542. class QMDRewardFunction:
  543. """Reward function compatible with TRL's GRPOTrainer."""
  544. __name__ = "qmd_scoring_reward"
  545. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  546. rewards = []
  547. for i, completion in enumerate(completions):
  548. query = ""
  549. if prompts and i < len(prompts):
  550. query = extract_query_from_prompt(prompts[i])
  551. rewards.append(score_expansion(query, completion))
  552. return rewards
  553. # =============================================================================
  554. # CLI: run standalone to test the reward function
  555. # =============================================================================
  556. if __name__ == "__main__":
  557. print("QMD Reward Function Self-Test")
  558. print("=" * 60)
  559. tests = [
  560. ("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."),
  561. ("auth", "auth is important for security"),
  562. ("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"),
  563. ("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"),
  564. ("how to use React hooks", "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"),
  565. ("auth", "<think>Let me think...</think>\nlex: auth"),
  566. ("auth", "lex: auth\nThis is some explanation\nvec: more"),
  567. # Personal entity tests (issue #247: entity stripping)
  568. ("meeting with Bob about C++", 'lex: Bob "C++" meeting\nlex: Bob C++ discussion notes\nvec: meeting notes with Bob about C++ programming'),
  569. ("meeting with Bob about C++", "lex: c++ meetings\nvec: programming meeting notes"), # BAD: Bob is gone
  570. # Quoted phrases bonus
  571. ("python memory leak debugging", 'lex: "memory leak" python -java\nlex: tracemalloc profiler\nvec: how to find memory leaks in Python'),
  572. # "/only:" mode tests (slash prefix)
  573. ("auth /only:lex", "lex: auth setup\nlex: authentication config\nlex: login credentials"),
  574. ("auth /only:lex", "lex: auth setup\nvec: how to configure authentication"), # should fail - has vec
  575. ("React hooks /only:vec", "vec: how to use React hooks in functional components\nvec: useState and useEffect patterns in React"),
  576. ("PostgreSQL indexing /only:hyde", "hyde: PostgreSQL uses B-tree indexes by default. Create indexes with CREATE INDEX idx_name ON table(column). EXPLAIN ANALYZE shows whether queries use indexes efficiently."),
  577. ]
  578. for query, expansion in tests:
  579. score = score_expansion(query, expansion)
  580. detail = score_expansion_detailed(query, expansion)
  581. only_mode = detail.get("only_mode")
  582. mode_str = f" [only:{only_mode}]" if only_mode else ""
  583. print(f"\n Query: '{query}'{mode_str}")
  584. print(f" Score: {score:.2f} ({detail['rating']})")
  585. if detail["deductions"]:
  586. print(f" Issues: {', '.join(detail['deductions'][:3])}")