reward.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = []
  4. # ///
  5. """
  6. QMD Query Expansion Reward Function
  7. Single source of truth for scoring query expansions. Used by:
  8. - GRPO training (as the RL reward signal)
  9. - Evaluation scripts (for scoring model outputs)
  10. Scores expansions on five dimensions:
  11. Format (30) - Has lex/vec lines, no invalid lines
  12. Diversity (30) - Multiple types, diverse content, no echoes
  13. HyDE (20) - Optional bonus for hypothetical document passage
  14. Quality (20) - Lex shorter than vec, natural language, key terms
  15. Entity (20) - Named entity preservation in lex/vec lines
  16. Returns 0.0-1.0 for RL rewards, or a detailed breakdown dict for evaluation.
  17. """
  18. import re
  19. from collections import Counter
  20. # =============================================================================
  21. # Constants
  22. # =============================================================================
  23. # "only:" mode patterns - when query ends with these, expect only that type
  24. # Format: "query /only:lex" (slash prefix, no space after colon)
  25. ONLY_MODE_PATTERN = re.compile(r'\s+/only:(lex|vec|hyde)\s*$', re.IGNORECASE)
  26. STOPWORDS = frozenset({
  27. 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in',
  28. 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by',
  29. })
  30. KEY_TERM_STOPWORDS = frozenset({
  31. 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
  32. 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
  33. 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell',
  34. 'about', 'from', 'into', 'between', 'through', 'during', 'after',
  35. 'before', 'like', 'than', 'then', 'that', 'this', 'their', 'its',
  36. 'was', 'were', 'has', 'had', 'been', 'being', 'have', 'not', 'but',
  37. 'just', 'also', 'very', 'so', 'if', 'at', 'by', 'up', 'out', 'all',
  38. 'some', 'any', 'no', 'each', 'every', 'both', 'few', 'more', 'most',
  39. 'other', 'only', 'same', 'such', 'here', 'there', 'asked', 'said',
  40. 'notes', 'meeting', 'email', 'discussion', 'conversation', 'call',
  41. })
  42. # Words that commonly start queries but aren't named entities.
  43. # Used for position-0 entity detection to avoid false positives.
  44. QUERY_VERB_STOPWORDS = frozenset({
  45. 'configure', 'setup', 'install', 'build', 'create', 'make', 'run',
  46. 'start', 'stop', 'check', 'test', 'debug', 'fix', 'update', 'change',
  47. 'add', 'remove', 'delete', 'use', 'using', 'need', 'want', 'should',
  48. 'would', 'could', 'help', 'please', 'best', 'good', 'new', 'old',
  49. 'latest', 'recent', 'setting', 'settings', 'compare', 'comparing',
  50. 'implement', 'implementing', 'deploy', 'deploying', 'migrate',
  51. 'migrating', 'optimize', 'optimizing', 'understand', 'understanding',
  52. 'explain', 'list', 'describe', 'define', 'convert', 'connecting',
  53. 'performance', 'overview', 'introduction', 'tutorial', 'example',
  54. 'difference', 'between', 'about', 'review', 'resolve', 'resolving',
  55. 'troubleshoot', 'troubleshooting', 'monitor', 'monitoring', 'manage',
  56. 'managing', 'enable', 'disable', 'set', 'write', 'read', 'search',
  57. 'possible', 'common', 'typical', 'recommended', 'alternative',
  58. })
  59. GENERIC_LEX_PHRASES = frozenset({
  60. 'find information about', 'search for', 'look up', 'get information',
  61. 'learn about', 'information on', 'details about', 'find out about',
  62. 'what is', 'how to', 'guide to', 'help with',
  63. })
  64. # Chat template tokens that indicate a broken output
  65. CHAT_TEMPLATE_TOKENS = frozenset({
  66. '<|im_start|>', '<|im_end|>', '<|endoftext|>',
  67. '\nassistant\n', '\nuser\n',
  68. })
  69. # =============================================================================
  70. # Parsing
  71. # =============================================================================
  72. def parse_expansion(text: str) -> dict:
  73. """Parse a multi-line expansion into {lex, vec, hyde, invalid} lists."""
  74. result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
  75. for line in text.strip().split("\n"):
  76. line = line.strip()
  77. if not line:
  78. continue
  79. if line.startswith("lex:"):
  80. result["lex"].append(line[4:].strip())
  81. elif line.startswith("vec:"):
  82. result["vec"].append(line[4:].strip())
  83. elif line.startswith("hyde:"):
  84. result["hyde"].append(line[5:].strip())
  85. else:
  86. result["invalid"].append(line)
  87. return result
  88. def detect_only_mode(query: str) -> tuple[str | None, str]:
  89. """Detect if query ends with 'only: lex/vec/hyde'.
  90. Returns (only_type, base_query) where only_type is None for normal queries.
  91. """
  92. match = ONLY_MODE_PATTERN.search(query)
  93. if match:
  94. only_type = match.group(1).lower()
  95. base_query = query[:match.start()].strip()
  96. return only_type, base_query
  97. return None, query
  98. def clean_model_output(text: str) -> tuple[str, bool]:
  99. """Strip chat template artifacts from model output.
  100. Returns (cleaned_text, used_thinking) where used_thinking is True
  101. if the model emitted <think>...</think> blocks.
  102. """
  103. text = text.replace('<|im_end|>', '').strip()
  104. used_thinking = '<think>' in text and '</think>' in text
  105. if used_thinking:
  106. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
  107. return text, used_thinking
  108. # =============================================================================
  109. # Helpers
  110. # =============================================================================
  111. def extract_named_entities(query: str) -> set:
  112. """Extract named entities using heuristics.
  113. Detects: ALL-CAPS acronyms (TDS, API), capitalized proper nouns (React, Bob),
  114. technical terms with special chars (node.js, C++), CamelCase (JavaScript),
  115. and compound names (TDS motorsports -> both words).
  116. Position-0 words are also detected as entities if they are capitalized and
  117. not common query-starting verbs (e.g. "Bob asked about deploy" -> "bob").
  118. """
  119. entities = set()
  120. words = query.split()
  121. prev_was_entity = False
  122. for i, word in enumerate(words):
  123. clean = word.strip('.,!?:;()[]"\'')
  124. if not clean:
  125. prev_was_entity = False
  126. continue
  127. is_entity = False
  128. # ALL-CAPS acronyms: TDS, API, GPU, AWS
  129. if clean.isupper() and len(clean) >= 2:
  130. entities.add(clean.lower())
  131. is_entity = True
  132. # Capitalized proper nouns (any position, including first word)
  133. elif clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
  134. if i > 0:
  135. # Non-first words: always treat as entity
  136. entities.add(clean.lower())
  137. is_entity = True
  138. elif clean.lower() not in QUERY_VERB_STOPWORDS:
  139. # First word: also entity if not a common query verb
  140. entities.add(clean.lower())
  141. is_entity = True
  142. # Technical terms with special chars: node.js, C++, .NET
  143. elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
  144. entities.add(clean.lower())
  145. is_entity = True
  146. # CamelCase: JavaScript, TypeScript
  147. elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
  148. entities.add(clean.lower())
  149. is_entity = True
  150. # Compound names: word following an entity (TDS motorsports)
  151. elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
  152. entities.add(clean.lower())
  153. is_entity = True
  154. prev_was_entity = is_entity
  155. return entities
  156. def get_key_terms(query: str) -> set:
  157. """Get non-stopword terms from a query."""
  158. return set(query.lower().split()) - KEY_TERM_STOPWORDS
  159. def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
  160. """Does the lex line contain at least one key term from the query?"""
  161. key_terms = get_key_terms(query)
  162. if not key_terms:
  163. return True
  164. return bool(key_terms & set(lex_line.lower().split()))
  165. def lex_preserves_entities(line: str, entities: set) -> bool:
  166. """Does the line contain at least one named entity?"""
  167. if not entities:
  168. return True
  169. lower = line.lower()
  170. return any(e in lower for e in entities)
  171. def lex_is_generic(lex_line: str) -> bool:
  172. """Is this lex line a useless generic filler phrase?"""
  173. lower = lex_line.lower().strip()
  174. for phrase in GENERIC_LEX_PHRASES:
  175. if phrase in lower or lower.startswith(phrase.split()[0]):
  176. remaining = lower
  177. for word in phrase.split():
  178. remaining = remaining.replace(word, '', 1).strip()
  179. if len(remaining) < 3:
  180. return True
  181. return False
  182. def word_set_distance(a: str, b: str) -> int:
  183. """Symmetric difference of word sets (how many words are unique to one)."""
  184. return len(set(a.lower().split()) ^ set(b.lower().split()))
  185. def is_diverse(a: str, b: str, min_distance: int = 2) -> bool:
  186. """Are two strings sufficiently different?"""
  187. a, b = a.lower().strip(), b.lower().strip()
  188. if a == b or a in b or b in a:
  189. return False
  190. return word_set_distance(a, b) >= min_distance
  191. def echoes_query(expansion: str, query: str) -> bool:
  192. """Is this expansion just echoing the original query?"""
  193. exp, q = expansion.lower().strip(), query.lower().strip()
  194. return exp == q or (q in exp and len(exp) < len(q) + 10)
  195. def word_repetition_penalty(text: str) -> int:
  196. """Penalty for words repeated 3+ times (excluding stopwords)."""
  197. counts = Counter(re.findall(r'\b\w+\b', text.lower()))
  198. return sum((c - 2) * 2 for w, c in counts.items()
  199. if c >= 3 and w not in STOPWORDS and len(w) > 2)
  200. # =============================================================================
  201. # Scoring
  202. # =============================================================================
  203. def _score_only_mode(query: str, base_query: str, text: str, used_thinking: bool, only_type: str) -> dict:
  204. """Score an 'only:' mode expansion. Expects ONLY the requested type."""
  205. parsed = parse_expansion(text)
  206. deductions = []
  207. # Expected type must be present
  208. expected_items = parsed.get(only_type, [])
  209. if not expected_items:
  210. return {
  211. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  212. "think_bonus": 0, "total": 0, "max_possible": 100,
  213. "percentage": 0.0, "rating": "Failed",
  214. "deductions": [f"missing expected {only_type}: output"],
  215. "parsed": parsed,
  216. "entities_detected": [],
  217. "only_mode": only_type,
  218. }
  219. # Penalize presence of OTHER types
  220. other_types = {"lex", "vec", "hyde"} - {only_type}
  221. unwanted_count = sum(len(parsed.get(t, [])) for t in other_types)
  222. if unwanted_count > 0:
  223. deductions.append(f"contains unwanted types (expected only {only_type})")
  224. # --- Format (0-30) ---
  225. format_score = 30 if unwanted_count == 0 else max(0, 30 - unwanted_count * 10)
  226. # --- Diversity (0-30) ---
  227. diversity_score = 0
  228. if len(expected_items) >= 2:
  229. diversity_score += 15
  230. # Check for diversity among items
  231. div_score = 15
  232. for i, a in enumerate(expected_items):
  233. for b in expected_items[i+1:]:
  234. if not is_diverse(a, b, 2):
  235. div_score -= 5
  236. deductions.append(f"{only_type} duplicate: {a[:20]}...")
  237. diversity_score += max(0, div_score)
  238. elif len(expected_items) == 1:
  239. diversity_score = 15 # One item is fine for single-type output
  240. # Check for echoes
  241. for exp in expected_items:
  242. if echoes_query(exp, base_query):
  243. diversity_score -= 5
  244. deductions.append(f"echoes query: {exp[:20]}...")
  245. diversity_score = max(0, diversity_score)
  246. # --- Type-specific quality (0-20) ---
  247. quality_score = 10 # base
  248. entities = extract_named_entities(base_query)
  249. if only_type == "lex":
  250. # Lex should be short keyword phrases with key terms
  251. with_terms = sum(1 for l in expected_items if lex_preserves_key_terms(l, base_query))
  252. if with_terms == len(expected_items):
  253. quality_score += 5
  254. # Check for generic phrases
  255. generic = sum(1 for l in expected_items if lex_is_generic(l))
  256. if generic == 0:
  257. quality_score += 5
  258. else:
  259. deductions.append(f"{generic} generic lex phrases")
  260. elif only_type == "vec":
  261. # Vec should be natural language sentences
  262. natural = sum(1 for v in expected_items if " " in v and len(v) > 15)
  263. if natural == len(expected_items):
  264. quality_score += 10
  265. else:
  266. quality_score += 5
  267. deductions.append("vec not all natural language")
  268. elif only_type == "hyde":
  269. # Hyde should be a document snippet (50-200 chars)
  270. hyde_text = expected_items[0]
  271. hyde_len = len(hyde_text)
  272. if 50 <= hyde_len <= 200:
  273. quality_score += 10
  274. elif 30 <= hyde_len <= 300:
  275. quality_score += 5
  276. deductions.append(f"hyde length {hyde_len} (ideal: 50-200)")
  277. else:
  278. deductions.append(f"hyde length {hyde_len} out of range")
  279. # --- Entity preservation (0-20) ---
  280. entity_score = 10 # base
  281. if entities:
  282. with_entities = sum(1 for item in expected_items if lex_preserves_entities(item, entities))
  283. if with_entities == len(expected_items):
  284. entity_score += 10
  285. elif with_entities > 0:
  286. entity_score += 5
  287. else:
  288. entity_score = 0
  289. deductions.append(f"missing entities: {entities}")
  290. # --- Think bonus (0-20) ---
  291. think_bonus = 0 if used_thinking else 20
  292. # --- Total ---
  293. total = format_score + diversity_score + quality_score + entity_score + think_bonus
  294. max_possible = 120
  295. percentage = max(0.0, min(100.0, total / max_possible * 100))
  296. if percentage >= 80:
  297. rating = "Excellent"
  298. elif percentage >= 60:
  299. rating = "Good"
  300. elif percentage >= 40:
  301. rating = "Acceptable"
  302. elif percentage >= 20:
  303. rating = "Poor"
  304. else:
  305. rating = "Failed"
  306. return {
  307. "format": format_score,
  308. "diversity": diversity_score,
  309. "hyde": 0, # not used in only mode (quality covers it)
  310. "quality": quality_score,
  311. "entity": entity_score,
  312. "think_bonus": think_bonus,
  313. "total": total,
  314. "max_possible": max_possible,
  315. "percentage": round(percentage, 1),
  316. "rating": rating,
  317. "deductions": deductions,
  318. "parsed": parsed,
  319. "entities_detected": list(entities) if entities else [],
  320. "only_mode": only_type,
  321. }
  322. def score_expansion_detailed(query: str, expansion: str) -> dict:
  323. """Score an expansion with full breakdown. Returns dict with all dimensions."""
  324. text, used_thinking = clean_model_output(expansion.strip())
  325. deductions = []
  326. # Detect "only:" mode
  327. only_type, base_query = detect_only_mode(query)
  328. def _fail(reason):
  329. return {
  330. "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0,
  331. "think_bonus": 0, "total": 0, "max_possible": 100,
  332. "percentage": 0.0, "rating": "Failed",
  333. "deductions": [reason],
  334. "parsed": parse_expansion(expansion),
  335. "entities_detected": [],
  336. "only_mode": only_type,
  337. }
  338. # Hard fail: remaining chat template tokens
  339. if any(tok in text for tok in CHAT_TEMPLATE_TOKENS):
  340. return _fail("CHAT TEMPLATE LEAKAGE")
  341. # Hard fail: every non-empty line must have a valid prefix
  342. for line in text.split("\n"):
  343. line = line.strip()
  344. if line and not line.startswith(("lex:", "vec:", "hyde:")):
  345. return _fail(f"INVALID LINE: {line[:50]}")
  346. # --- Handle "only:" mode separately ---
  347. if only_type:
  348. return _score_only_mode(query, base_query, text, used_thinking, only_type)
  349. parsed = parse_expansion(text)
  350. # --- Format (0-30) ---
  351. format_score = 10 # no invalid lines (guaranteed by hard fail)
  352. if parsed["lex"]:
  353. format_score += 10
  354. else:
  355. deductions.append("missing lex:")
  356. if parsed["vec"]:
  357. format_score += 10
  358. else:
  359. deductions.append("missing vec:")
  360. # --- Diversity (0-30) ---
  361. diversity_score = 0
  362. types_present = sum(1 for t in ("lex", "vec") if parsed[t])
  363. if types_present >= 2:
  364. diversity_score += 10
  365. else:
  366. deductions.append("only one type")
  367. if len(parsed["lex"]) + len(parsed["vec"]) >= 2:
  368. diversity_score += 5
  369. lex_div = 5
  370. for i, a in enumerate(parsed["lex"]):
  371. for b in parsed["lex"][i+1:]:
  372. if not is_diverse(a, b, 2):
  373. lex_div -= 2
  374. deductions.append(f"lex duplicate: {a[:20]}...")
  375. diversity_score += max(0, lex_div)
  376. vec_div = 5
  377. for i, a in enumerate(parsed["vec"]):
  378. for b in parsed["vec"][i+1:]:
  379. if not is_diverse(a, b, 3):
  380. vec_div -= 2
  381. deductions.append(f"vec duplicate: {a[:20]}...")
  382. diversity_score += max(0, vec_div)
  383. echo = 5
  384. lex_echo_count = 0
  385. for exp in parsed["lex"]:
  386. if echoes_query(exp, query):
  387. lex_echo_count += 1
  388. deductions.append(f"lex echoes query: {exp[:20]}...")
  389. # Harsh penalty for lex echoes - they're useless
  390. if lex_echo_count > 0:
  391. echo -= lex_echo_count * 10 # -10 per echo
  392. for exp in parsed["vec"]:
  393. if echoes_query(exp, query):
  394. echo -= 3 # vec echoes less severe (natural language overlap ok)
  395. deductions.append(f"vec echoes query: {exp[:20]}...")
  396. diversity_score += max(-10, echo) # can go negative
  397. # --- HyDE (0-20, optional bonus) ---
  398. hyde_score = 0
  399. if parsed["hyde"]:
  400. hyde_text = parsed["hyde"][0]
  401. hyde_score += 5
  402. hyde_len = len(hyde_text)
  403. if 50 <= hyde_len <= 200:
  404. hyde_score += 5
  405. elif hyde_len < 50:
  406. hyde_score += 2
  407. deductions.append(f"hyde too short ({hyde_len})")
  408. else:
  409. deductions.append(f"hyde too long ({hyde_len})")
  410. if "\n" not in hyde_text:
  411. hyde_score += 5
  412. hyde_score += max(0, 5 - word_repetition_penalty(hyde_text))
  413. # --- Extract entities (used by both quality and entity sections) ---
  414. entities = extract_named_entities(query)
  415. # --- Quality (0-20) ---
  416. quality_score = 5 # base relevance
  417. if parsed["lex"] and parsed["vec"]:
  418. avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"])
  419. avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"])
  420. if avg_lex <= avg_vec:
  421. quality_score += 5
  422. else:
  423. deductions.append("lex longer than vec")
  424. if parsed["vec"]:
  425. natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15)
  426. quality_score += 5 if natural == len(parsed["vec"]) else 2
  427. if parsed["lex"]:
  428. with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query))
  429. if with_terms == len(parsed["lex"]):
  430. quality_score += 5
  431. elif with_terms > 0:
  432. quality_score += 2
  433. else:
  434. deductions.append("lex missing key terms")
  435. # Bonus: lex uses quoted phrases for multi-word entities (+3)
  436. if entities and parsed["lex"]:
  437. multi_word_entities = [e for e in entities if " " in e or len(e) > 6]
  438. if multi_word_entities:
  439. lex_joined = " ".join(parsed["lex"])
  440. if '"' in lex_joined:
  441. quality_score += 3
  442. # --- Entity Preservation (-45 to +20) ---
  443. entity_score = 0
  444. if entities and parsed["lex"]:
  445. # Per-line check: do lex lines contain entities?
  446. with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
  447. if with_entities == len(parsed["lex"]):
  448. entity_score += 15
  449. elif with_entities > 0:
  450. entity_score += 5
  451. else:
  452. entity_score -= 30
  453. deductions.append(f"lex missing entities: {entities}")
  454. # Per-entity coverage: is each entity mentioned somewhere in lex+vec?
  455. all_output = " ".join(parsed["lex"] + parsed["vec"]).lower()
  456. missing_entities = {e for e in entities if e not in all_output}
  457. if missing_entities:
  458. penalty = len(missing_entities) * 20
  459. entity_score -= penalty
  460. deductions.append(f"entities dropped: {missing_entities}")
  461. generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
  462. if generic_count:
  463. entity_score -= generic_count * 15
  464. deductions.append(f"{generic_count} generic lex phrases")
  465. if parsed["vec"]:
  466. vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
  467. if vec_with > 0:
  468. entity_score += 5
  469. elif not entities:
  470. entity_score = 10
  471. # --- Think bonus (0-20): reward NOT using thinking mode ---
  472. think_bonus = 0 if used_thinking else 20
  473. # --- Total ---
  474. total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus
  475. max_possible = 140 if parsed["hyde"] else 120
  476. percentage = max(0.0, min(100.0, total / max_possible * 100))
  477. # Hard cap: lex echoes are unacceptable - cap at 50%
  478. if lex_echo_count > 0:
  479. percentage = min(percentage, 50.0)
  480. deductions.insert(0, f"CAPPED: {lex_echo_count} lex echo(es)")
  481. if percentage >= 80:
  482. rating = "Excellent"
  483. elif percentage >= 60:
  484. rating = "Good"
  485. elif percentage >= 40:
  486. rating = "Acceptable"
  487. elif percentage >= 20:
  488. rating = "Poor"
  489. else:
  490. rating = "Failed"
  491. return {
  492. "format": format_score,
  493. "diversity": diversity_score,
  494. "hyde": hyde_score,
  495. "quality": quality_score,
  496. "entity": max(0, entity_score),
  497. "think_bonus": think_bonus,
  498. "total": max(0, total),
  499. "max_possible": max_possible,
  500. "percentage": round(percentage, 1),
  501. "rating": rating,
  502. "deductions": deductions,
  503. "parsed": parsed,
  504. "entities_detected": list(entities) if entities else [],
  505. "only_mode": None,
  506. }
  507. def score_expansion(query: str, expansion: str) -> float:
  508. """Score expansion as a float in [0.0, 1.0] for use as RL reward."""
  509. result = score_expansion_detailed(query, expansion)
  510. return max(0.0, min(1.0, result["total"] / result["max_possible"]))
  511. def extract_query_from_prompt(prompt: str) -> str:
  512. """Extract the query string from a chat-formatted prompt."""
  513. if "Expand this search query:" in prompt:
  514. query = prompt.split("Expand this search query:")[-1].strip()
  515. if "<|im_end|>" in query:
  516. query = query.split("<|im_end|>")[0].strip()
  517. return query
  518. return prompt.strip()
  519. # =============================================================================
  520. # TRL-compatible reward class
  521. # =============================================================================
  522. class QMDRewardFunction:
  523. """Reward function compatible with TRL's GRPOTrainer."""
  524. __name__ = "qmd_scoring_reward"
  525. def __call__(self, completions: list[str], prompts: list[str] = None, **kwargs) -> list[float]:
  526. rewards = []
  527. for i, completion in enumerate(completions):
  528. query = ""
  529. if prompts and i < len(prompts):
  530. query = extract_query_from_prompt(prompts[i])
  531. rewards.append(score_expansion(query, completion))
  532. return rewards
  533. # =============================================================================
  534. # CLI: run standalone to test the reward function
  535. # =============================================================================
  536. if __name__ == "__main__":
  537. print("QMD Reward Function Self-Test")
  538. print("=" * 60)
  539. tests = [
  540. ("auth", "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."),
  541. ("auth", "auth is important for security"),
  542. ("who is TDS motorsports", "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"),
  543. ("who is TDS motorsports", "lex: find information about\nlex: company details\nvec: who is this company"),
  544. ("how to use React hooks", "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"),
  545. ("auth", "<think>Let me think...</think>\nlex: auth"),
  546. ("auth", "lex: auth\nThis is some explanation\nvec: more"),
  547. # Personal entity tests (issue #247: entity stripping)
  548. ("meeting with Bob about C++", 'lex: Bob "C++" meeting\nlex: Bob C++ discussion notes\nvec: meeting notes with Bob about C++ programming'),
  549. ("meeting with Bob about C++", "lex: c++ meetings\nvec: programming meeting notes"), # BAD: Bob is gone
  550. # Quoted phrases bonus
  551. ("python memory leak debugging", 'lex: "memory leak" python -java\nlex: tracemalloc profiler\nvec: how to find memory leaks in Python'),
  552. # "/only:" mode tests (slash prefix)
  553. ("auth /only:lex", "lex: auth setup\nlex: authentication config\nlex: login credentials"),
  554. ("auth /only:lex", "lex: auth setup\nvec: how to configure authentication"), # should fail - has vec
  555. ("React hooks /only:vec", "vec: how to use React hooks in functional components\nvec: useState and useEffect patterns in React"),
  556. ("PostgreSQL indexing /only:hyde", "hyde: PostgreSQL uses B-tree indexes by default. Create indexes with CREATE INDEX idx_name ON table(column). EXPLAIN ANALYZE shows whether queries use indexes efficiently."),
  557. ]
  558. for query, expansion in tests:
  559. score = score_expansion(query, expansion)
  560. detail = score_expansion_detailed(query, expansion)
  561. only_mode = detail.get("only_mode")
  562. mode_str = f" [only:{only_mode}]" if only_mode else ""
  563. print(f"\n Query: '{query}'{mode_str}")
  564. print(f" Score: {score:.2f} ({detail['rating']})")
  565. if detail["deductions"]:
  566. print(f" Issues: {', '.join(detail['deductions'][:3])}")