eval_verbose.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "huggingface_hub>=0.20.0",
  8. # "accelerate",
  9. # ]
  10. # ///
  11. """
  12. Verbose eval: prints the actual expansions for every query.
  13. hf jobs uv run --flavor a10g-small --secrets HF_TOKEN --timeout 30m jobs/eval_verbose.py
  14. """
  15. import os
  16. import re
  17. import sys
  18. from collections import Counter
  19. import torch
  20. from huggingface_hub import login
  21. from peft import PeftModel
  22. from transformers import AutoModelForCausalLM, AutoTokenizer
  23. BASE_MODEL = "Qwen/Qwen3-1.7B"
  24. SFT_MODEL = "tobil/qmd-query-expansion-1.7B-sft"
  25. GRPO_MODEL = "tobil/qmd-query-expansion-1.7B-grpo"
  26. QUERIES = [
  27. "how to configure authentication",
  28. "typescript async await",
  29. "docker compose networking",
  30. "git rebase vs merge",
  31. "react useEffect cleanup",
  32. "auth",
  33. "config",
  34. "setup",
  35. "api",
  36. "who is TDS motorsports",
  37. "React hooks tutorial",
  38. "Docker container networking",
  39. "Kubernetes pod deployment",
  40. "AWS Lambda functions",
  41. "meeting notes project kickoff",
  42. "ideas for new feature",
  43. "todo list app architecture",
  44. "what is dependency injection",
  45. "difference between sql and nosql",
  46. "kubernetes vs docker swarm",
  47. "connection timeout error",
  48. "memory leak debugging",
  49. "cors error fix",
  50. "recent news about Shopify",
  51. "latest AI developments",
  52. "best laptops right now",
  53. "what changed in kubernetes latest version",
  54. "how to implement caching with redis in nodejs",
  55. "best practices for api rate limiting",
  56. "setting up ci cd pipeline with github actions",
  57. ]
  58. def load_model(base, sft=None, grpo=None):
  59. tokenizer = AutoTokenizer.from_pretrained(base)
  60. if tokenizer.pad_token is None:
  61. tokenizer.pad_token = tokenizer.eos_token
  62. model = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16, device_map="auto")
  63. if sft:
  64. model = PeftModel.from_pretrained(model, sft)
  65. model = model.merge_and_unload()
  66. if grpo:
  67. model = PeftModel.from_pretrained(model, grpo)
  68. model.eval()
  69. return model, tokenizer
  70. def generate(model, tokenizer, query):
  71. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  72. prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  73. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  74. with torch.no_grad():
  75. out = model.generate(**inputs, max_new_tokens=200, temperature=0.7, do_sample=True,
  76. pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
  77. text = tokenizer.decode(out[0], skip_special_tokens=True)
  78. if "\nassistant\n" in text:
  79. text = text.split("\nassistant\n")[-1].strip()
  80. elif "assistant\n" in text:
  81. text = text.split("assistant\n")[-1].strip()
  82. if "<think>" in text:
  83. text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
  84. return text
  85. def main():
  86. hf_token = os.environ.get("HF_TOKEN")
  87. if hf_token:
  88. login(token=hf_token)
  89. print("Loading GRPO model...", file=sys.stderr)
  90. model, tokenizer = load_model(BASE_MODEL, sft=SFT_MODEL, grpo=GRPO_MODEL)
  91. for i, query in enumerate(QUERIES, 1):
  92. expansion = generate(model, tokenizer, query)
  93. print(f"\n{'='*60}")
  94. print(f"[{i}/{len(QUERIES)}] {query}")
  95. print(f"{'─'*60}")
  96. print(expansion)
  97. if __name__ == "__main__":
  98. main()