# /// script # requires-python = ">=3.10" # dependencies = [ # "transformers>=4.45.0", # "peft>=0.7.0", # "torch", # "huggingface_hub>=0.20.0", # "accelerate", # ] # /// """ Evaluate QMD query expansion models on HuggingFace Jobs. Self-contained script — inlines the reward function and test queries. hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval.py -- --sft-only """ import argparse import csv import io import json import os import re import sys from collections import Counter import torch from huggingface_hub import HfApi, login from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer # --- Config --- BASE_MODEL = "Qwen/Qwen3-1.7B" SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft" GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo" # --- Test queries (inlined from evals/queries.txt) --- QUERIES = [ # Technical documentation "how to configure authentication", "typescript async await", "docker compose networking", "git rebase vs merge", "react useEffect cleanup", # Short/ambiguous "auth", "config", "setup", "api", # Named entities "who is TDS motorsports", "React hooks tutorial", "Docker container networking", "Kubernetes pod deployment", "AWS Lambda functions", # Personal notes / journals "meeting notes project kickoff", "ideas for new feature", "todo list app architecture", # Research / learning "what is dependency injection", "difference between sql and nosql", "kubernetes vs docker swarm", # Error/debugging "connection timeout error", "memory leak debugging", "cors error fix", # Temporal / recency "recent news about Shopify", "latest AI developments", "best laptops right now", "what changed in kubernetes latest version", # Complex "how to implement caching with redis in nodejs", "best practices for api rate limiting", "setting up ci cd pipeline with github actions", ] # ============================================================================= # Reward function (inlined from reward.py) # ============================================================================= STOPWORDS = frozenset({ 'the', 'a', 'an', 'is', 'are', 'to', 'for', 'of', 'in', 'and', 'or', 'it', 'this', 'that', 'be', 'with', 'as', 'on', 'by', }) KEY_TERM_STOPWORDS = frozenset({ 'what', 'is', 'how', 'to', 'the', 'a', 'an', 'in', 'on', 'for', 'of', 'and', 'or', 'with', 'my', 'your', 'do', 'does', 'can', 'i', 'me', 'we', 'who', 'where', 'when', 'why', 'which', 'find', 'get', 'show', 'tell', }) GENERIC_LEX_PHRASES = frozenset({ 'find information about', 'search for', 'look up', 'get information', 'learn about', 'information on', 'details about', 'find out about', 'what is', 'how to', 'guide to', 'help with', }) CHAT_TEMPLATE_TOKENS = frozenset({ '<|im_start|>', '<|im_end|>', '<|endoftext|>', '\nassistant\n', '\nuser\n', }) def parse_expansion(text): result = {"lex": [], "vec": [], "hyde": [], "invalid": []} for line in text.strip().split("\n"): line = line.strip() if not line: continue if line.startswith("lex:"): result["lex"].append(line[4:].strip()) elif line.startswith("vec:"): result["vec"].append(line[4:].strip()) elif line.startswith("hyde:"): result["hyde"].append(line[5:].strip()) else: result["invalid"].append(line) return result def clean_model_output(text): text = text.replace('<|im_end|>', '').strip() used_thinking = '' in text and '' in text if used_thinking: text = re.sub(r'.*?', '', text, flags=re.DOTALL).strip() return text, used_thinking def extract_named_entities(query): entities = set() words = query.split() prev_was_entity = False for i, word in enumerate(words): clean = word.strip('.,!?:;()[]"\'') if not clean: prev_was_entity = False continue is_entity = False if clean.isupper() and len(clean) >= 2: entities.add(clean.lower()); is_entity = True elif i > 0 and clean[0].isupper() and clean.lower() not in KEY_TERM_STOPWORDS: entities.add(clean.lower()); is_entity = True elif any(c in clean for c in '.+-#@') and len(clean) >= 2: entities.add(clean.lower()); is_entity = True elif len(clean) > 1 and any(c.isupper() for c in clean[1:]) and clean[0].isupper(): entities.add(clean.lower()); is_entity = True elif prev_was_entity and clean.lower() not in KEY_TERM_STOPWORDS: entities.add(clean.lower()); is_entity = True prev_was_entity = is_entity return entities def get_key_terms(query): return set(query.lower().split()) - KEY_TERM_STOPWORDS def lex_preserves_key_terms(lex_line, query): key_terms = get_key_terms(query) return not key_terms or bool(key_terms & set(lex_line.lower().split())) def lex_preserves_entities(line, entities): if not entities: return True return any(e in line.lower() for e in entities) def lex_is_generic(lex_line): lower = lex_line.lower().strip() for phrase in GENERIC_LEX_PHRASES: if phrase in lower or lower.startswith(phrase.split()[0]): remaining = lower for word in phrase.split(): remaining = remaining.replace(word, '', 1).strip() if len(remaining) < 3: return True return False def word_set_distance(a, b): return len(set(a.lower().split()) ^ set(b.lower().split())) def is_diverse(a, b, min_distance=2): a, b = a.lower().strip(), b.lower().strip() if a == b or a in b or b in a: return False return word_set_distance(a, b) >= min_distance def echoes_query(expansion, query): exp, q = expansion.lower().strip(), query.lower().strip() return exp == q or (q in exp and len(exp) < len(q) + 10) def word_repetition_penalty(text): counts = Counter(re.findall(r'\b\w+\b', text.lower())) return sum((c - 2) * 2 for w, c in counts.items() if c >= 3 and w not in STOPWORDS and len(w) > 2) def score_expansion_detailed(query, expansion): text, used_thinking = clean_model_output(expansion.strip()) deductions = [] def _fail(reason): return { "format": 0, "diversity": 0, "hyde": 0, "quality": 0, "entity": 0, "think_bonus": 0, "total": 0, "max_possible": 100, "percentage": 0.0, "rating": "Failed", "deductions": [reason], } if any(tok in text for tok in CHAT_TEMPLATE_TOKENS): return _fail("CHAT TEMPLATE LEAKAGE") for line in text.split("\n"): line = line.strip() if line and not line.startswith(("lex:", "vec:", "hyde:")): return _fail(f"INVALID LINE: {line[:50]}") parsed = parse_expansion(text) format_score = 10 if parsed["lex"]: format_score += 10 else: deductions.append("missing lex:") if parsed["vec"]: format_score += 10 else: deductions.append("missing vec:") diversity_score = 0 types_present = sum(1 for t in ("lex", "vec") if parsed[t]) if types_present >= 2: diversity_score += 10 if len(parsed["lex"]) + len(parsed["vec"]) >= 2: diversity_score += 5 lex_div = 5 for i, a in enumerate(parsed["lex"]): for b in parsed["lex"][i+1:]: if not is_diverse(a, b, 2): lex_div -= 2 diversity_score += max(0, lex_div) vec_div = 5 for i, a in enumerate(parsed["vec"]): for b in parsed["vec"][i+1:]: if not is_diverse(a, b, 3): vec_div -= 2 diversity_score += max(0, vec_div) echo = 5 for exp in parsed["lex"] + parsed["vec"]: if echoes_query(exp, query): echo -= 3 diversity_score += max(0, echo) hyde_score = 0 if parsed["hyde"]: hyde_text = parsed["hyde"][0] hyde_score += 5 hyde_len = len(hyde_text) if 50 <= hyde_len <= 200: hyde_score += 5 elif hyde_len < 50: hyde_score += 2 if "\n" not in hyde_text: hyde_score += 5 hyde_score += max(0, 5 - word_repetition_penalty(hyde_text)) quality_score = 5 if parsed["lex"] and parsed["vec"]: avg_lex = sum(len(l) for l in parsed["lex"]) / len(parsed["lex"]) avg_vec = sum(len(v) for v in parsed["vec"]) / len(parsed["vec"]) if avg_lex <= avg_vec: quality_score += 5 if parsed["vec"]: natural = sum(1 for v in parsed["vec"] if " " in v and len(v) > 15) quality_score += 5 if natural == len(parsed["vec"]) else 2 if parsed["lex"]: with_terms = sum(1 for l in parsed["lex"] if lex_preserves_key_terms(l, query)) if with_terms == len(parsed["lex"]): quality_score += 5 elif with_terms > 0: quality_score += 2 entity_score = 0 entities = extract_named_entities(query) if entities and parsed["lex"]: with_entities = sum(1 for l in parsed["lex"] if lex_preserves_entities(l, entities)) if with_entities == len(parsed["lex"]): entity_score += 15 elif with_entities > 0: entity_score += 5 else: entity_score -= 30 generic_count = sum(1 for l in parsed["lex"] if lex_is_generic(l)) if generic_count: entity_score -= generic_count * 15 if parsed["vec"]: vec_with = sum(1 for v in parsed["vec"] if lex_preserves_entities(v, entities)) if vec_with > 0: entity_score += 5 elif not entities: entity_score = 10 think_bonus = 0 if used_thinking else 20 total = format_score + diversity_score + hyde_score + quality_score + entity_score + think_bonus max_possible = 140 if parsed["hyde"] else 120 percentage = max(0.0, min(100.0, total / max_possible * 100)) if percentage >= 80: rating = "Excellent" elif percentage >= 60: rating = "Good" elif percentage >= 40: rating = "Acceptable" elif percentage >= 20: rating = "Poor" else: rating = "Failed" return { "format": format_score, "diversity": diversity_score, "hyde": hyde_score, "quality": quality_score, "entity": max(0, entity_score), "think_bonus": think_bonus, "total": max(0, total), "max_possible": max_possible, "percentage": round(percentage, 1), "rating": rating, "deductions": deductions, "entities_detected": list(entities) if entities else [], } # ============================================================================= # Model loading and generation # ============================================================================= def load_model(base, sft=None, grpo=None): print(f"Loading tokenizer from {base}...") tokenizer = AutoTokenizer.from_pretrained(base) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Loading base model {base}...") model = AutoModelForCausalLM.from_pretrained( base, torch_dtype=torch.bfloat16, device_map="auto", ) if sft: print(f"Loading and merging SFT adapter {sft}...") model = PeftModel.from_pretrained(model, sft) model = model.merge_and_unload() if grpo: print(f"Loading GRPO adapter {grpo}...") model = PeftModel.from_pretrained(model, grpo) model.eval() return model, tokenizer def generate_expansion(model, tokenizer, query, max_new_tokens=200): messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=0.7, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, ) full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) if "\nassistant\n" in full_output: expansion = full_output.split("\nassistant\n")[-1].strip() elif "assistant\n" in full_output: expansion = full_output.split("assistant\n")[-1].strip() else: expansion = full_output[len(prompt):].strip() if "" in expansion: expansion = re.sub(r'.*?', '', expansion, flags=re.DOTALL).strip() return expansion # ============================================================================= # Main # ============================================================================= def results_to_csv(results, label): """Convert eval results to CSV string.""" buf = io.StringIO() writer = csv.writer(buf) writer.writerow([ "model", "query", "expansion", "score_pct", "rating", "format", "diversity", "hyde", "quality", "entity", "think_bonus", "total", "max_possible", "deductions", ]) for r in results: s = r["scores"] writer.writerow([ label, r["query"], r["expansion"], s["percentage"], s["rating"], s["format"], s["diversity"], s["hyde"], s["quality"], s["entity"], s["think_bonus"], s["total"], s["max_possible"], "; ".join(s.get("deductions", [])), ]) return buf.getvalue() def upload_csv(results, label, repo_id, api): """Upload eval results CSV to HuggingFace Hub.""" csv_data = results_to_csv(results, label) tag = label.split("/")[-1].replace(" ", "_").lower() filename = f"eval_{tag}.csv" print(f" Uploading {filename} to {repo_id}...") api.upload_file( path_or_fileobj=csv_data.encode("utf-8"), path_in_repo=filename, repo_id=repo_id, repo_type="model", ) print(f" Uploaded: https://huggingface.co/{repo_id}/blob/main/{filename}") def evaluate_model(model, tokenizer, label): print(f"\n{'='*70}") print(f" EVALUATING: {label}") print(f"{'='*70}") results = [] for i, query in enumerate(QUERIES, 1): expansion = generate_expansion(model, tokenizer, query) scores = score_expansion_detailed(query, expansion) results.append({"query": query, "expansion": expansion, "scores": scores}) marker = "+" if scores["percentage"] >= 80 else "-" if scores["percentage"] < 60 else "~" print(f" [{marker}] {i:2d}/{len(QUERIES)} {scores['percentage']:5.1f}% {scores['rating']:10s} {query}") avg = sum(r["scores"]["percentage"] for r in results) / len(results) ratings = Counter(r["scores"]["rating"] for r in results) print(f"\n {'─'*50}") print(f" Average score: {avg:.1f}%") print(f" Ratings:") for rating in ["Excellent", "Good", "Acceptable", "Poor", "Failed"]: count = ratings.get(rating, 0) if count > 0: print(f" {rating:10s}: {count:2d} {'█' * count}") # Show worst queries worst = sorted(results, key=lambda r: r["scores"]["percentage"])[:5] print(f"\n Bottom 5:") for r in worst: print(f" {r['scores']['percentage']:5.1f}% {r['query']}") if r["scores"]["deductions"]: print(f" {', '.join(r['scores']['deductions'][:3])}") return results, avg def main(): parser = argparse.ArgumentParser() parser.add_argument("--sft-only", action="store_true", help="Only evaluate SFT model") parser.add_argument("--upload-repo", default="tobil/qmd-query-expansion-evals", help="HF repo to upload CSV results") args = parser.parse_args() hf_token = os.environ.get("HF_TOKEN") if hf_token: login(token=hf_token) api = HfApi() api.create_repo(repo_id=args.upload_repo, repo_type="model", exist_ok=True) # Evaluate SFT model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL) sft_results, sft_avg = evaluate_model(model, tokenizer, f"SFT: {SFT_MODEL}") upload_csv(sft_results, "sft", args.upload_repo, api) if not args.sft_only: # For GRPO: reload base, merge SFT, then load GRPO adapter del model torch.cuda.empty_cache() model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL) grpo_results, grpo_avg = evaluate_model(model, tokenizer, f"GRPO: {GRPO_MODEL}") upload_csv(grpo_results, "grpo", args.upload_repo, api) # Upload combined comparison CSV combined = results_to_csv(sft_results, "sft") + results_to_csv(grpo_results, "grpo").split("\n", 1)[1] api.upload_file( path_or_fileobj=combined.encode("utf-8"), path_in_repo="eval_comparison.csv", repo_id=args.upload_repo, repo_type="model", ) print(f" Uploaded: eval_comparison.csv") # Comparison print(f"\n{'='*70}") print(f" COMPARISON") print(f"{'='*70}") print(f" SFT average: {sft_avg:.1f}%") print(f" GRPO average: {grpo_avg:.1f}%") print(f" Delta: {grpo_avg - sft_avg:+.1f}%") improved = sum(1 for s, g in zip(sft_results, grpo_results) if g["scores"]["percentage"] > s["scores"]["percentage"]) regressed = sum(1 for s, g in zip(sft_results, grpo_results) if g["scores"]["percentage"] < s["scores"]["percentage"]) print(f" Improved: {improved}/{len(QUERIES)}, Regressed: {regressed}/{len(QUERIES)}") if __name__ == "__main__": main()