| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446 |
- #!/usr/bin/env python3
- # /// script
- # requires-python = ">=3.10"
- # dependencies = [
- # "transformers>=4.36.0",
- # "peft>=0.7.0",
- # "torch>=2.0.0",
- # "accelerate>=0.24.0",
- # "huggingface_hub>=0.20.0",
- # "sentencepiece>=0.1.99",
- # "protobuf>=3.20.0",
- # "numpy",
- # "optimum[onnxruntime]",
- # "onnx>=1.15.0",
- # "onnxruntime>=1.17.0",
- # ]
- # ///
- """
- Convert QMD query expansion model to ONNX format for Transformers.js.
- Loads the base model, merges SFT and GRPO adapters, then exports to ONNX
- with quantization for browser deployment via Transformers.js + WebGPU.
- Usage:
- uv run convert_onnx.py --size 1.7B
- uv run convert_onnx.py --size 1.7B --no-upload
- uv run convert_onnx.py --base Qwen/Qwen3-1.7B \
- --sft tobil/qmd-query-expansion-1.7B-sft \
- --grpo tobil/qmd-query-expansion-1.7B-grpo \
- --output tobil/qmd-query-expansion-1.7B-ONNX
- Quantization options:
- --quantize q4 MatMulNBits 4-bit (default, smallest)
- --quantize q8 8-bit dynamic quantization
- --quantize fp16 FP16 (requires GPU export)
- --quantize none No quantization (FP32, ~7GB)
- """
- import argparse
- import json
- import os
- import shutil
- import sys
- from pathlib import Path
- import torch
- from huggingface_hub import HfApi, login
- from peft import PeftModel
- from transformers import AutoModelForCausalLM, AutoTokenizer
- PRESETS = {
- "1.7B": {
- "base": "Qwen/Qwen3-1.7B",
- "sft": "tobil/qmd-query-expansion-1.7B-sft",
- "grpo": "tobil/qmd-query-expansion-1.7B-grpo",
- "output": "tobil/qmd-query-expansion-1.7B-ONNX",
- },
- "4B": {
- "base": "Qwen/Qwen3-4B",
- "sft": "tobil/qmd-query-expansion-4B-sft",
- "grpo": "tobil/qmd-query-expansion-4B-grpo",
- "output": "tobil/qmd-query-expansion-4B-ONNX",
- },
- }
- def merge_adapters(base_model: str, sft_model: str, grpo_model: str) -> tuple:
- """Load base model, merge SFT + GRPO adapters, return (model, tokenizer)."""
- print(f"\nStep 1: Loading base model {base_model}...")
- model = AutoModelForCausalLM.from_pretrained(
- base_model, dtype=torch.float32, trust_remote_code=True,
- )
- print(f"Step 2: Merging SFT adapter {sft_model}...")
- model = PeftModel.from_pretrained(model, sft_model)
- model = model.merge_and_unload()
- print(f"Step 3: Merging GRPO adapter {grpo_model}...")
- model = PeftModel.from_pretrained(model, grpo_model)
- model = model.merge_and_unload()
- tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
- return model, tokenizer
- def export_onnx(model, tokenizer, output_dir: str):
- """Export merged model to ONNX using Optimum."""
- from optimum.exporters.onnx import main_export
- # Save merged model to temp dir first (Optimum needs HF format on disk)
- merged_dir = "/tmp/merged_model_onnx"
- print(f"\nStep 4: Saving merged model to {merged_dir}...")
- model.save_pretrained(merged_dir, safe_serialization=True)
- tokenizer.save_pretrained(merged_dir)
- print(f"\nStep 5: Exporting to ONNX at {output_dir}...")
- # no_post_process=True avoids the 2GB protobuf serialization limit
- # that occurs during tied-weight deduplication on large FP32 models.
- # The exported model still works correctly — the tied weights just
- # aren't deduplicated in the graph, which is fine since we quantize next.
- main_export(
- model_name_or_path=merged_dir,
- output=output_dir,
- task="text-generation-with-past",
- device="cpu",
- fp16=False,
- no_post_process=True,
- )
- # Clean up temp merged dir
- shutil.rmtree(merged_dir, ignore_errors=True)
- def _find_onnx_model(onnx_dir: str) -> Path:
- """Find the main ONNX model file in the output directory."""
- model_path = Path(onnx_dir) / "model.onnx"
- if model_path.exists():
- return model_path
- candidates = list(Path(onnx_dir).glob("*.onnx"))
- if not candidates:
- raise FileNotFoundError(f"No .onnx files found in {onnx_dir}")
- return candidates[0]
- def quantize_onnx(onnx_dir: str, quantize_type: str):
- """Quantize the exported ONNX model using optimum-cli for memory safety."""
- if quantize_type == "none":
- print("\nSkipping quantization (FP32).")
- return
- model_path = _find_onnx_model(onnx_dir)
- print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
- q_dir = Path(onnx_dir) / f"quantized_{quantize_type}"
- q_dir.mkdir(exist_ok=True)
- if quantize_type == "q4":
- # Use optimum-cli which handles external data format and memory properly.
- # Direct MatMulNBitsQuantizer can OOM on large FP32 models.
- cmd = [
- sys.executable, "-m", "optimum.commands.optimum_cli",
- "onnxruntime", "quantize",
- "--onnx_model", onnx_dir,
- "--o", str(q_dir),
- "--avx2",
- ]
- print(f" Running: optimum-cli onnxruntime quantize ...")
- result = subprocess.run(cmd, capture_output=True, text=True)
- if result.returncode != 0:
- print(f" optimum-cli failed, falling back to MatMulNBitsQuantizer...")
- _quantize_q4_direct(model_path, q_dir)
- elif quantize_type == "q8":
- cmd = [
- sys.executable, "-m", "optimum.commands.optimum_cli",
- "onnxruntime", "quantize",
- "--onnx_model", onnx_dir,
- "--o", str(q_dir),
- "--avx2",
- ]
- result = subprocess.run(cmd, capture_output=True, text=True)
- if result.returncode != 0:
- print(f" optimum-cli quantize failed: {result.stderr[:300]}")
- elif quantize_type == "fp16":
- _convert_fp16_external(model_path, q_dir)
- # Report sizes
- for f in sorted(q_dir.glob("*.onnx*")):
- size_mb = f.stat().st_size / (1024 * 1024)
- if size_mb > 1:
- print(f" {f.name}: {size_mb:.1f} MB")
- # Move quantized files back into main dir, replacing originals
- for f in q_dir.iterdir():
- dest = Path(onnx_dir) / f.name
- if dest.exists():
- dest.unlink()
- shutil.move(str(f), str(dest))
- q_dir.rmdir()
- def _quantize_q4_direct(model_path: Path, output_dir: Path):
- """Direct Q4 quantization fallback (may need significant RAM)."""
- from onnxruntime.quantization import matmul_nbits_quantizer
- quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
- model=str(model_path),
- block_size=32,
- is_symmetric=True,
- bits=4,
- )
- quant.process()
- q_path = output_dir / model_path.name
- quant.model.save(str(q_path))
- def _convert_fp16_external(model_path: Path, output_dir: Path):
- """Convert ONNX model to FP16 using onnxconverter-common (handles external data)."""
- print(" Converting to FP16 (external data format)...")
- try:
- from onnxconverter_common import float16
- import onnx
- model = onnx.load(str(model_path), load_external_data=True)
- model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
- fp16_path = output_dir / model_path.name
- onnx.save(model_fp16, str(fp16_path))
- except ImportError:
- print(" onnxconverter-common not available; skipping FP16 conversion.")
- def validate_onnx(onnx_dir: str, base_model: str):
- """Run a sample inference through the ONNX model to verify it works."""
- import onnxruntime as ort
- import numpy as np
- model_path = _find_onnx_model(onnx_dir)
- print(f"\nValidation: loading {model_path.name}...")
- tokenizer = AutoTokenizer.from_pretrained(onnx_dir, trust_remote_code=True)
- session = ort.InferenceSession(
- str(model_path),
- providers=["CPUExecutionProvider"],
- )
- # Tokenize a test prompt
- test_query = "/no_think Expand this search query: distributed consensus"
- chat_prompt = tokenizer.apply_chat_template(
- [{"role": "user", "content": test_query}],
- add_generation_prompt=True,
- tokenize=False,
- )
- inputs = tokenizer(chat_prompt, return_tensors="np")
- input_ids = inputs["input_ids"].astype(np.int64)
- attention_mask = inputs["attention_mask"].astype(np.int64)
- # Build feed dict with all required inputs
- seq_len = input_ids.shape[1]
- feed = {"input_ids": input_ids, "attention_mask": attention_mask}
- # Add position_ids if needed
- all_inputs = {inp.name: inp for inp in session.get_inputs()}
- if "position_ids" in all_inputs:
- feed["position_ids"] = np.arange(seq_len, dtype=np.int64).reshape(1, -1)
- # Initialize past_key_values to zeros if the model expects them
- for name, inp in sorted(all_inputs.items()):
- if name.startswith("past_key_values"):
- shape = []
- for dim in inp.shape:
- shape.append(dim if isinstance(dim, int) else 0)
- # batch dim = 1
- if shape and shape[0] == 0:
- shape[0] = 1
- feed[name] = np.zeros(shape, dtype=np.float32)
- # Run inference
- output_names = [o.name for o in session.get_outputs()]
- results = session.run(output_names, feed)
- # Check logits shape
- logits = results[0]
- print(f" Input tokens: {input_ids.shape[1]}")
- print(f" Output logits shape: {logits.shape}")
- print(f" Logits range: [{logits.min():.2f}, {logits.max():.2f}]")
- # Greedy decode next token
- next_token_id = int(np.argmax(logits[0, -1, :]))
- next_token = tokenizer.decode([next_token_id])
- print(f" Next token: {repr(next_token)} (id={next_token_id})")
- # Check KV cache outputs exist
- kv_outputs = [n for n in output_names if n.startswith("present")]
- if kv_outputs:
- print(f" KV cache outputs: {len(kv_outputs)} tensors (generation-ready)")
- else:
- print(" WARNING: No KV cache outputs — model may not support efficient generation")
- # Sanity checks
- assert logits.shape[0] == 1, "Batch size mismatch"
- assert logits.shape[1] == input_ids.shape[1], "Sequence length mismatch"
- assert logits.max() > logits.min(), "Logits are constant (broken model)"
- assert not np.isnan(logits).any(), "Logits contain NaN"
- assert not np.isinf(logits).any(), "Logits contain Inf"
- print(" Validation PASSED")
- def write_transformers_js_config(onnx_dir: str):
- """Write Transformers.js compatibility config."""
- config_path = Path(onnx_dir) / "transformers_js_config.json"
- config = {
- "model_type": "text-generation",
- "quantized": True,
- }
- config_path.write_text(json.dumps(config, indent=2) + "\n")
- print(f" Wrote {config_path.name}")
- def upload_to_hub(
- onnx_dir: str,
- output_repo: str,
- base_model: str,
- sft_model: str,
- grpo_model: str,
- ):
- """Upload ONNX model to HuggingFace Hub."""
- print(f"\nStep 7: Uploading to {output_repo}...")
- api = HfApi()
- api.create_repo(repo_id=output_repo, repo_type="model", exist_ok=True)
- api.upload_folder(
- folder_path=onnx_dir,
- repo_id=output_repo,
- commit_message="Upload ONNX model",
- )
- readme = f"""---
- base_model: {base_model}
- tags: [onnx, transformers.js, webgpu, query-expansion, qmd]
- library_name: transformers.js
- ---
- # {output_repo.split("/")[-1]}
- ONNX conversion of the QMD Query Expansion model for use with
- [Transformers.js](https://huggingface.co/docs/transformers.js) and WebGPU.
- ## Details
- - **Base:** {base_model}
- - **SFT:** {sft_model}
- - **GRPO:** {grpo_model}
- - **Task:** Query expansion (lex/vec/hyde format)
- - **Format:** ONNX with Q4 quantization
- ## Usage with Transformers.js
- ```javascript
- import {{ AutoTokenizer, AutoModelForCausalLM }} from "@huggingface/transformers";
- const tokenizer = await AutoTokenizer.from_pretrained("{output_repo}");
- const model = await AutoModelForCausalLM.from_pretrained("{output_repo}", {{
- dtype: "q4",
- device: "webgpu",
- }});
- ```
- ## Prompt Format
- ```
- <|im_start|>user
- /no_think Expand this search query: your query here<|im_end|>
- <|im_start|>assistant
- ```
- """
- api.upload_file(
- path_or_fileobj=readme.encode(),
- path_in_repo="README.md",
- repo_id=output_repo,
- )
- def main():
- parser = argparse.ArgumentParser(description="Convert QMD model to ONNX")
- parser.add_argument(
- "--size", choices=PRESETS.keys(), help="Use preset config for model size",
- )
- parser.add_argument("--base", help="Base model (overrides preset)")
- parser.add_argument("--sft", help="SFT adapter (overrides preset)")
- parser.add_argument("--grpo", help="GRPO adapter (overrides preset)")
- parser.add_argument("--output", help="Output HF repo (overrides preset)")
- parser.add_argument(
- "--quantize",
- choices=["q4", "q8", "fp16", "none"],
- default="q4",
- help="Quantization type (default: q4)",
- )
- parser.add_argument(
- "--no-upload", action="store_true", help="Don't upload to HF Hub",
- )
- parser.add_argument(
- "--validate", action="store_true",
- help="Run inference validation on exported model",
- )
- parser.add_argument(
- "--validate-only", metavar="DIR",
- help="Skip export, only validate an existing ONNX dir",
- )
- args = parser.parse_args()
- # Validate-only mode: skip export, just run validation
- if args.validate_only:
- validate_onnx(args.validate_only, "")
- return
- # Resolve config
- if args.size:
- preset = PRESETS[args.size]
- base_model = args.base or preset["base"]
- sft_model = args.sft or preset["sft"]
- grpo_model = args.grpo or preset["grpo"]
- output_repo = args.output or preset["output"]
- elif args.base and args.sft and args.grpo and args.output:
- base_model = args.base
- sft_model = args.sft
- grpo_model = args.grpo
- output_repo = args.output
- else:
- parser.error(
- "Either --size or all of --base/--sft/--grpo/--output are required",
- )
- model_name = output_repo.split("/")[-1]
- print(f"QMD ONNX Conversion: {model_name}")
- print("=" * 60)
- # Login
- hf_token = os.environ.get("HF_TOKEN")
- if hf_token:
- print("Logging in to HuggingFace...")
- login(token=hf_token)
- # Merge adapters
- model, tokenizer = merge_adapters(base_model, sft_model, grpo_model)
- # Export to ONNX
- onnx_dir = f"/tmp/onnx_output/{model_name}"
- os.makedirs(onnx_dir, exist_ok=True)
- export_onnx(model, tokenizer, onnx_dir)
- # Quantize
- quantize_onnx(onnx_dir, args.quantize)
- # Write Transformers.js config
- write_transformers_js_config(onnx_dir)
- # Validate
- if args.validate:
- validate_onnx(onnx_dir, base_model)
- # Upload
- if not args.no_upload:
- upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model)
- print(f"\nDone! ONNX files at: {onnx_dir}")
- if not args.no_upload:
- print(f"Repository: https://huggingface.co/{output_repo}")
- if __name__ == "__main__":
- main()
|