|
|
@@ -13,6 +13,7 @@
|
|
|
# "optimum[onnxruntime]",
|
|
|
# "onnx>=1.15.0",
|
|
|
# "onnxruntime>=1.17.0",
|
|
|
+# "onnxconverter-common>=1.14.0",
|
|
|
# ]
|
|
|
# ///
|
|
|
"""
|
|
|
@@ -40,6 +41,7 @@ import argparse
|
|
|
import json
|
|
|
import os
|
|
|
import shutil
|
|
|
+import subprocess
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
|
|
|
@@ -123,7 +125,7 @@ def _find_onnx_model(onnx_dir: str) -> Path:
|
|
|
|
|
|
|
|
|
def quantize_onnx(onnx_dir: str, quantize_type: str):
|
|
|
- """Quantize the exported ONNX model using optimum-cli for memory safety."""
|
|
|
+ """Quantize the exported ONNX model."""
|
|
|
if quantize_type == "none":
|
|
|
print("\nSkipping quantization (FP32).")
|
|
|
return
|
|
|
@@ -131,56 +133,19 @@ def quantize_onnx(onnx_dir: str, quantize_type: str):
|
|
|
model_path = _find_onnx_model(onnx_dir)
|
|
|
print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
|
|
|
|
|
|
- q_dir = Path(onnx_dir) / f"quantized_{quantize_type}"
|
|
|
- q_dir.mkdir(exist_ok=True)
|
|
|
-
|
|
|
if quantize_type == "q4":
|
|
|
- # Use optimum-cli which handles external data format and memory properly.
|
|
|
- # Direct MatMulNBitsQuantizer can OOM on large FP32 models.
|
|
|
- cmd = [
|
|
|
- sys.executable, "-m", "optimum.commands.optimum_cli",
|
|
|
- "onnxruntime", "quantize",
|
|
|
- "--onnx_model", onnx_dir,
|
|
|
- "--o", str(q_dir),
|
|
|
- "--avx2",
|
|
|
- ]
|
|
|
- print(f" Running: optimum-cli onnxruntime quantize ...")
|
|
|
- result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
- if result.returncode != 0:
|
|
|
- print(f" optimum-cli failed, falling back to MatMulNBitsQuantizer...")
|
|
|
- _quantize_q4_direct(model_path, q_dir)
|
|
|
+ _quantize_q4(model_path)
|
|
|
elif quantize_type == "q8":
|
|
|
- cmd = [
|
|
|
- sys.executable, "-m", "optimum.commands.optimum_cli",
|
|
|
- "onnxruntime", "quantize",
|
|
|
- "--onnx_model", onnx_dir,
|
|
|
- "--o", str(q_dir),
|
|
|
- "--avx2",
|
|
|
- ]
|
|
|
- result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
- if result.returncode != 0:
|
|
|
- print(f" optimum-cli quantize failed: {result.stderr[:300]}")
|
|
|
+ _quantize_q8(model_path)
|
|
|
elif quantize_type == "fp16":
|
|
|
- _convert_fp16_external(model_path, q_dir)
|
|
|
-
|
|
|
- # Report sizes
|
|
|
- for f in sorted(q_dir.glob("*.onnx*")):
|
|
|
- size_mb = f.stat().st_size / (1024 * 1024)
|
|
|
- if size_mb > 1:
|
|
|
- print(f" {f.name}: {size_mb:.1f} MB")
|
|
|
+ _convert_fp16(model_path)
|
|
|
|
|
|
- # Move quantized files back into main dir, replacing originals
|
|
|
- for f in q_dir.iterdir():
|
|
|
- dest = Path(onnx_dir) / f.name
|
|
|
- if dest.exists():
|
|
|
- dest.unlink()
|
|
|
- shutil.move(str(f), str(dest))
|
|
|
- q_dir.rmdir()
|
|
|
|
|
|
-
|
|
|
-def _quantize_q4_direct(model_path: Path, output_dir: Path):
|
|
|
- """Direct Q4 quantization fallback (may need significant RAM)."""
|
|
|
+def _quantize_q4(model_path: Path):
|
|
|
+ """4-bit MatMulNBits quantization via onnxruntime. Needs ~16GB RAM for 1.7B models."""
|
|
|
from onnxruntime.quantization import matmul_nbits_quantizer
|
|
|
+
|
|
|
+ q_path = model_path.with_name(model_path.stem + "_q4" + model_path.suffix)
|
|
|
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
|
|
|
model=str(model_path),
|
|
|
block_size=32,
|
|
|
@@ -188,22 +153,65 @@ def _quantize_q4_direct(model_path: Path, output_dir: Path):
|
|
|
bits=4,
|
|
|
)
|
|
|
quant.process()
|
|
|
- q_path = output_dir / model_path.name
|
|
|
quant.model.save(str(q_path))
|
|
|
|
|
|
+ # Remove original FP32 files, keep only quantized
|
|
|
+ if q_path.exists():
|
|
|
+ _report_size(q_path)
|
|
|
+ model_path.unlink(missing_ok=True)
|
|
|
+ data_path = model_path.with_name(model_path.name + "_data")
|
|
|
+ data_path.unlink(missing_ok=True)
|
|
|
+ # Rename quantized to model.onnx for Transformers.js compatibility
|
|
|
+ q_path.rename(model_path)
|
|
|
+ print(f" Renamed {q_path.name} -> {model_path.name}")
|
|
|
+
|
|
|
+
|
|
|
+def _quantize_q8(model_path: Path):
|
|
|
+ """8-bit dynamic quantization via onnxruntime."""
|
|
|
+ from onnxruntime.quantization import quantize_dynamic, QuantType
|
|
|
+
|
|
|
+ q_path = model_path.with_name(model_path.stem + "_q8" + model_path.suffix)
|
|
|
+ quantize_dynamic(
|
|
|
+ model_input=str(model_path),
|
|
|
+ model_output=str(q_path),
|
|
|
+ weight_type=QuantType.QUInt8,
|
|
|
+ )
|
|
|
+
|
|
|
+ if q_path.exists():
|
|
|
+ _report_size(q_path)
|
|
|
+ model_path.unlink(missing_ok=True)
|
|
|
+ data_path = model_path.with_name(model_path.name + "_data")
|
|
|
+ data_path.unlink(missing_ok=True)
|
|
|
+ q_path.rename(model_path)
|
|
|
+ print(f" Renamed {q_path.name} -> {model_path.name}")
|
|
|
+
|
|
|
+
|
|
|
+def _convert_fp16(model_path: Path):
|
|
|
+ """Convert ONNX model weights to FP16."""
|
|
|
+ from onnxconverter_common import float16
|
|
|
+ import onnx
|
|
|
+
|
|
|
+ print(" Converting to FP16...")
|
|
|
+ model = onnx.load(str(model_path), load_external_data=True)
|
|
|
+ model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
|
|
|
+
|
|
|
+ fp16_path = model_path.with_name(model_path.stem + "_fp16" + model_path.suffix)
|
|
|
+ onnx.save(model_fp16, str(fp16_path))
|
|
|
+
|
|
|
+ if fp16_path.exists():
|
|
|
+ _report_size(fp16_path)
|
|
|
+ model_path.unlink(missing_ok=True)
|
|
|
+ data_path = model_path.with_name(model_path.name + "_data")
|
|
|
+ data_path.unlink(missing_ok=True)
|
|
|
+ fp16_path.rename(model_path)
|
|
|
+ print(f" Renamed {fp16_path.name} -> {model_path.name}")
|
|
|
+
|
|
|
+
|
|
|
+def _report_size(path: Path):
|
|
|
+ """Print file size in MB."""
|
|
|
+ size_mb = path.stat().st_size / (1024 * 1024)
|
|
|
+ print(f" {path.name}: {size_mb:.1f} MB")
|
|
|
|
|
|
-def _convert_fp16_external(model_path: Path, output_dir: Path):
|
|
|
- """Convert ONNX model to FP16 using onnxconverter-common (handles external data)."""
|
|
|
- print(" Converting to FP16 (external data format)...")
|
|
|
- try:
|
|
|
- from onnxconverter_common import float16
|
|
|
- import onnx
|
|
|
- model = onnx.load(str(model_path), load_external_data=True)
|
|
|
- model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
|
|
|
- fp16_path = output_dir / model_path.name
|
|
|
- onnx.save(model_fp16, str(fp16_path))
|
|
|
- except ImportError:
|
|
|
- print(" onnxconverter-common not available; skipping FP16 conversion.")
|
|
|
|
|
|
|
|
|
def validate_onnx(onnx_dir: str, base_model: str):
|
|
|
@@ -283,12 +291,12 @@ def validate_onnx(onnx_dir: str, base_model: str):
|
|
|
print(" Validation PASSED")
|
|
|
|
|
|
|
|
|
-def write_transformers_js_config(onnx_dir: str):
|
|
|
+def write_transformers_js_config(onnx_dir: str, quantize_type: str = "q4"):
|
|
|
"""Write Transformers.js compatibility config."""
|
|
|
config_path = Path(onnx_dir) / "transformers_js_config.json"
|
|
|
config = {
|
|
|
"model_type": "text-generation",
|
|
|
- "quantized": True,
|
|
|
+ "quantized": quantize_type != "none",
|
|
|
}
|
|
|
config_path.write_text(json.dumps(config, indent=2) + "\n")
|
|
|
print(f" Wrote {config_path.name}")
|
|
|
@@ -300,6 +308,7 @@ def upload_to_hub(
|
|
|
base_model: str,
|
|
|
sft_model: str,
|
|
|
grpo_model: str,
|
|
|
+ quantize_type: str = "q4",
|
|
|
):
|
|
|
"""Upload ONNX model to HuggingFace Hub."""
|
|
|
print(f"\nStep 7: Uploading to {output_repo}...")
|
|
|
@@ -327,7 +336,7 @@ ONNX conversion of the QMD Query Expansion model for use with
|
|
|
- **SFT:** {sft_model}
|
|
|
- **GRPO:** {grpo_model}
|
|
|
- **Task:** Query expansion (lex/vec/hyde format)
|
|
|
-- **Format:** ONNX with Q4 quantization
|
|
|
+- **Format:** ONNX with {quantize_type.upper()} quantization
|
|
|
|
|
|
## Usage with Transformers.js
|
|
|
|
|
|
@@ -336,7 +345,7 @@ import {{ AutoTokenizer, AutoModelForCausalLM }} from "@huggingface/transformers
|
|
|
|
|
|
const tokenizer = await AutoTokenizer.from_pretrained("{output_repo}");
|
|
|
const model = await AutoModelForCausalLM.from_pretrained("{output_repo}", {{
|
|
|
- dtype: "q4",
|
|
|
+ dtype: "{quantize_type}",
|
|
|
device: "webgpu",
|
|
|
}});
|
|
|
```
|
|
|
@@ -427,7 +436,7 @@ def main():
|
|
|
quantize_onnx(onnx_dir, args.quantize)
|
|
|
|
|
|
# Write Transformers.js config
|
|
|
- write_transformers_js_config(onnx_dir)
|
|
|
+ write_transformers_js_config(onnx_dir, args.quantize)
|
|
|
|
|
|
# Validate
|
|
|
if args.validate:
|
|
|
@@ -435,7 +444,7 @@ def main():
|
|
|
|
|
|
# Upload
|
|
|
if not args.no_upload:
|
|
|
- upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model)
|
|
|
+ upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model, args.quantize)
|
|
|
|
|
|
print(f"\nDone! ONNX files at: {onnx_dir}")
|
|
|
if not args.no_upload:
|