Ver código fonte

fix: quantization paths, missing imports, and hardcoded metadata

- Add missing subprocess import (NameError on any quantize path)
- Replace broken optimum-cli quantize calls with direct onnxruntime:
  Q4 uses MatMulNBitsQuantizer, Q8 uses quantize_dynamic
- Add onnxconverter-common to deps for FP16 (was silently swallowed)
- Make FP16 fail loudly on missing dep instead of silently uploading FP32
- README and transformers_js_config now reflect actual quantize_type
  instead of always hardcoding Q4
- Remove dead _convert_fp16_external function
Shreyas Karnik 2 meses atrás
pai
commit
b05d8863ca
1 arquivos alterados com 73 adições e 64 exclusões
  1. 73 64
      finetune/convert_onnx.py

+ 73 - 64
finetune/convert_onnx.py

@@ -13,6 +13,7 @@
 #     "optimum[onnxruntime]",
 #     "onnx>=1.15.0",
 #     "onnxruntime>=1.17.0",
+#     "onnxconverter-common>=1.14.0",
 # ]
 # ///
 """
@@ -40,6 +41,7 @@ import argparse
 import json
 import os
 import shutil
+import subprocess
 import sys
 from pathlib import Path
 
@@ -123,7 +125,7 @@ def _find_onnx_model(onnx_dir: str) -> Path:
 
 
 def quantize_onnx(onnx_dir: str, quantize_type: str):
-    """Quantize the exported ONNX model using optimum-cli for memory safety."""
+    """Quantize the exported ONNX model."""
     if quantize_type == "none":
         print("\nSkipping quantization (FP32).")
         return
@@ -131,56 +133,19 @@ def quantize_onnx(onnx_dir: str, quantize_type: str):
     model_path = _find_onnx_model(onnx_dir)
     print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
 
-    q_dir = Path(onnx_dir) / f"quantized_{quantize_type}"
-    q_dir.mkdir(exist_ok=True)
-
     if quantize_type == "q4":
-        # Use optimum-cli which handles external data format and memory properly.
-        # Direct MatMulNBitsQuantizer can OOM on large FP32 models.
-        cmd = [
-            sys.executable, "-m", "optimum.commands.optimum_cli",
-            "onnxruntime", "quantize",
-            "--onnx_model", onnx_dir,
-            "--o", str(q_dir),
-            "--avx2",
-        ]
-        print(f"  Running: optimum-cli onnxruntime quantize ...")
-        result = subprocess.run(cmd, capture_output=True, text=True)
-        if result.returncode != 0:
-            print(f"  optimum-cli failed, falling back to MatMulNBitsQuantizer...")
-            _quantize_q4_direct(model_path, q_dir)
+        _quantize_q4(model_path)
     elif quantize_type == "q8":
-        cmd = [
-            sys.executable, "-m", "optimum.commands.optimum_cli",
-            "onnxruntime", "quantize",
-            "--onnx_model", onnx_dir,
-            "--o", str(q_dir),
-            "--avx2",
-        ]
-        result = subprocess.run(cmd, capture_output=True, text=True)
-        if result.returncode != 0:
-            print(f"  optimum-cli quantize failed: {result.stderr[:300]}")
+        _quantize_q8(model_path)
     elif quantize_type == "fp16":
-        _convert_fp16_external(model_path, q_dir)
-
-    # Report sizes
-    for f in sorted(q_dir.glob("*.onnx*")):
-        size_mb = f.stat().st_size / (1024 * 1024)
-        if size_mb > 1:
-            print(f"  {f.name}: {size_mb:.1f} MB")
+        _convert_fp16(model_path)
 
-    # Move quantized files back into main dir, replacing originals
-    for f in q_dir.iterdir():
-        dest = Path(onnx_dir) / f.name
-        if dest.exists():
-            dest.unlink()
-        shutil.move(str(f), str(dest))
-    q_dir.rmdir()
 
-
-def _quantize_q4_direct(model_path: Path, output_dir: Path):
-    """Direct Q4 quantization fallback (may need significant RAM)."""
+def _quantize_q4(model_path: Path):
+    """4-bit MatMulNBits quantization via onnxruntime. Needs ~16GB RAM for 1.7B models."""
     from onnxruntime.quantization import matmul_nbits_quantizer
+
+    q_path = model_path.with_name(model_path.stem + "_q4" + model_path.suffix)
     quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
         model=str(model_path),
         block_size=32,
@@ -188,22 +153,65 @@ def _quantize_q4_direct(model_path: Path, output_dir: Path):
         bits=4,
     )
     quant.process()
-    q_path = output_dir / model_path.name
     quant.model.save(str(q_path))
 
+    # Remove original FP32 files, keep only quantized
+    if q_path.exists():
+        _report_size(q_path)
+        model_path.unlink(missing_ok=True)
+        data_path = model_path.with_name(model_path.name + "_data")
+        data_path.unlink(missing_ok=True)
+        # Rename quantized to model.onnx for Transformers.js compatibility
+        q_path.rename(model_path)
+        print(f"  Renamed {q_path.name} -> {model_path.name}")
+
+
+def _quantize_q8(model_path: Path):
+    """8-bit dynamic quantization via onnxruntime."""
+    from onnxruntime.quantization import quantize_dynamic, QuantType
+
+    q_path = model_path.with_name(model_path.stem + "_q8" + model_path.suffix)
+    quantize_dynamic(
+        model_input=str(model_path),
+        model_output=str(q_path),
+        weight_type=QuantType.QUInt8,
+    )
+
+    if q_path.exists():
+        _report_size(q_path)
+        model_path.unlink(missing_ok=True)
+        data_path = model_path.with_name(model_path.name + "_data")
+        data_path.unlink(missing_ok=True)
+        q_path.rename(model_path)
+        print(f"  Renamed {q_path.name} -> {model_path.name}")
+
+
+def _convert_fp16(model_path: Path):
+    """Convert ONNX model weights to FP16."""
+    from onnxconverter_common import float16
+    import onnx
+
+    print("  Converting to FP16...")
+    model = onnx.load(str(model_path), load_external_data=True)
+    model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
+
+    fp16_path = model_path.with_name(model_path.stem + "_fp16" + model_path.suffix)
+    onnx.save(model_fp16, str(fp16_path))
+
+    if fp16_path.exists():
+        _report_size(fp16_path)
+        model_path.unlink(missing_ok=True)
+        data_path = model_path.with_name(model_path.name + "_data")
+        data_path.unlink(missing_ok=True)
+        fp16_path.rename(model_path)
+        print(f"  Renamed {fp16_path.name} -> {model_path.name}")
+
+
+def _report_size(path: Path):
+    """Print file size in MB."""
+    size_mb = path.stat().st_size / (1024 * 1024)
+    print(f"  {path.name}: {size_mb:.1f} MB")
 
-def _convert_fp16_external(model_path: Path, output_dir: Path):
-    """Convert ONNX model to FP16 using onnxconverter-common (handles external data)."""
-    print("  Converting to FP16 (external data format)...")
-    try:
-        from onnxconverter_common import float16
-        import onnx
-        model = onnx.load(str(model_path), load_external_data=True)
-        model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
-        fp16_path = output_dir / model_path.name
-        onnx.save(model_fp16, str(fp16_path))
-    except ImportError:
-        print("  onnxconverter-common not available; skipping FP16 conversion.")
 
 
 def validate_onnx(onnx_dir: str, base_model: str):
@@ -283,12 +291,12 @@ def validate_onnx(onnx_dir: str, base_model: str):
     print("  Validation PASSED")
 
 
-def write_transformers_js_config(onnx_dir: str):
+def write_transformers_js_config(onnx_dir: str, quantize_type: str = "q4"):
     """Write Transformers.js compatibility config."""
     config_path = Path(onnx_dir) / "transformers_js_config.json"
     config = {
         "model_type": "text-generation",
-        "quantized": True,
+        "quantized": quantize_type != "none",
     }
     config_path.write_text(json.dumps(config, indent=2) + "\n")
     print(f"  Wrote {config_path.name}")
@@ -300,6 +308,7 @@ def upload_to_hub(
     base_model: str,
     sft_model: str,
     grpo_model: str,
+    quantize_type: str = "q4",
 ):
     """Upload ONNX model to HuggingFace Hub."""
     print(f"\nStep 7: Uploading to {output_repo}...")
@@ -327,7 +336,7 @@ ONNX conversion of the QMD Query Expansion model for use with
 - **SFT:** {sft_model}
 - **GRPO:** {grpo_model}
 - **Task:** Query expansion (lex/vec/hyde format)
-- **Format:** ONNX with Q4 quantization
+- **Format:** ONNX with {quantize_type.upper()} quantization
 
 ## Usage with Transformers.js
 
@@ -336,7 +345,7 @@ import {{ AutoTokenizer, AutoModelForCausalLM }} from "@huggingface/transformers
 
 const tokenizer = await AutoTokenizer.from_pretrained("{output_repo}");
 const model = await AutoModelForCausalLM.from_pretrained("{output_repo}", {{
-  dtype: "q4",
+  dtype: "{quantize_type}",
   device: "webgpu",
 }});
 ```
@@ -427,7 +436,7 @@ def main():
     quantize_onnx(onnx_dir, args.quantize)
 
     # Write Transformers.js config
-    write_transformers_js_config(onnx_dir)
+    write_transformers_js_config(onnx_dir, args.quantize)
 
     # Validate
     if args.validate:
@@ -435,7 +444,7 @@ def main():
 
     # Upload
     if not args.no_upload:
-        upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model)
+        upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model, args.quantize)
 
     print(f"\nDone! ONNX files at: {onnx_dir}")
     if not args.no_upload: