convert_onnx.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "transformers>=4.36.0",
  6. # "peft>=0.7.0",
  7. # "torch>=2.0.0",
  8. # "accelerate>=0.24.0",
  9. # "huggingface_hub>=0.20.0",
  10. # "sentencepiece>=0.1.99",
  11. # "protobuf>=3.20.0",
  12. # "numpy",
  13. # "optimum>=1.17.0",
  14. # "onnx>=1.15.0",
  15. # "onnxruntime>=1.17.0",
  16. # ]
  17. # ///
  18. """
  19. Convert QMD query expansion model to ONNX format for Transformers.js.
  20. Loads the base model, merges SFT and GRPO adapters, then exports to ONNX
  21. with quantization for browser deployment via Transformers.js + WebGPU.
  22. Usage:
  23. uv run convert_onnx.py --size 1.7B
  24. uv run convert_onnx.py --size 1.7B --no-upload
  25. uv run convert_onnx.py --base Qwen/Qwen3-1.7B \
  26. --sft tobil/qmd-query-expansion-1.7B-sft \
  27. --grpo tobil/qmd-query-expansion-1.7B-grpo \
  28. --output tobil/qmd-query-expansion-1.7B-ONNX
  29. Quantization options:
  30. --quantize q4 MatMulNBits 4-bit (default, smallest)
  31. --quantize q8 8-bit dynamic quantization
  32. --quantize fp16 FP16 (requires GPU export)
  33. --quantize none No quantization (FP32, ~7GB)
  34. """
  35. import argparse
  36. import json
  37. import os
  38. import shutil
  39. import sys
  40. from pathlib import Path
  41. import torch
  42. from huggingface_hub import HfApi, login
  43. from peft import PeftModel
  44. from transformers import AutoModelForCausalLM, AutoTokenizer
  45. PRESETS = {
  46. "1.7B": {
  47. "base": "Qwen/Qwen3-1.7B",
  48. "sft": "tobil/qmd-query-expansion-1.7B-sft",
  49. "grpo": "tobil/qmd-query-expansion-1.7B-grpo",
  50. "output": "tobil/qmd-query-expansion-1.7B-ONNX",
  51. },
  52. "4B": {
  53. "base": "Qwen/Qwen3-4B",
  54. "sft": "tobil/qmd-query-expansion-4B-sft",
  55. "grpo": "tobil/qmd-query-expansion-4B-grpo",
  56. "output": "tobil/qmd-query-expansion-4B-ONNX",
  57. },
  58. }
  59. def merge_adapters(base_model: str, sft_model: str, grpo_model: str) -> tuple:
  60. """Load base model, merge SFT + GRPO adapters, return (model, tokenizer)."""
  61. print(f"\nStep 1: Loading base model {base_model}...")
  62. model = AutoModelForCausalLM.from_pretrained(
  63. base_model, torch_dtype=torch.float32, trust_remote_code=True,
  64. )
  65. print(f"Step 2: Merging SFT adapter {sft_model}...")
  66. model = PeftModel.from_pretrained(model, sft_model)
  67. model = model.merge_and_unload()
  68. print(f"Step 3: Merging GRPO adapter {grpo_model}...")
  69. model = PeftModel.from_pretrained(model, grpo_model)
  70. model = model.merge_and_unload()
  71. tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
  72. return model, tokenizer
  73. def export_onnx(model, tokenizer, output_dir: str):
  74. """Export merged model to ONNX using Optimum."""
  75. from optimum.exporters.onnx import main_export
  76. # Save merged model to temp dir first (Optimum needs HF format on disk)
  77. merged_dir = "/tmp/merged_model_onnx"
  78. print(f"\nStep 4: Saving merged model to {merged_dir}...")
  79. model.save_pretrained(merged_dir, safe_serialization=True)
  80. tokenizer.save_pretrained(merged_dir)
  81. print(f"\nStep 5: Exporting to ONNX at {output_dir}...")
  82. main_export(
  83. model_name_or_path=merged_dir,
  84. output=output_dir,
  85. task="text-generation-with-past",
  86. device="cpu",
  87. fp16=False,
  88. )
  89. # Clean up temp merged dir
  90. shutil.rmtree(merged_dir, ignore_errors=True)
  91. def quantize_onnx(onnx_dir: str, quantize_type: str):
  92. """Quantize the exported ONNX model."""
  93. if quantize_type == "none":
  94. print("\nSkipping quantization (FP32).")
  95. return
  96. model_path = Path(onnx_dir) / "model.onnx"
  97. if not model_path.exists():
  98. # Optimum may produce decoder_model.onnx for text-generation-with-past
  99. candidates = list(Path(onnx_dir).glob("*.onnx"))
  100. if not candidates:
  101. print(" WARNING: No .onnx files found to quantize.")
  102. return
  103. model_path = candidates[0]
  104. print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
  105. if quantize_type == "q4":
  106. try:
  107. from onnxruntime.quantization import matmul_nbits_quantizer
  108. quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
  109. model=str(model_path),
  110. block_size=32,
  111. is_symmetric=True,
  112. bits=4,
  113. )
  114. quant.process()
  115. q_path = model_path.with_name(
  116. model_path.stem + "_q4" + model_path.suffix,
  117. )
  118. quant.model.save(str(q_path))
  119. size_mb = q_path.stat().st_size / (1024 * 1024)
  120. print(f" Q4: {size_mb:.1f} MB -> {q_path.name}")
  121. except ImportError:
  122. print(" WARNING: onnxruntime quantization not available, trying alternative...")
  123. _quantize_dynamic(model_path, quantize_type)
  124. elif quantize_type == "q8":
  125. _quantize_dynamic(model_path, quantize_type)
  126. elif quantize_type == "fp16":
  127. _convert_fp16(model_path)
  128. def _quantize_dynamic(model_path: Path, qtype: str):
  129. """Dynamic quantization fallback."""
  130. from onnxruntime.quantization import quantize_dynamic, QuantType
  131. weight_type = QuantType.QUInt8 if qtype == "q8" else QuantType.QInt8
  132. q_path = model_path.with_name(
  133. model_path.stem + f"_{qtype}" + model_path.suffix,
  134. )
  135. quantize_dynamic(
  136. model_input=str(model_path),
  137. model_output=str(q_path),
  138. weight_type=weight_type,
  139. )
  140. size_mb = q_path.stat().st_size / (1024 * 1024)
  141. print(f" {qtype.upper()}: {size_mb:.1f} MB -> {q_path.name}")
  142. def _convert_fp16(model_path: Path):
  143. """Convert ONNX model to FP16."""
  144. import onnx
  145. from onnx import numpy_helper
  146. print(" Converting to FP16...")
  147. model = onnx.load(str(model_path))
  148. for initializer in model.graph.initializer:
  149. if initializer.data_type == onnx.TensorProto.FLOAT:
  150. np_data = numpy_helper.to_array(initializer)
  151. initializer.CopyFrom(
  152. numpy_helper.from_array(np_data.astype("float16"), initializer.name),
  153. )
  154. fp16_path = model_path.with_name(
  155. model_path.stem + "_fp16" + model_path.suffix,
  156. )
  157. onnx.save(model, str(fp16_path))
  158. size_mb = fp16_path.stat().st_size / (1024 * 1024)
  159. print(f" FP16: {size_mb:.1f} MB -> {fp16_path.name}")
  160. def write_transformers_js_config(onnx_dir: str):
  161. """Write Transformers.js compatibility config."""
  162. config_path = Path(onnx_dir) / "transformers_js_config.json"
  163. config = {
  164. "model_type": "text-generation",
  165. "quantized": True,
  166. }
  167. config_path.write_text(json.dumps(config, indent=2) + "\n")
  168. print(f" Wrote {config_path.name}")
  169. def upload_to_hub(
  170. onnx_dir: str,
  171. output_repo: str,
  172. base_model: str,
  173. sft_model: str,
  174. grpo_model: str,
  175. ):
  176. """Upload ONNX model to HuggingFace Hub."""
  177. print(f"\nStep 7: Uploading to {output_repo}...")
  178. api = HfApi()
  179. api.create_repo(repo_id=output_repo, repo_type="model", exist_ok=True)
  180. api.upload_folder(
  181. folder_path=onnx_dir,
  182. repo_id=output_repo,
  183. commit_message="Upload ONNX model",
  184. )
  185. readme = f"""---
  186. base_model: {base_model}
  187. tags: [onnx, transformers.js, webgpu, query-expansion, qmd]
  188. library_name: transformers.js
  189. ---
  190. # {output_repo.split("/")[-1]}
  191. ONNX conversion of the QMD Query Expansion model for use with
  192. [Transformers.js](https://huggingface.co/docs/transformers.js) and WebGPU.
  193. ## Details
  194. - **Base:** {base_model}
  195. - **SFT:** {sft_model}
  196. - **GRPO:** {grpo_model}
  197. - **Task:** Query expansion (lex/vec/hyde format)
  198. - **Format:** ONNX with Q4 quantization
  199. ## Usage with Transformers.js
  200. ```javascript
  201. import {{ AutoTokenizer, AutoModelForCausalLM }} from "@huggingface/transformers";
  202. const tokenizer = await AutoTokenizer.from_pretrained("{output_repo}");
  203. const model = await AutoModelForCausalLM.from_pretrained("{output_repo}", {{
  204. dtype: "q4",
  205. device: "webgpu",
  206. }});
  207. ```
  208. ## Prompt Format
  209. ```
  210. <|im_start|>user
  211. /no_think Expand this search query: your query here<|im_end|>
  212. <|im_start|>assistant
  213. ```
  214. """
  215. api.upload_file(
  216. path_or_fileobj=readme.encode(),
  217. path_in_repo="README.md",
  218. repo_id=output_repo,
  219. )
  220. def main():
  221. parser = argparse.ArgumentParser(description="Convert QMD model to ONNX")
  222. parser.add_argument(
  223. "--size", choices=PRESETS.keys(), help="Use preset config for model size",
  224. )
  225. parser.add_argument("--base", help="Base model (overrides preset)")
  226. parser.add_argument("--sft", help="SFT adapter (overrides preset)")
  227. parser.add_argument("--grpo", help="GRPO adapter (overrides preset)")
  228. parser.add_argument("--output", help="Output HF repo (overrides preset)")
  229. parser.add_argument(
  230. "--quantize",
  231. choices=["q4", "q8", "fp16", "none"],
  232. default="q4",
  233. help="Quantization type (default: q4)",
  234. )
  235. parser.add_argument(
  236. "--no-upload", action="store_true", help="Don't upload to HF Hub",
  237. )
  238. args = parser.parse_args()
  239. # Resolve config
  240. if args.size:
  241. preset = PRESETS[args.size]
  242. base_model = args.base or preset["base"]
  243. sft_model = args.sft or preset["sft"]
  244. grpo_model = args.grpo or preset["grpo"]
  245. output_repo = args.output or preset["output"]
  246. elif args.base and args.sft and args.grpo and args.output:
  247. base_model = args.base
  248. sft_model = args.sft
  249. grpo_model = args.grpo
  250. output_repo = args.output
  251. else:
  252. parser.error(
  253. "Either --size or all of --base/--sft/--grpo/--output are required",
  254. )
  255. model_name = output_repo.split("/")[-1]
  256. print(f"QMD ONNX Conversion: {model_name}")
  257. print("=" * 60)
  258. # Login
  259. hf_token = os.environ.get("HF_TOKEN")
  260. if hf_token:
  261. print("Logging in to HuggingFace...")
  262. login(token=hf_token)
  263. # Merge adapters
  264. model, tokenizer = merge_adapters(base_model, sft_model, grpo_model)
  265. # Export to ONNX
  266. onnx_dir = f"/tmp/onnx_output/{model_name}"
  267. os.makedirs(onnx_dir, exist_ok=True)
  268. export_onnx(model, tokenizer, onnx_dir)
  269. # Quantize
  270. quantize_onnx(onnx_dir, args.quantize)
  271. # Write Transformers.js config
  272. write_transformers_js_config(onnx_dir)
  273. # Upload
  274. if not args.no_upload:
  275. upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model)
  276. print(f"\nDone! ONNX files at: {onnx_dir}")
  277. if not args.no_upload:
  278. print(f"Repository: https://huggingface.co/{output_repo}")
  279. if __name__ == "__main__":
  280. main()