|
|
@@ -38,15 +38,77 @@ from trl import GRPOTrainer, GRPOConfig
|
|
|
|
|
|
STOPWORDS = {'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by'}
|
|
|
KEY_TERM_STOPWORDS = {'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of',
|
|
|
- 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we'}
|
|
|
+ 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we',
|
|
|
+ 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell'}
|
|
|
+
|
|
|
+# Generic filler phrases that should never be in lex queries
|
|
|
+GENERIC_LEX_PHRASES = {
|
|
|
+ 'find information about', 'search for', 'look up', 'get information',
|
|
|
+ 'learn about', 'information on', 'details about', 'find out about',
|
|
|
+ 'what is', 'how to', 'guide to', 'help with'
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+def extract_named_entities(query: str) -> set:
|
|
|
+ """Extract named entities from query using simple heuristics.
|
|
|
+
|
|
|
+ Named entities are:
|
|
|
+ - Capitalized words (except first word which may just be sentence start)
|
|
|
+ - All-caps words/acronyms (TDS, API, GPU)
|
|
|
+ - Technical terms with special chars (node.js, C++, .NET)
|
|
|
+ - Words following acronyms/proper nouns (TDS motorsports -> both words)
|
|
|
+ """
|
|
|
+ entities = set()
|
|
|
+ words = query.split()
|
|
|
+ prev_was_entity = False
|
|
|
+
|
|
|
+ for i, word in enumerate(words):
|
|
|
+ # Clean punctuation but keep internal special chars
|
|
|
+ clean = word.strip('.,!?:;()[]"\'')
|
|
|
+ if not clean:
|
|
|
+ prev_was_entity = False
|
|
|
+ continue
|
|
|
+
|
|
|
+ is_entity = False
|
|
|
+
|
|
|
+ # All-caps words (acronyms): TDS, API, GPU, etc.
|
|
|
+ if clean.isupper() and len(clean) >= 2:
|
|
|
+ entities.add(clean.lower())
|
|
|
+ is_entity = True
|
|
|
+
|
|
|
+ # Capitalized words (not first word, not common words)
|
|
|
+ elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
|
+ entities.add(clean.lower())
|
|
|
+ is_entity = True
|
|
|
+
|
|
|
+ # Technical terms with special chars: node.js, C++, .NET
|
|
|
+ elif any(c in clean for c in '.+-#@') and len(clean) >= 2:
|
|
|
+ entities.add(clean.lower())
|
|
|
+ is_entity = True
|
|
|
+
|
|
|
+ # CamelCase: JavaScript, TypeScript, etc.
|
|
|
+ elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper():
|
|
|
+ entities.add(clean.lower())
|
|
|
+ is_entity = True
|
|
|
+
|
|
|
+ # Word following an entity is likely part of compound name (TDS motorsports)
|
|
|
+ elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS:
|
|
|
+ entities.add(clean.lower())
|
|
|
+ is_entity = True
|
|
|
+
|
|
|
+ prev_was_entity = is_entity
|
|
|
+
|
|
|
+ return entities
|
|
|
|
|
|
|
|
|
def get_key_terms(query: str) -> set:
|
|
|
+ """Get key terms (non-stopwords) from query."""
|
|
|
words = set(query.lower().split())
|
|
|
return words - KEY_TERM_STOPWORDS
|
|
|
|
|
|
|
|
|
def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
|
|
|
+ """Check if lex line preserves key terms from query."""
|
|
|
key_terms = get_key_terms(query)
|
|
|
if not key_terms:
|
|
|
return True
|
|
|
@@ -54,6 +116,28 @@ def lex_preserves_key_terms(lex_line: str, query: str) -> bool:
|
|
|
return bool(key_terms & lex_words)
|
|
|
|
|
|
|
|
|
+def lex_preserves_entities(lex_line: str, entities: set) -> bool:
|
|
|
+ """Check if lex line contains at least one named entity."""
|
|
|
+ if not entities:
|
|
|
+ return True # No entities to preserve
|
|
|
+ lex_lower = lex_line.lower()
|
|
|
+ return any(entity in lex_lower for entity in entities)
|
|
|
+
|
|
|
+
|
|
|
+def lex_is_generic(lex_line: str) -> bool:
|
|
|
+ """Check if lex line is a generic filler phrase."""
|
|
|
+ lex_lower = lex_line.lower().strip()
|
|
|
+ for phrase in GENERIC_LEX_PHRASES:
|
|
|
+ if phrase in lex_lower or lex_lower.startswith(phrase.split()[0]):
|
|
|
+ # Also check if it's ONLY the generic phrase with no specifics
|
|
|
+ remaining = lex_lower
|
|
|
+ for word in phrase.split():
|
|
|
+ remaining = remaining.replace(word, '', 1).strip()
|
|
|
+ if len(remaining) < 3: # Nothing specific left
|
|
|
+ return True
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
def parse_expansion(text: str) -> dict:
|
|
|
lines = text.strip().split("\n")
|
|
|
result = {"lex": [], "vec": [], "hyde": [], "invalid": []}
|
|
|
@@ -193,9 +277,37 @@ def score_expansion(query: str, expansion: str) -> float:
|
|
|
elif lex_with_terms > 0:
|
|
|
quality_score += 2
|
|
|
|
|
|
- total = format_score + diversity_score + hyde_score + quality_score
|
|
|
- max_possible = 100 if parsed["hyde"] else 80
|
|
|
- return total / max_possible
|
|
|
+ # NAMED ENTITY PRESERVATION (critical for quality)
|
|
|
+ # This score can go heavily negative to punish missing entities
|
|
|
+ entity_score = 0
|
|
|
+ entities = extract_named_entities(query)
|
|
|
+ if entities and parsed["lex"]:
|
|
|
+ # Count lex lines that preserve at least one entity
|
|
|
+ lex_with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities))
|
|
|
+ if lex_with_entities == len(parsed["lex"]):
|
|
|
+ entity_score += 15 # All lex lines have entities - great!
|
|
|
+ elif lex_with_entities > 0:
|
|
|
+ entity_score += 5 # Some have entities
|
|
|
+ else:
|
|
|
+ entity_score -= 30 # NO lex lines have entities - HEAVY penalty!
|
|
|
+
|
|
|
+ # Penalize generic filler phrases in lex (these are useless for BM25)
|
|
|
+ generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l))
|
|
|
+ entity_score -= generic_count * 15 # -15 per generic phrase
|
|
|
+
|
|
|
+ # Bonus for entities in vec too (less critical but nice)
|
|
|
+ if parsed["vec"]:
|
|
|
+ vec_with_entities = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities))
|
|
|
+ if vec_with_entities > 0:
|
|
|
+ entity_score += 5
|
|
|
+ elif not entities:
|
|
|
+ # No entities in query - give base score
|
|
|
+ entity_score = 10
|
|
|
+
|
|
|
+ # Entity score CAN go negative to heavily penalize missing entities
|
|
|
+ total = format_score + diversity_score + hyde_score + quality_score + entity_score
|
|
|
+ max_possible = 120 if parsed["hyde"] else 100
|
|
|
+ return max(0.0, min(1.0, total / max_possible)) # Clamp to 0.0-1.0
|
|
|
|
|
|
|
|
|
def extract_query_from_prompt(prompt: str) -> str:
|
|
|
@@ -231,10 +343,36 @@ def main():
|
|
|
print("GRPO Training Configuration:")
|
|
|
print(yaml.dump(cfg, default_flow_style=False))
|
|
|
print("\nTesting reward function...")
|
|
|
+
|
|
|
+ # Test 1: Basic query
|
|
|
test_good = "lex: auth setup\nlex: authentication config\nvec: how to configure authentication\nhyde: Configure auth by setting AUTH_SECRET."
|
|
|
test_bad = "auth is important for security"
|
|
|
- print(f" Good output score: {score_expansion('auth', test_good):.2f}")
|
|
|
- print(f" Bad output score: {score_expansion('auth', test_bad):.2f}")
|
|
|
+ print(f"\n Query: 'auth'")
|
|
|
+ print(f" Good output score: {score_expansion('auth', test_good):.2f}")
|
|
|
+ print(f" Bad output score: {score_expansion('auth', test_bad):.2f}")
|
|
|
+
|
|
|
+ # Test 2: Named entity query (the critical case!)
|
|
|
+ query_entity = "who is TDS motorsports"
|
|
|
+ entities = extract_named_entities(query_entity)
|
|
|
+ print(f"\n Query: '{query_entity}'")
|
|
|
+ print(f" Extracted entities: {entities}")
|
|
|
+
|
|
|
+ good_entity = "lex: TDS motorsports history\nlex: TDS motorsports founders\nvec: information about TDS motorsports company"
|
|
|
+ bad_entity = "lex: find information about\nlex: company details\nvec: who is this company"
|
|
|
+ print(f" Good (preserves entity): {score_expansion(query_entity, good_entity):.2f}")
|
|
|
+ print(f" Bad (generic phrases): {score_expansion(query_entity, bad_entity):.2f}")
|
|
|
+
|
|
|
+ # Test 3: Technical term
|
|
|
+ query_tech = "how to use React hooks"
|
|
|
+ entities_tech = extract_named_entities(query_tech)
|
|
|
+ print(f"\n Query: '{query_tech}'")
|
|
|
+ print(f" Extracted entities: {entities_tech}")
|
|
|
+
|
|
|
+ good_tech = "lex: React hooks tutorial\nlex: useEffect useState\nvec: how to use React hooks in functional components"
|
|
|
+ bad_tech = "lex: programming tutorial\nlex: how to code\nvec: learn web development"
|
|
|
+ print(f" Good (preserves React): {score_expansion(query_tech, good_tech):.2f}")
|
|
|
+ print(f" Bad (generic): {score_expansion(query_tech, bad_tech):.2f}")
|
|
|
+
|
|
|
return
|
|
|
|
|
|
# Login
|