run.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "huggingface_hub",
  8. # "accelerate",
  9. # ]
  10. # ///
  11. """
  12. Generate query expansions from a model and save to JSONL.
  13. Usage:
  14. uv run evals/run.py --model tobil/qmd-query-expansion-0.6B-v4
  15. uv run evals/run.py --model ./local-model --queries evals/queries.txt
  16. uv run evals/run.py --model tobil/qmd-query-expansion-0.6B-v4 --output results.jsonl
  17. """
  18. import argparse
  19. import json
  20. import sys
  21. from datetime import datetime
  22. from pathlib import Path
  23. import torch
  24. from peft import PeftModel
  25. from transformers import AutoModelForCausalLM, AutoTokenizer
  26. def load_queries(path: str) -> list[str]:
  27. """Load queries from file, one per line, ignoring comments."""
  28. queries = []
  29. with open(path) as f:
  30. for line in f:
  31. line = line.strip()
  32. if line and not line.startswith("#"):
  33. queries.append(line)
  34. return queries
  35. def load_model(model_path: str, base_model: str = None, sft_model: str = None):
  36. """
  37. Load the model (supports Hub URLs, local adapters, and merged models).
  38. Args:
  39. model_path: HF Hub ID (tobil/model), local adapter dir, or merged model dir
  40. base_model: Base model for tokenizer (auto-detected if not provided)
  41. sft_model: SFT adapter to load first (for GRPO models that need SFT base)
  42. """
  43. model_path_str = str(model_path)
  44. is_local = Path(model_path_str).exists()
  45. # Check if it's an adapter or full model
  46. is_adapter = False
  47. if is_local:
  48. adapter_config_path = Path(model_path_str) / "adapter_config.json"
  49. is_adapter = adapter_config_path.exists()
  50. if is_adapter and not base_model:
  51. with open(adapter_config_path) as f:
  52. config = json.load(f)
  53. base_model = config.get("base_model_name_or_path", "Qwen/Qwen3-0.6B")
  54. else:
  55. # For Hub models, assume adapter
  56. is_adapter = True
  57. # Default base model
  58. if not base_model:
  59. base_model = "Qwen/Qwen3-0.6B"
  60. print(f"Loading tokenizer from {base_model}...", file=sys.stderr)
  61. tokenizer = AutoTokenizer.from_pretrained(base_model)
  62. if tokenizer.pad_token is None:
  63. tokenizer.pad_token = tokenizer.eos_token
  64. print(f"Loading base model {base_model}...", file=sys.stderr)
  65. model = AutoModelForCausalLM.from_pretrained(
  66. base_model,
  67. torch_dtype=torch.bfloat16,
  68. device_map="auto",
  69. )
  70. # If SFT model specified, load and merge it first (for GRPO models)
  71. if sft_model:
  72. print(f"Loading SFT adapter {sft_model} and merging...", file=sys.stderr)
  73. model = PeftModel.from_pretrained(model, sft_model)
  74. model = model.merge_and_unload()
  75. if is_adapter:
  76. print(f"Loading adapter from {model_path_str}...", file=sys.stderr)
  77. model = PeftModel.from_pretrained(model, model_path_str)
  78. model.eval()
  79. return model, tokenizer
  80. def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200) -> str:
  81. """Generate query expansion using Qwen3 chat template with /no_think."""
  82. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  83. prompt = tokenizer.apply_chat_template(
  84. messages,
  85. tokenize=False,
  86. add_generation_prompt=True
  87. )
  88. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  89. with torch.no_grad():
  90. outputs = model.generate(
  91. **inputs,
  92. max_new_tokens=max_new_tokens,
  93. temperature=0.7,
  94. do_sample=True,
  95. pad_token_id=tokenizer.pad_token_id,
  96. eos_token_id=tokenizer.eos_token_id,
  97. )
  98. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  99. # Extract assistant response
  100. if "\nassistant\n" in full_output:
  101. expansion = full_output.split("\nassistant\n")[-1].strip()
  102. elif "assistant\n" in full_output:
  103. expansion = full_output.split("assistant\n")[-1].strip()
  104. else:
  105. expansion = full_output[len(prompt):].strip()
  106. # Remove any <think> tags
  107. if expansion.startswith("<think>"):
  108. think_end = expansion.find("</think>")
  109. if think_end != -1:
  110. expansion = expansion[think_end + 8:].strip()
  111. return expansion
  112. def main():
  113. parser = argparse.ArgumentParser(description="Generate query expansions")
  114. parser.add_argument("--model", required=True, help="Model path (Hub or local)")
  115. parser.add_argument("--base-model", default=None, help="Base model for tokenizer")
  116. parser.add_argument("--sft-model", default=None, help="SFT adapter to load first (for GRPO models)")
  117. parser.add_argument("--queries", default="evals/queries.txt", help="Queries file")
  118. parser.add_argument("--output", help="Output JSONL file (default: evals/results_{model_name}.jsonl)")
  119. parser.add_argument("--max-tokens", type=int, default=200, help="Max tokens to generate")
  120. args = parser.parse_args()
  121. # Determine output path
  122. if args.output:
  123. output_path = Path(args.output)
  124. else:
  125. model_name = args.model.replace("/", "_").replace(".", "_")
  126. output_path = Path(f"evals/results_{model_name}.jsonl")
  127. # Load queries
  128. queries = load_queries(args.queries)
  129. print(f"Loaded {len(queries)} queries from {args.queries}", file=sys.stderr)
  130. # Load model
  131. model, tokenizer = load_model(args.model, args.base_model, args.sft_model)
  132. # Generate expansions
  133. print(f"Generating expansions...", file=sys.stderr)
  134. results = []
  135. for i, query in enumerate(queries, 1):
  136. print(f"[{i}/{len(queries)}] {query}", file=sys.stderr)
  137. expansion = generate_expansion(model, tokenizer, query, args.max_tokens)
  138. results.append({
  139. "query": query,
  140. "expansion": expansion,
  141. })
  142. # Write results
  143. output_path.parent.mkdir(parents=True, exist_ok=True)
  144. with open(output_path, "w") as f:
  145. # Write metadata as first line
  146. metadata = {
  147. "_meta": True,
  148. "model": args.model,
  149. "base_model": args.base_model,
  150. "timestamp": datetime.now().isoformat(),
  151. "num_queries": len(queries),
  152. }
  153. f.write(json.dumps(metadata) + "\n")
  154. # Write results
  155. for result in results:
  156. f.write(json.dumps(result) + "\n")
  157. print(f"Results saved to: {output_path}", file=sys.stderr)
  158. print(str(output_path)) # Print path to stdout for piping
  159. if __name__ == "__main__":
  160. main()