run.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # /// script
  2. # requires-python = ">=3.10"
  3. # dependencies = [
  4. # "transformers>=4.45.0",
  5. # "peft>=0.7.0",
  6. # "torch",
  7. # "huggingface_hub",
  8. # "accelerate",
  9. # ]
  10. # ///
  11. """
  12. Generate query expansions from a model and save to JSONL.
  13. Usage:
  14. uv run evals/run.py --model tobil/qmd-query-expansion-0.6B-v4
  15. uv run evals/run.py --model ./local-model --queries evals/queries.txt
  16. uv run evals/run.py --model tobil/qmd-query-expansion-0.6B-v4 --output results.jsonl
  17. """
  18. import argparse
  19. import json
  20. import sys
  21. from datetime import datetime
  22. from pathlib import Path
  23. import torch
  24. from peft import PeftModel
  25. from transformers import AutoModelForCausalLM, AutoTokenizer
  26. def load_queries(path: str) -> list[str]:
  27. """Load queries from file, one per line, ignoring comments."""
  28. queries = []
  29. with open(path) as f:
  30. for line in f:
  31. line = line.strip()
  32. if line and not line.startswith("#"):
  33. queries.append(line)
  34. return queries
  35. def load_model(model_path: str, base_model: str = None):
  36. """
  37. Load the model (supports Hub URLs, local adapters, and merged models).
  38. Args:
  39. model_path: HF Hub ID (tobil/model), local adapter dir, or merged model dir
  40. base_model: Base model for tokenizer (auto-detected if not provided)
  41. """
  42. model_path = Path(model_path) if not model_path.startswith(("http", "hf://")) and "/" not in model_path[:20] else model_path
  43. is_local = isinstance(model_path, Path) and model_path.exists()
  44. # Check if it's an adapter or full model
  45. is_adapter = False
  46. if is_local:
  47. adapter_config = model_path / "adapter_config.json"
  48. is_adapter = adapter_config.exists()
  49. if is_adapter and not base_model:
  50. # Read base model from adapter config
  51. import json
  52. with open(adapter_config) as f:
  53. config = json.load(f)
  54. base_model = config.get("base_model_name_or_path", "Qwen/Qwen3-0.6B")
  55. else:
  56. # For Hub models, assume adapter unless --base-model not provided
  57. is_adapter = base_model is not None
  58. # Default base model
  59. if not base_model:
  60. base_model = "Qwen/Qwen3-0.6B"
  61. print(f"Loading tokenizer from {base_model}...", file=sys.stderr)
  62. tokenizer = AutoTokenizer.from_pretrained(base_model)
  63. if tokenizer.pad_token is None:
  64. tokenizer.pad_token = tokenizer.eos_token
  65. if is_adapter:
  66. print(f"Loading base model {base_model}...", file=sys.stderr)
  67. base = AutoModelForCausalLM.from_pretrained(
  68. base_model,
  69. torch_dtype=torch.bfloat16,
  70. device_map="auto",
  71. )
  72. print(f"Loading adapter from {model_path}...", file=sys.stderr)
  73. model = PeftModel.from_pretrained(base, str(model_path))
  74. else:
  75. print(f"Loading model from {model_path}...", file=sys.stderr)
  76. model = AutoModelForCausalLM.from_pretrained(
  77. str(model_path),
  78. torch_dtype=torch.bfloat16,
  79. device_map="auto",
  80. )
  81. model.eval()
  82. return model, tokenizer
  83. def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200) -> str:
  84. """Generate query expansion using Qwen3 chat template with /no_think."""
  85. messages = [{"role": "user", "content": f"/no_think Expand this search query: {query}"}]
  86. prompt = tokenizer.apply_chat_template(
  87. messages,
  88. tokenize=False,
  89. add_generation_prompt=True
  90. )
  91. inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
  92. with torch.no_grad():
  93. outputs = model.generate(
  94. **inputs,
  95. max_new_tokens=max_new_tokens,
  96. temperature=0.7,
  97. do_sample=True,
  98. pad_token_id=tokenizer.pad_token_id,
  99. eos_token_id=tokenizer.eos_token_id,
  100. )
  101. full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  102. # Extract assistant response
  103. if "\nassistant\n" in full_output:
  104. expansion = full_output.split("\nassistant\n")[-1].strip()
  105. elif "assistant\n" in full_output:
  106. expansion = full_output.split("assistant\n")[-1].strip()
  107. else:
  108. expansion = full_output[len(prompt):].strip()
  109. # Remove any <think> tags
  110. if expansion.startswith("<think>"):
  111. think_end = expansion.find("</think>")
  112. if think_end != -1:
  113. expansion = expansion[think_end + 8:].strip()
  114. return expansion
  115. def main():
  116. parser = argparse.ArgumentParser(description="Generate query expansions")
  117. parser.add_argument("--model", required=True, help="Model path (Hub or local)")
  118. parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B", help="Base model for tokenizer")
  119. parser.add_argument("--queries", default="evals/queries.txt", help="Queries file")
  120. parser.add_argument("--output", help="Output JSONL file (default: evals/results_{model_name}.jsonl)")
  121. parser.add_argument("--max-tokens", type=int, default=200, help="Max tokens to generate")
  122. args = parser.parse_args()
  123. # Determine output path
  124. if args.output:
  125. output_path = Path(args.output)
  126. else:
  127. model_name = args.model.replace("/", "_").replace(".", "_")
  128. output_path = Path(f"evals/results_{model_name}.jsonl")
  129. # Load queries
  130. queries = load_queries(args.queries)
  131. print(f"Loaded {len(queries)} queries from {args.queries}", file=sys.stderr)
  132. # Load model
  133. model, tokenizer = load_model(args.model, args.base_model)
  134. # Generate expansions
  135. print(f"Generating expansions...", file=sys.stderr)
  136. results = []
  137. for i, query in enumerate(queries, 1):
  138. print(f"[{i}/{len(queries)}] {query}", file=sys.stderr)
  139. expansion = generate_expansion(model, tokenizer, query, args.max_tokens)
  140. results.append({
  141. "query": query,
  142. "expansion": expansion,
  143. })
  144. # Write results
  145. output_path.parent.mkdir(parents=True, exist_ok=True)
  146. with open(output_path, "w") as f:
  147. # Write metadata as first line
  148. metadata = {
  149. "_meta": True,
  150. "model": args.model,
  151. "base_model": args.base_model,
  152. "timestamp": datetime.now().isoformat(),
  153. "num_queries": len(queries),
  154. }
  155. f.write(json.dumps(metadata) + "\n")
  156. # Write results
  157. for result in results:
  158. f.write(json.dumps(result) + "\n")
  159. print(f"Results saved to: {output_path}", file=sys.stderr)
  160. print(str(output_path)) # Print path to stdout for piping
  161. if __name__ == "__main__":
  162. main()