|
|
@@ -39,31 +39,30 @@ def load_queries(path: str) -> list[str]:
|
|
|
return queries
|
|
|
|
|
|
|
|
|
-def load_model(model_path: str, base_model: str = None):
|
|
|
+def load_model(model_path: str, base_model: str = None, sft_model: str = None):
|
|
|
"""
|
|
|
Load the model (supports Hub URLs, local adapters, and merged models).
|
|
|
|
|
|
Args:
|
|
|
model_path: HF Hub ID (tobil/model), local adapter dir, or merged model dir
|
|
|
base_model: Base model for tokenizer (auto-detected if not provided)
|
|
|
+ sft_model: SFT adapter to load first (for GRPO models that need SFT base)
|
|
|
"""
|
|
|
- model_path = Path(model_path) if not model_path.startswith(("http", "hf://")) and "/" not in model_path[:20] else model_path
|
|
|
- is_local = isinstance(model_path, Path) and model_path.exists()
|
|
|
+ model_path_str = str(model_path)
|
|
|
+ is_local = Path(model_path_str).exists()
|
|
|
|
|
|
# Check if it's an adapter or full model
|
|
|
is_adapter = False
|
|
|
if is_local:
|
|
|
- adapter_config = model_path / "adapter_config.json"
|
|
|
- is_adapter = adapter_config.exists()
|
|
|
+ adapter_config_path = Path(model_path_str) / "adapter_config.json"
|
|
|
+ is_adapter = adapter_config_path.exists()
|
|
|
if is_adapter and not base_model:
|
|
|
- # Read base model from adapter config
|
|
|
- import json
|
|
|
- with open(adapter_config) as f:
|
|
|
+ with open(adapter_config_path) as f:
|
|
|
config = json.load(f)
|
|
|
base_model = config.get("base_model_name_or_path", "Qwen/Qwen3-0.6B")
|
|
|
else:
|
|
|
- # For Hub models, assume adapter unless --base-model not provided
|
|
|
- is_adapter = base_model is not None
|
|
|
+ # For Hub models, assume adapter
|
|
|
+ is_adapter = True
|
|
|
|
|
|
# Default base model
|
|
|
if not base_model:
|
|
|
@@ -74,22 +73,22 @@ def load_model(model_path: str, base_model: str = None):
|
|
|
if tokenizer.pad_token is None:
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
+ print(f"Loading base model {base_model}...", file=sys.stderr)
|
|
|
+ model = AutoModelForCausalLM.from_pretrained(
|
|
|
+ base_model,
|
|
|
+ torch_dtype=torch.bfloat16,
|
|
|
+ device_map="auto",
|
|
|
+ )
|
|
|
+
|
|
|
+ # If SFT model specified, load and merge it first (for GRPO models)
|
|
|
+ if sft_model:
|
|
|
+ print(f"Loading SFT adapter {sft_model} and merging...", file=sys.stderr)
|
|
|
+ model = PeftModel.from_pretrained(model, sft_model)
|
|
|
+ model = model.merge_and_unload()
|
|
|
+
|
|
|
if is_adapter:
|
|
|
- print(f"Loading base model {base_model}...", file=sys.stderr)
|
|
|
- base = AutoModelForCausalLM.from_pretrained(
|
|
|
- base_model,
|
|
|
- torch_dtype=torch.bfloat16,
|
|
|
- device_map="auto",
|
|
|
- )
|
|
|
- print(f"Loading adapter from {model_path}...", file=sys.stderr)
|
|
|
- model = PeftModel.from_pretrained(base, str(model_path))
|
|
|
- else:
|
|
|
- print(f"Loading model from {model_path}...", file=sys.stderr)
|
|
|
- model = AutoModelForCausalLM.from_pretrained(
|
|
|
- str(model_path),
|
|
|
- torch_dtype=torch.bfloat16,
|
|
|
- device_map="auto",
|
|
|
- )
|
|
|
+ print(f"Loading adapter from {model_path_str}...", file=sys.stderr)
|
|
|
+ model = PeftModel.from_pretrained(model, model_path_str)
|
|
|
|
|
|
model.eval()
|
|
|
return model, tokenizer
|
|
|
@@ -138,7 +137,8 @@ def generate_expansion(model, tokenizer, query: str, max_new_tokens: int = 200)
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="Generate query expansions")
|
|
|
parser.add_argument("--model", required=True, help="Model path (Hub or local)")
|
|
|
- parser.add_argument("--base-model", default="Qwen/Qwen3-0.6B", help="Base model for tokenizer")
|
|
|
+ parser.add_argument("--base-model", default=None, help="Base model for tokenizer")
|
|
|
+ parser.add_argument("--sft-model", default=None, help="SFT adapter to load first (for GRPO models)")
|
|
|
parser.add_argument("--queries", default="evals/queries.txt", help="Queries file")
|
|
|
parser.add_argument("--output", help="Output JSONL file (default: evals/results_{model_name}.jsonl)")
|
|
|
parser.add_argument("--max-tokens", type=int, default=200, help="Max tokens to generate")
|
|
|
@@ -156,7 +156,7 @@ def main():
|
|
|
print(f"Loaded {len(queries)} queries from {args.queries}", file=sys.stderr)
|
|
|
|
|
|
# Load model
|
|
|
- model, tokenizer = load_model(args.model, args.base_model)
|
|
|
+ model, tokenizer = load_model(args.model, args.base_model, args.sft_model)
|
|
|
|
|
|
# Generate expansions
|
|
|
print(f"Generating expansions...", file=sys.stderr)
|