convert_onnx.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "transformers>=4.36.0",
  6. # "peft>=0.7.0",
  7. # "torch>=2.0.0",
  8. # "accelerate>=0.24.0",
  9. # "huggingface_hub>=0.20.0",
  10. # "sentencepiece>=0.1.99",
  11. # "protobuf>=3.20.0",
  12. # "numpy",
  13. # "optimum[onnxruntime]",
  14. # "onnx>=1.15.0",
  15. # "onnxruntime>=1.17.0",
  16. # ]
  17. # ///
  18. """
  19. Convert QMD query expansion model to ONNX format for Transformers.js.
  20. Loads the base model, merges SFT and GRPO adapters, then exports to ONNX
  21. with quantization for browser deployment via Transformers.js + WebGPU.
  22. Usage:
  23. uv run convert_onnx.py --size 1.7B
  24. uv run convert_onnx.py --size 1.7B --no-upload
  25. uv run convert_onnx.py --base Qwen/Qwen3-1.7B \
  26. --sft tobil/qmd-query-expansion-1.7B-sft \
  27. --grpo tobil/qmd-query-expansion-1.7B-grpo \
  28. --output tobil/qmd-query-expansion-1.7B-ONNX
  29. Quantization options:
  30. --quantize q4 MatMulNBits 4-bit (default, smallest)
  31. --quantize q8 8-bit dynamic quantization
  32. --quantize fp16 FP16 (requires GPU export)
  33. --quantize none No quantization (FP32, ~7GB)
  34. """
  35. import argparse
  36. import json
  37. import os
  38. import shutil
  39. import sys
  40. from pathlib import Path
  41. import torch
  42. from huggingface_hub import HfApi, login
  43. from peft import PeftModel
  44. from transformers import AutoModelForCausalLM, AutoTokenizer
  45. PRESETS = {
  46. "1.7B": {
  47. "base": "Qwen/Qwen3-1.7B",
  48. "sft": "tobil/qmd-query-expansion-1.7B-sft",
  49. "grpo": "tobil/qmd-query-expansion-1.7B-grpo",
  50. "output": "tobil/qmd-query-expansion-1.7B-ONNX",
  51. },
  52. "4B": {
  53. "base": "Qwen/Qwen3-4B",
  54. "sft": "tobil/qmd-query-expansion-4B-sft",
  55. "grpo": "tobil/qmd-query-expansion-4B-grpo",
  56. "output": "tobil/qmd-query-expansion-4B-ONNX",
  57. },
  58. }
  59. def merge_adapters(base_model: str, sft_model: str, grpo_model: str) -> tuple:
  60. """Load base model, merge SFT + GRPO adapters, return (model, tokenizer)."""
  61. print(f"\nStep 1: Loading base model {base_model}...")
  62. model = AutoModelForCausalLM.from_pretrained(
  63. base_model, dtype=torch.float32, trust_remote_code=True,
  64. )
  65. print(f"Step 2: Merging SFT adapter {sft_model}...")
  66. model = PeftModel.from_pretrained(model, sft_model)
  67. model = model.merge_and_unload()
  68. print(f"Step 3: Merging GRPO adapter {grpo_model}...")
  69. model = PeftModel.from_pretrained(model, grpo_model)
  70. model = model.merge_and_unload()
  71. tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
  72. return model, tokenizer
  73. def export_onnx(model, tokenizer, output_dir: str):
  74. """Export merged model to ONNX using Optimum."""
  75. from optimum.exporters.onnx import main_export
  76. # Save merged model to temp dir first (Optimum needs HF format on disk)
  77. merged_dir = "/tmp/merged_model_onnx"
  78. print(f"\nStep 4: Saving merged model to {merged_dir}...")
  79. model.save_pretrained(merged_dir, safe_serialization=True)
  80. tokenizer.save_pretrained(merged_dir)
  81. print(f"\nStep 5: Exporting to ONNX at {output_dir}...")
  82. # no_post_process=True avoids the 2GB protobuf serialization limit
  83. # that occurs during tied-weight deduplication on large FP32 models.
  84. # The exported model still works correctly — the tied weights just
  85. # aren't deduplicated in the graph, which is fine since we quantize next.
  86. main_export(
  87. model_name_or_path=merged_dir,
  88. output=output_dir,
  89. task="text-generation-with-past",
  90. device="cpu",
  91. fp16=False,
  92. no_post_process=True,
  93. )
  94. # Clean up temp merged dir
  95. shutil.rmtree(merged_dir, ignore_errors=True)
  96. def _find_onnx_model(onnx_dir: str) -> Path:
  97. """Find the main ONNX model file in the output directory."""
  98. model_path = Path(onnx_dir) / "model.onnx"
  99. if model_path.exists():
  100. return model_path
  101. candidates = list(Path(onnx_dir).glob("*.onnx"))
  102. if not candidates:
  103. raise FileNotFoundError(f"No .onnx files found in {onnx_dir}")
  104. return candidates[0]
  105. def quantize_onnx(onnx_dir: str, quantize_type: str):
  106. """Quantize the exported ONNX model using optimum-cli for memory safety."""
  107. if quantize_type == "none":
  108. print("\nSkipping quantization (FP32).")
  109. return
  110. model_path = _find_onnx_model(onnx_dir)
  111. print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
  112. q_dir = Path(onnx_dir) / f"quantized_{quantize_type}"
  113. q_dir.mkdir(exist_ok=True)
  114. if quantize_type == "q4":
  115. # Use optimum-cli which handles external data format and memory properly.
  116. # Direct MatMulNBitsQuantizer can OOM on large FP32 models.
  117. cmd = [
  118. sys.executable, "-m", "optimum.commands.optimum_cli",
  119. "onnxruntime", "quantize",
  120. "--onnx_model", onnx_dir,
  121. "--o", str(q_dir),
  122. "--avx2",
  123. ]
  124. print(f" Running: optimum-cli onnxruntime quantize ...")
  125. result = subprocess.run(cmd, capture_output=True, text=True)
  126. if result.returncode != 0:
  127. print(f" optimum-cli failed, falling back to MatMulNBitsQuantizer...")
  128. _quantize_q4_direct(model_path, q_dir)
  129. elif quantize_type == "q8":
  130. cmd = [
  131. sys.executable, "-m", "optimum.commands.optimum_cli",
  132. "onnxruntime", "quantize",
  133. "--onnx_model", onnx_dir,
  134. "--o", str(q_dir),
  135. "--avx2",
  136. ]
  137. result = subprocess.run(cmd, capture_output=True, text=True)
  138. if result.returncode != 0:
  139. print(f" optimum-cli quantize failed: {result.stderr[:300]}")
  140. elif quantize_type == "fp16":
  141. _convert_fp16_external(model_path, q_dir)
  142. # Report sizes
  143. for f in sorted(q_dir.glob("*.onnx*")):
  144. size_mb = f.stat().st_size / (1024 * 1024)
  145. if size_mb > 1:
  146. print(f" {f.name}: {size_mb:.1f} MB")
  147. # Move quantized files back into main dir, replacing originals
  148. for f in q_dir.iterdir():
  149. dest = Path(onnx_dir) / f.name
  150. if dest.exists():
  151. dest.unlink()
  152. shutil.move(str(f), str(dest))
  153. q_dir.rmdir()
  154. def _quantize_q4_direct(model_path: Path, output_dir: Path):
  155. """Direct Q4 quantization fallback (may need significant RAM)."""
  156. from onnxruntime.quantization import matmul_nbits_quantizer
  157. quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
  158. model=str(model_path),
  159. block_size=32,
  160. is_symmetric=True,
  161. bits=4,
  162. )
  163. quant.process()
  164. q_path = output_dir / model_path.name
  165. quant.model.save(str(q_path))
  166. def _convert_fp16_external(model_path: Path, output_dir: Path):
  167. """Convert ONNX model to FP16 using onnxconverter-common (handles external data)."""
  168. print(" Converting to FP16 (external data format)...")
  169. try:
  170. from onnxconverter_common import float16
  171. import onnx
  172. model = onnx.load(str(model_path), load_external_data=True)
  173. model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
  174. fp16_path = output_dir / model_path.name
  175. onnx.save(model_fp16, str(fp16_path))
  176. except ImportError:
  177. print(" onnxconverter-common not available; skipping FP16 conversion.")
  178. def validate_onnx(onnx_dir: str, base_model: str):
  179. """Run a sample inference through the ONNX model to verify it works."""
  180. import onnxruntime as ort
  181. import numpy as np
  182. model_path = _find_onnx_model(onnx_dir)
  183. print(f"\nValidation: loading {model_path.name}...")
  184. tokenizer = AutoTokenizer.from_pretrained(onnx_dir, trust_remote_code=True)
  185. session = ort.InferenceSession(
  186. str(model_path),
  187. providers=["CPUExecutionProvider"],
  188. )
  189. # Tokenize a test prompt
  190. test_query = "/no_think Expand this search query: distributed consensus"
  191. chat_prompt = tokenizer.apply_chat_template(
  192. [{"role": "user", "content": test_query}],
  193. add_generation_prompt=True,
  194. tokenize=False,
  195. )
  196. inputs = tokenizer(chat_prompt, return_tensors="np")
  197. input_ids = inputs["input_ids"].astype(np.int64)
  198. attention_mask = inputs["attention_mask"].astype(np.int64)
  199. # Build feed dict with all required inputs
  200. seq_len = input_ids.shape[1]
  201. feed = {"input_ids": input_ids, "attention_mask": attention_mask}
  202. # Add position_ids if needed
  203. all_inputs = {inp.name: inp for inp in session.get_inputs()}
  204. if "position_ids" in all_inputs:
  205. feed["position_ids"] = np.arange(seq_len, dtype=np.int64).reshape(1, -1)
  206. # Initialize past_key_values to zeros if the model expects them
  207. for name, inp in sorted(all_inputs.items()):
  208. if name.startswith("past_key_values"):
  209. shape = []
  210. for dim in inp.shape:
  211. shape.append(dim if isinstance(dim, int) else 0)
  212. # batch dim = 1
  213. if shape and shape[0] == 0:
  214. shape[0] = 1
  215. feed[name] = np.zeros(shape, dtype=np.float32)
  216. # Run inference
  217. output_names = [o.name for o in session.get_outputs()]
  218. results = session.run(output_names, feed)
  219. # Check logits shape
  220. logits = results[0]
  221. print(f" Input tokens: {input_ids.shape[1]}")
  222. print(f" Output logits shape: {logits.shape}")
  223. print(f" Logits range: [{logits.min():.2f}, {logits.max():.2f}]")
  224. # Greedy decode next token
  225. next_token_id = int(np.argmax(logits[0, -1, :]))
  226. next_token = tokenizer.decode([next_token_id])
  227. print(f" Next token: {repr(next_token)} (id={next_token_id})")
  228. # Check KV cache outputs exist
  229. kv_outputs = [n for n in output_names if n.startswith("present")]
  230. if kv_outputs:
  231. print(f" KV cache outputs: {len(kv_outputs)} tensors (generation-ready)")
  232. else:
  233. print(" WARNING: No KV cache outputs — model may not support efficient generation")
  234. # Sanity checks
  235. assert logits.shape[0] == 1, "Batch size mismatch"
  236. assert logits.shape[1] == input_ids.shape[1], "Sequence length mismatch"
  237. assert logits.max() > logits.min(), "Logits are constant (broken model)"
  238. assert not np.isnan(logits).any(), "Logits contain NaN"
  239. assert not np.isinf(logits).any(), "Logits contain Inf"
  240. print(" Validation PASSED")
  241. def write_transformers_js_config(onnx_dir: str):
  242. """Write Transformers.js compatibility config."""
  243. config_path = Path(onnx_dir) / "transformers_js_config.json"
  244. config = {
  245. "model_type": "text-generation",
  246. "quantized": True,
  247. }
  248. config_path.write_text(json.dumps(config, indent=2) + "\n")
  249. print(f" Wrote {config_path.name}")
  250. def upload_to_hub(
  251. onnx_dir: str,
  252. output_repo: str,
  253. base_model: str,
  254. sft_model: str,
  255. grpo_model: str,
  256. ):
  257. """Upload ONNX model to HuggingFace Hub."""
  258. print(f"\nStep 7: Uploading to {output_repo}...")
  259. api = HfApi()
  260. api.create_repo(repo_id=output_repo, repo_type="model", exist_ok=True)
  261. api.upload_folder(
  262. folder_path=onnx_dir,
  263. repo_id=output_repo,
  264. commit_message="Upload ONNX model",
  265. )
  266. readme = f"""---
  267. base_model: {base_model}
  268. tags: [onnx, transformers.js, webgpu, query-expansion, qmd]
  269. library_name: transformers.js
  270. ---
  271. # {output_repo.split("/")[-1]}
  272. ONNX conversion of the QMD Query Expansion model for use with
  273. [Transformers.js](https://huggingface.co/docs/transformers.js) and WebGPU.
  274. ## Details
  275. - **Base:** {base_model}
  276. - **SFT:** {sft_model}
  277. - **GRPO:** {grpo_model}
  278. - **Task:** Query expansion (lex/vec/hyde format)
  279. - **Format:** ONNX with Q4 quantization
  280. ## Usage with Transformers.js
  281. ```javascript
  282. import {{ AutoTokenizer, AutoModelForCausalLM }} from "@huggingface/transformers";
  283. const tokenizer = await AutoTokenizer.from_pretrained("{output_repo}");
  284. const model = await AutoModelForCausalLM.from_pretrained("{output_repo}", {{
  285. dtype: "q4",
  286. device: "webgpu",
  287. }});
  288. ```
  289. ## Prompt Format
  290. ```
  291. <|im_start|>user
  292. /no_think Expand this search query: your query here<|im_end|>
  293. <|im_start|>assistant
  294. ```
  295. """
  296. api.upload_file(
  297. path_or_fileobj=readme.encode(),
  298. path_in_repo="README.md",
  299. repo_id=output_repo,
  300. )
  301. def main():
  302. parser = argparse.ArgumentParser(description="Convert QMD model to ONNX")
  303. parser.add_argument(
  304. "--size", choices=PRESETS.keys(), help="Use preset config for model size",
  305. )
  306. parser.add_argument("--base", help="Base model (overrides preset)")
  307. parser.add_argument("--sft", help="SFT adapter (overrides preset)")
  308. parser.add_argument("--grpo", help="GRPO adapter (overrides preset)")
  309. parser.add_argument("--output", help="Output HF repo (overrides preset)")
  310. parser.add_argument(
  311. "--quantize",
  312. choices=["q4", "q8", "fp16", "none"],
  313. default="q4",
  314. help="Quantization type (default: q4)",
  315. )
  316. parser.add_argument(
  317. "--no-upload", action="store_true", help="Don't upload to HF Hub",
  318. )
  319. parser.add_argument(
  320. "--validate", action="store_true",
  321. help="Run inference validation on exported model",
  322. )
  323. parser.add_argument(
  324. "--validate-only", metavar="DIR",
  325. help="Skip export, only validate an existing ONNX dir",
  326. )
  327. args = parser.parse_args()
  328. # Validate-only mode: skip export, just run validation
  329. if args.validate_only:
  330. validate_onnx(args.validate_only, "")
  331. return
  332. # Resolve config
  333. if args.size:
  334. preset = PRESETS[args.size]
  335. base_model = args.base or preset["base"]
  336. sft_model = args.sft or preset["sft"]
  337. grpo_model = args.grpo or preset["grpo"]
  338. output_repo = args.output or preset["output"]
  339. elif args.base and args.sft and args.grpo and args.output:
  340. base_model = args.base
  341. sft_model = args.sft
  342. grpo_model = args.grpo
  343. output_repo = args.output
  344. else:
  345. parser.error(
  346. "Either --size or all of --base/--sft/--grpo/--output are required",
  347. )
  348. model_name = output_repo.split("/")[-1]
  349. print(f"QMD ONNX Conversion: {model_name}")
  350. print("=" * 60)
  351. # Login
  352. hf_token = os.environ.get("HF_TOKEN")
  353. if hf_token:
  354. print("Logging in to HuggingFace...")
  355. login(token=hf_token)
  356. # Merge adapters
  357. model, tokenizer = merge_adapters(base_model, sft_model, grpo_model)
  358. # Export to ONNX
  359. onnx_dir = f"/tmp/onnx_output/{model_name}"
  360. os.makedirs(onnx_dir, exist_ok=True)
  361. export_onnx(model, tokenizer, onnx_dir)
  362. # Quantize
  363. quantize_onnx(onnx_dir, args.quantize)
  364. # Write Transformers.js config
  365. write_transformers_js_config(onnx_dir)
  366. # Validate
  367. if args.validate:
  368. validate_onnx(onnx_dir, base_model)
  369. # Upload
  370. if not args.no_upload:
  371. upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model)
  372. print(f"\nDone! ONNX files at: {onnx_dir}")
  373. if not args.no_upload:
  374. print(f"Repository: https://huggingface.co/{output_repo}")
  375. if __name__ == "__main__":
  376. main()