convert_onnx.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "transformers>=4.36.0",
  6. # "peft>=0.7.0",
  7. # "torch>=2.0.0",
  8. # "accelerate>=0.24.0",
  9. # "huggingface_hub>=0.20.0",
  10. # "sentencepiece>=0.1.99",
  11. # "protobuf>=3.20.0",
  12. # "numpy",
  13. # "optimum[onnxruntime]",
  14. # "onnx>=1.15.0",
  15. # "onnxruntime>=1.17.0",
  16. # "onnxconverter-common>=1.14.0",
  17. # ]
  18. # ///
  19. """
  20. Convert QMD query expansion model to ONNX format for Transformers.js.
  21. Loads the base model, merges SFT and GRPO adapters, then exports to ONNX
  22. with quantization for browser deployment via Transformers.js + WebGPU.
  23. Usage:
  24. uv run convert_onnx.py --size 1.7B
  25. uv run convert_onnx.py --size 1.7B --no-upload
  26. uv run convert_onnx.py --base Qwen/Qwen3-1.7B \
  27. --sft tobil/qmd-query-expansion-1.7B-sft \
  28. --grpo tobil/qmd-query-expansion-1.7B-grpo \
  29. --output tobil/qmd-query-expansion-1.7B-ONNX
  30. Quantization options:
  31. --quantize q4 MatMulNBits 4-bit (default, smallest)
  32. --quantize q8 8-bit dynamic quantization
  33. --quantize fp16 FP16 (requires GPU export)
  34. --quantize none No quantization (FP32, ~7GB)
  35. """
  36. import argparse
  37. import json
  38. import os
  39. import shutil
  40. import subprocess
  41. import sys
  42. from pathlib import Path
  43. import torch
  44. from huggingface_hub import HfApi, login
  45. from peft import PeftModel
  46. from transformers import AutoModelForCausalLM, AutoTokenizer
  47. PRESETS = {
  48. "1.7B": {
  49. "base": "Qwen/Qwen3-1.7B",
  50. "sft": "tobil/qmd-query-expansion-1.7B-sft",
  51. "grpo": "tobil/qmd-query-expansion-1.7B-grpo",
  52. "output": "tobil/qmd-query-expansion-1.7B-ONNX",
  53. },
  54. "4B": {
  55. "base": "Qwen/Qwen3-4B",
  56. "sft": "tobil/qmd-query-expansion-4B-sft",
  57. "grpo": "tobil/qmd-query-expansion-4B-grpo",
  58. "output": "tobil/qmd-query-expansion-4B-ONNX",
  59. },
  60. }
  61. def merge_adapters(base_model: str, sft_model: str, grpo_model: str) -> tuple:
  62. """Load base model, merge SFT + GRPO adapters, return (model, tokenizer)."""
  63. print(f"\nStep 1: Loading base model {base_model}...")
  64. model = AutoModelForCausalLM.from_pretrained(
  65. base_model, dtype=torch.float32, trust_remote_code=True,
  66. )
  67. print(f"Step 2: Merging SFT adapter {sft_model}...")
  68. model = PeftModel.from_pretrained(model, sft_model)
  69. model = model.merge_and_unload()
  70. print(f"Step 3: Merging GRPO adapter {grpo_model}...")
  71. model = PeftModel.from_pretrained(model, grpo_model)
  72. model = model.merge_and_unload()
  73. tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
  74. return model, tokenizer
  75. def export_onnx(model, tokenizer, output_dir: str):
  76. """Export merged model to ONNX using Optimum."""
  77. from optimum.exporters.onnx import main_export
  78. # Save merged model to temp dir first (Optimum needs HF format on disk)
  79. merged_dir = "/tmp/merged_model_onnx"
  80. print(f"\nStep 4: Saving merged model to {merged_dir}...")
  81. model.save_pretrained(merged_dir, safe_serialization=True)
  82. tokenizer.save_pretrained(merged_dir)
  83. print(f"\nStep 5: Exporting to ONNX at {output_dir}...")
  84. # no_post_process=True avoids the 2GB protobuf serialization limit
  85. # that occurs during tied-weight deduplication on large FP32 models.
  86. # The exported model still works correctly — the tied weights just
  87. # aren't deduplicated in the graph, which is fine since we quantize next.
  88. main_export(
  89. model_name_or_path=merged_dir,
  90. output=output_dir,
  91. task="text-generation-with-past",
  92. device="cpu",
  93. fp16=False,
  94. no_post_process=True,
  95. )
  96. # Clean up temp merged dir
  97. shutil.rmtree(merged_dir, ignore_errors=True)
  98. def _find_onnx_model(onnx_dir: str) -> Path:
  99. """Find the main ONNX model file in the output directory."""
  100. model_path = Path(onnx_dir) / "model.onnx"
  101. if model_path.exists():
  102. return model_path
  103. candidates = list(Path(onnx_dir).glob("*.onnx"))
  104. if not candidates:
  105. raise FileNotFoundError(f"No .onnx files found in {onnx_dir}")
  106. return candidates[0]
  107. def quantize_onnx(onnx_dir: str, quantize_type: str):
  108. """Quantize the exported ONNX model."""
  109. if quantize_type == "none":
  110. print("\nSkipping quantization (FP32).")
  111. return
  112. model_path = _find_onnx_model(onnx_dir)
  113. print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
  114. if quantize_type == "q4":
  115. _quantize_q4(model_path)
  116. elif quantize_type == "q8":
  117. _quantize_q8(model_path)
  118. elif quantize_type == "fp16":
  119. _convert_fp16(model_path)
  120. def _quantize_q4(model_path: Path):
  121. """4-bit MatMulNBits quantization via onnxruntime. Needs ~16GB RAM for 1.7B models."""
  122. from onnxruntime.quantization import matmul_nbits_quantizer
  123. q_path = model_path.with_name(model_path.stem + "_q4" + model_path.suffix)
  124. quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
  125. model=str(model_path),
  126. block_size=32,
  127. is_symmetric=True,
  128. bits=4,
  129. )
  130. quant.process()
  131. quant.model.save(str(q_path))
  132. # Remove original FP32 files, keep only quantized
  133. if q_path.exists():
  134. _report_size(q_path)
  135. model_path.unlink(missing_ok=True)
  136. data_path = model_path.with_name(model_path.name + "_data")
  137. data_path.unlink(missing_ok=True)
  138. # Rename quantized to model.onnx for Transformers.js compatibility
  139. q_path.rename(model_path)
  140. print(f" Renamed {q_path.name} -> {model_path.name}")
  141. def _quantize_q8(model_path: Path):
  142. """8-bit dynamic quantization via onnxruntime."""
  143. from onnxruntime.quantization import quantize_dynamic, QuantType
  144. q_path = model_path.with_name(model_path.stem + "_q8" + model_path.suffix)
  145. quantize_dynamic(
  146. model_input=str(model_path),
  147. model_output=str(q_path),
  148. weight_type=QuantType.QUInt8,
  149. )
  150. if q_path.exists():
  151. _report_size(q_path)
  152. model_path.unlink(missing_ok=True)
  153. data_path = model_path.with_name(model_path.name + "_data")
  154. data_path.unlink(missing_ok=True)
  155. q_path.rename(model_path)
  156. print(f" Renamed {q_path.name} -> {model_path.name}")
  157. def _convert_fp16(model_path: Path):
  158. """Convert ONNX model weights to FP16."""
  159. from onnxconverter_common import float16
  160. import onnx
  161. print(" Converting to FP16...")
  162. model = onnx.load(str(model_path), load_external_data=True)
  163. model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
  164. fp16_path = model_path.with_name(model_path.stem + "_fp16" + model_path.suffix)
  165. onnx.save(model_fp16, str(fp16_path))
  166. if fp16_path.exists():
  167. _report_size(fp16_path)
  168. model_path.unlink(missing_ok=True)
  169. data_path = model_path.with_name(model_path.name + "_data")
  170. data_path.unlink(missing_ok=True)
  171. fp16_path.rename(model_path)
  172. print(f" Renamed {fp16_path.name} -> {model_path.name}")
  173. def _report_size(path: Path):
  174. """Print file size in MB."""
  175. size_mb = path.stat().st_size / (1024 * 1024)
  176. print(f" {path.name}: {size_mb:.1f} MB")
  177. def validate_onnx(onnx_dir: str, base_model: str):
  178. """Run a sample inference through the ONNX model to verify it works."""
  179. import onnxruntime as ort
  180. import numpy as np
  181. model_path = _find_onnx_model(onnx_dir)
  182. print(f"\nValidation: loading {model_path.name}...")
  183. tokenizer = AutoTokenizer.from_pretrained(onnx_dir, trust_remote_code=True)
  184. session = ort.InferenceSession(
  185. str(model_path),
  186. providers=["CPUExecutionProvider"],
  187. )
  188. # Tokenize a test prompt
  189. test_query = "/no_think Expand this search query: distributed consensus"
  190. chat_prompt = tokenizer.apply_chat_template(
  191. [{"role": "user", "content": test_query}],
  192. add_generation_prompt=True,
  193. tokenize=False,
  194. )
  195. inputs = tokenizer(chat_prompt, return_tensors="np")
  196. input_ids = inputs["input_ids"].astype(np.int64)
  197. attention_mask = inputs["attention_mask"].astype(np.int64)
  198. # Build feed dict with all required inputs
  199. seq_len = input_ids.shape[1]
  200. feed = {"input_ids": input_ids, "attention_mask": attention_mask}
  201. # Add position_ids if needed
  202. all_inputs = {inp.name: inp for inp in session.get_inputs()}
  203. if "position_ids" in all_inputs:
  204. feed["position_ids"] = np.arange(seq_len, dtype=np.int64).reshape(1, -1)
  205. # Initialize past_key_values to zeros if the model expects them
  206. for name, inp in sorted(all_inputs.items()):
  207. if name.startswith("past_key_values"):
  208. shape = []
  209. for dim in inp.shape:
  210. shape.append(dim if isinstance(dim, int) else 0)
  211. # batch dim = 1
  212. if shape and shape[0] == 0:
  213. shape[0] = 1
  214. feed[name] = np.zeros(shape, dtype=np.float32)
  215. # Run inference
  216. output_names = [o.name for o in session.get_outputs()]
  217. results = session.run(output_names, feed)
  218. # Check logits shape
  219. logits = results[0]
  220. print(f" Input tokens: {input_ids.shape[1]}")
  221. print(f" Output logits shape: {logits.shape}")
  222. print(f" Logits range: [{logits.min():.2f}, {logits.max():.2f}]")
  223. # Greedy decode next token
  224. next_token_id = int(np.argmax(logits[0, -1, :]))
  225. next_token = tokenizer.decode([next_token_id])
  226. print(f" Next token: {repr(next_token)} (id={next_token_id})")
  227. # Check KV cache outputs exist
  228. kv_outputs = [n for n in output_names if n.startswith("present")]
  229. if kv_outputs:
  230. print(f" KV cache outputs: {len(kv_outputs)} tensors (generation-ready)")
  231. else:
  232. print(" WARNING: No KV cache outputs — model may not support efficient generation")
  233. # Sanity checks
  234. assert logits.shape[0] == 1, "Batch size mismatch"
  235. assert logits.shape[1] == input_ids.shape[1], "Sequence length mismatch"
  236. assert logits.max() > logits.min(), "Logits are constant (broken model)"
  237. assert not np.isnan(logits).any(), "Logits contain NaN"
  238. assert not np.isinf(logits).any(), "Logits contain Inf"
  239. print(" Validation PASSED")
  240. def write_transformers_js_config(onnx_dir: str, quantize_type: str = "q4"):
  241. """Write Transformers.js compatibility config."""
  242. config_path = Path(onnx_dir) / "transformers_js_config.json"
  243. config = {
  244. "model_type": "text-generation",
  245. "quantized": quantize_type != "none",
  246. }
  247. config_path.write_text(json.dumps(config, indent=2) + "\n")
  248. print(f" Wrote {config_path.name}")
  249. def upload_to_hub(
  250. onnx_dir: str,
  251. output_repo: str,
  252. base_model: str,
  253. sft_model: str,
  254. grpo_model: str,
  255. quantize_type: str = "q4",
  256. ):
  257. """Upload ONNX model to HuggingFace Hub."""
  258. print(f"\nStep 7: Uploading to {output_repo}...")
  259. api = HfApi()
  260. api.create_repo(repo_id=output_repo, repo_type="model", exist_ok=True)
  261. api.upload_folder(
  262. folder_path=onnx_dir,
  263. repo_id=output_repo,
  264. commit_message="Upload ONNX model",
  265. )
  266. # Map quantize_type to Transformers.js dtype values
  267. dtype_map = {"q4": "q4", "q8": "q8", "fp16": "fp16", "none": "fp32"}
  268. tj_dtype = dtype_map.get(quantize_type, "fp32")
  269. format_desc = "FP32 (no quantization)" if quantize_type == "none" else f"{quantize_type.upper()} quantization"
  270. repo_name = output_repo.split("/")[-1]
  271. readme = f"""---
  272. base_model: {base_model}
  273. tags: [onnx, transformers.js, webgpu, query-expansion, qmd]
  274. library_name: transformers.js
  275. ---
  276. # {repo_name}
  277. ONNX conversion of the QMD Query Expansion model for use with
  278. [Transformers.js](https://huggingface.co/docs/transformers.js) and WebGPU.
  279. ## Details
  280. - **Base:** {base_model}
  281. - **SFT:** {sft_model}
  282. - **GRPO:** {grpo_model}
  283. - **Task:** Query expansion (lex/vec/hyde format)
  284. - **Format:** ONNX with {format_desc}
  285. ## Usage with Transformers.js
  286. ```javascript
  287. import {{ AutoTokenizer, AutoModelForCausalLM }} from "@huggingface/transformers";
  288. const tokenizer = await AutoTokenizer.from_pretrained("{output_repo}");
  289. const model = await AutoModelForCausalLM.from_pretrained("{output_repo}", {{
  290. dtype: "{tj_dtype}",
  291. device: "webgpu",
  292. }});
  293. ```
  294. ## Prompt Format
  295. ```
  296. <|im_start|>user
  297. /no_think Expand this search query: your query here<|im_end|>
  298. <|im_start|>assistant
  299. ```
  300. """
  301. api.upload_file(
  302. path_or_fileobj=readme.encode(),
  303. path_in_repo="README.md",
  304. repo_id=output_repo,
  305. )
  306. def main():
  307. parser = argparse.ArgumentParser(description="Convert QMD model to ONNX")
  308. parser.add_argument(
  309. "--size", choices=PRESETS.keys(), help="Use preset config for model size",
  310. )
  311. parser.add_argument("--base", help="Base model (overrides preset)")
  312. parser.add_argument("--sft", help="SFT adapter (overrides preset)")
  313. parser.add_argument("--grpo", help="GRPO adapter (overrides preset)")
  314. parser.add_argument("--output", help="Output HF repo (overrides preset)")
  315. parser.add_argument(
  316. "--quantize",
  317. choices=["q4", "q8", "fp16", "none"],
  318. default="q4",
  319. help="Quantization type (default: q4)",
  320. )
  321. parser.add_argument(
  322. "--no-upload", action="store_true", help="Don't upload to HF Hub",
  323. )
  324. parser.add_argument(
  325. "--validate", action="store_true",
  326. help="Run inference validation on exported model",
  327. )
  328. parser.add_argument(
  329. "--validate-only", metavar="DIR",
  330. help="Skip export, only validate an existing ONNX dir",
  331. )
  332. args = parser.parse_args()
  333. # Validate-only mode: skip export, just run validation
  334. if args.validate_only:
  335. validate_onnx(args.validate_only, "")
  336. return
  337. # Resolve config
  338. if args.size:
  339. preset = PRESETS[args.size]
  340. base_model = args.base or preset["base"]
  341. sft_model = args.sft or preset["sft"]
  342. grpo_model = args.grpo or preset["grpo"]
  343. output_repo = args.output or preset["output"]
  344. elif args.base and args.sft and args.grpo and args.output:
  345. base_model = args.base
  346. sft_model = args.sft
  347. grpo_model = args.grpo
  348. output_repo = args.output
  349. else:
  350. parser.error(
  351. "Either --size or all of --base/--sft/--grpo/--output are required",
  352. )
  353. model_name = output_repo.split("/")[-1]
  354. print(f"QMD ONNX Conversion: {model_name}")
  355. print("=" * 60)
  356. # Login
  357. hf_token = os.environ.get("HF_TOKEN")
  358. if hf_token:
  359. print("Logging in to HuggingFace...")
  360. login(token=hf_token)
  361. # Merge adapters
  362. model, tokenizer = merge_adapters(base_model, sft_model, grpo_model)
  363. # Export to ONNX
  364. onnx_dir = f"/tmp/onnx_output/{model_name}"
  365. os.makedirs(onnx_dir, exist_ok=True)
  366. export_onnx(model, tokenizer, onnx_dir)
  367. # Quantize
  368. quantize_onnx(onnx_dir, args.quantize)
  369. # Write Transformers.js config
  370. write_transformers_js_config(onnx_dir, args.quantize)
  371. # Validate
  372. if args.validate:
  373. validate_onnx(onnx_dir, base_model)
  374. # Upload
  375. if not args.no_upload:
  376. upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model, args.quantize)
  377. print(f"\nDone! ONNX files at: {onnx_dir}")
  378. if not args.no_upload:
  379. print(f"Repository: https://huggingface.co/{output_repo}")
  380. if __name__ == "__main__":
  381. main()