فهرست منبع

fix: handle 2GB protobuf limit, add validation, fix input feeds

- Use no_post_process=True for ONNX export to avoid protobuf serialize error
- Add --validate and --validate-only flags for inference verification
- Fix position_ids in validation feed (required by Qwen3 ONNX export)
- Use optimum-cli for quantization to handle external data format
- Fix optimum dependency to optimum[onnxruntime]

Tested: export + validation passes on CPU, KV cache present (56 tensors).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Shreyas Karnik 2 ماه پیش
والد
کامیت
e1ce37c989
1فایلهای تغییر یافته به همراه181 افزوده شده و 69 حذف شده
  1. 181 69
      finetune/convert_onnx.py

+ 181 - 69
finetune/convert_onnx.py

@@ -10,7 +10,7 @@
 #     "sentencepiece>=0.1.99",
 #     "protobuf>=3.20.0",
 #     "numpy",
-#     "optimum>=1.17.0",
+#     "optimum[onnxruntime]",
 #     "onnx>=1.15.0",
 #     "onnxruntime>=1.17.0",
 # ]
@@ -68,7 +68,7 @@ def merge_adapters(base_model: str, sft_model: str, grpo_model: str) -> tuple:
     """Load base model, merge SFT + GRPO adapters, return (model, tokenizer)."""
     print(f"\nStep 1: Loading base model {base_model}...")
     model = AutoModelForCausalLM.from_pretrained(
-        base_model, torch_dtype=torch.float32, trust_remote_code=True,
+        base_model, dtype=torch.float32, trust_remote_code=True,
     )
 
     print(f"Step 2: Merging SFT adapter {sft_model}...")
@@ -94,98 +94,193 @@ def export_onnx(model, tokenizer, output_dir: str):
     tokenizer.save_pretrained(merged_dir)
 
     print(f"\nStep 5: Exporting to ONNX at {output_dir}...")
+    # no_post_process=True avoids the 2GB protobuf serialization limit
+    # that occurs during tied-weight deduplication on large FP32 models.
+    # The exported model still works correctly — the tied weights just
+    # aren't deduplicated in the graph, which is fine since we quantize next.
     main_export(
         model_name_or_path=merged_dir,
         output=output_dir,
         task="text-generation-with-past",
         device="cpu",
         fp16=False,
+        no_post_process=True,
     )
 
     # Clean up temp merged dir
     shutil.rmtree(merged_dir, ignore_errors=True)
 
 
+def _find_onnx_model(onnx_dir: str) -> Path:
+    """Find the main ONNX model file in the output directory."""
+    model_path = Path(onnx_dir) / "model.onnx"
+    if model_path.exists():
+        return model_path
+    candidates = list(Path(onnx_dir).glob("*.onnx"))
+    if not candidates:
+        raise FileNotFoundError(f"No .onnx files found in {onnx_dir}")
+    return candidates[0]
+
+
 def quantize_onnx(onnx_dir: str, quantize_type: str):
-    """Quantize the exported ONNX model."""
+    """Quantize the exported ONNX model using optimum-cli for memory safety."""
     if quantize_type == "none":
         print("\nSkipping quantization (FP32).")
         return
 
-    model_path = Path(onnx_dir) / "model.onnx"
-    if not model_path.exists():
-        # Optimum may produce decoder_model.onnx for text-generation-with-past
-        candidates = list(Path(onnx_dir).glob("*.onnx"))
-        if not candidates:
-            print("  WARNING: No .onnx files found to quantize.")
-            return
-        model_path = candidates[0]
-
+    model_path = _find_onnx_model(onnx_dir)
     print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
 
-    if quantize_type == "q4":
-        try:
-            from onnxruntime.quantization import matmul_nbits_quantizer
-            quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
-                model=str(model_path),
-                block_size=32,
-                is_symmetric=True,
-                bits=4,
-            )
-            quant.process()
-            q_path = model_path.with_name(
-                model_path.stem + "_q4" + model_path.suffix,
-            )
-            quant.model.save(str(q_path))
-            size_mb = q_path.stat().st_size / (1024 * 1024)
-            print(f"  Q4: {size_mb:.1f} MB -> {q_path.name}")
-        except ImportError:
-            print("  WARNING: onnxruntime quantization not available, trying alternative...")
-            _quantize_dynamic(model_path, quantize_type)
+    q_dir = Path(onnx_dir) / f"quantized_{quantize_type}"
+    q_dir.mkdir(exist_ok=True)
 
+    if quantize_type == "q4":
+        # Use optimum-cli which handles external data format and memory properly.
+        # Direct MatMulNBitsQuantizer can OOM on large FP32 models.
+        cmd = [
+            sys.executable, "-m", "optimum.commands.optimum_cli",
+            "onnxruntime", "quantize",
+            "--onnx_model", onnx_dir,
+            "--o", str(q_dir),
+            "--avx2",
+        ]
+        print(f"  Running: optimum-cli onnxruntime quantize ...")
+        result = subprocess.run(cmd, capture_output=True, text=True)
+        if result.returncode != 0:
+            print(f"  optimum-cli failed, falling back to MatMulNBitsQuantizer...")
+            _quantize_q4_direct(model_path, q_dir)
     elif quantize_type == "q8":
-        _quantize_dynamic(model_path, quantize_type)
-
+        cmd = [
+            sys.executable, "-m", "optimum.commands.optimum_cli",
+            "onnxruntime", "quantize",
+            "--onnx_model", onnx_dir,
+            "--o", str(q_dir),
+            "--avx2",
+        ]
+        result = subprocess.run(cmd, capture_output=True, text=True)
+        if result.returncode != 0:
+            print(f"  optimum-cli quantize failed: {result.stderr[:300]}")
     elif quantize_type == "fp16":
-        _convert_fp16(model_path)
-
-
-def _quantize_dynamic(model_path: Path, qtype: str):
-    """Dynamic quantization fallback."""
-    from onnxruntime.quantization import quantize_dynamic, QuantType
-
-    weight_type = QuantType.QUInt8 if qtype == "q8" else QuantType.QInt8
-    q_path = model_path.with_name(
-        model_path.stem + f"_{qtype}" + model_path.suffix,
+        _convert_fp16_external(model_path, q_dir)
+
+    # Report sizes
+    for f in sorted(q_dir.glob("*.onnx*")):
+        size_mb = f.stat().st_size / (1024 * 1024)
+        if size_mb > 1:
+            print(f"  {f.name}: {size_mb:.1f} MB")
+
+    # Move quantized files back into main dir, replacing originals
+    for f in q_dir.iterdir():
+        dest = Path(onnx_dir) / f.name
+        if dest.exists():
+            dest.unlink()
+        shutil.move(str(f), str(dest))
+    q_dir.rmdir()
+
+
+def _quantize_q4_direct(model_path: Path, output_dir: Path):
+    """Direct Q4 quantization fallback (may need significant RAM)."""
+    from onnxruntime.quantization import matmul_nbits_quantizer
+    quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
+        model=str(model_path),
+        block_size=32,
+        is_symmetric=True,
+        bits=4,
     )
-    quantize_dynamic(
-        model_input=str(model_path),
-        model_output=str(q_path),
-        weight_type=weight_type,
+    quant.process()
+    q_path = output_dir / model_path.name
+    quant.model.save(str(q_path))
+
+
+def _convert_fp16_external(model_path: Path, output_dir: Path):
+    """Convert ONNX model to FP16 using onnxconverter-common (handles external data)."""
+    print("  Converting to FP16 (external data format)...")
+    try:
+        from onnxconverter_common import float16
+        import onnx
+        model = onnx.load(str(model_path), load_external_data=True)
+        model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
+        fp16_path = output_dir / model_path.name
+        onnx.save(model_fp16, str(fp16_path))
+    except ImportError:
+        print("  onnxconverter-common not available; skipping FP16 conversion.")
+
+
+def validate_onnx(onnx_dir: str, base_model: str):
+    """Run a sample inference through the ONNX model to verify it works."""
+    import onnxruntime as ort
+    import numpy as np
+
+    model_path = _find_onnx_model(onnx_dir)
+    print(f"\nValidation: loading {model_path.name}...")
+
+    tokenizer = AutoTokenizer.from_pretrained(onnx_dir, trust_remote_code=True)
+    session = ort.InferenceSession(
+        str(model_path),
+        providers=["CPUExecutionProvider"],
     )
-    size_mb = q_path.stat().st_size / (1024 * 1024)
-    print(f"  {qtype.upper()}: {size_mb:.1f} MB -> {q_path.name}")
-
-
-def _convert_fp16(model_path: Path):
-    """Convert ONNX model to FP16."""
-    import onnx
-    from onnx import numpy_helper
-
-    print("  Converting to FP16...")
-    model = onnx.load(str(model_path))
-    for initializer in model.graph.initializer:
-        if initializer.data_type == onnx.TensorProto.FLOAT:
-            np_data = numpy_helper.to_array(initializer)
-            initializer.CopyFrom(
-                numpy_helper.from_array(np_data.astype("float16"), initializer.name),
-            )
-    fp16_path = model_path.with_name(
-        model_path.stem + "_fp16" + model_path.suffix,
+
+    # Tokenize a test prompt
+    test_query = "/no_think Expand this search query: distributed consensus"
+    chat_prompt = tokenizer.apply_chat_template(
+        [{"role": "user", "content": test_query}],
+        add_generation_prompt=True,
+        tokenize=False,
     )
-    onnx.save(model, str(fp16_path))
-    size_mb = fp16_path.stat().st_size / (1024 * 1024)
-    print(f"  FP16: {size_mb:.1f} MB -> {fp16_path.name}")
+    inputs = tokenizer(chat_prompt, return_tensors="np")
+    input_ids = inputs["input_ids"].astype(np.int64)
+    attention_mask = inputs["attention_mask"].astype(np.int64)
+
+    # Build feed dict with all required inputs
+    seq_len = input_ids.shape[1]
+    feed = {"input_ids": input_ids, "attention_mask": attention_mask}
+
+    # Add position_ids if needed
+    all_inputs = {inp.name: inp for inp in session.get_inputs()}
+    if "position_ids" in all_inputs:
+        feed["position_ids"] = np.arange(seq_len, dtype=np.int64).reshape(1, -1)
+
+    # Initialize past_key_values to zeros if the model expects them
+    for name, inp in sorted(all_inputs.items()):
+        if name.startswith("past_key_values"):
+            shape = []
+            for dim in inp.shape:
+                shape.append(dim if isinstance(dim, int) else 0)
+            # batch dim = 1
+            if shape and shape[0] == 0:
+                shape[0] = 1
+            feed[name] = np.zeros(shape, dtype=np.float32)
+
+    # Run inference
+    output_names = [o.name for o in session.get_outputs()]
+    results = session.run(output_names, feed)
+
+    # Check logits shape
+    logits = results[0]
+    print(f"  Input tokens: {input_ids.shape[1]}")
+    print(f"  Output logits shape: {logits.shape}")
+    print(f"  Logits range: [{logits.min():.2f}, {logits.max():.2f}]")
+
+    # Greedy decode next token
+    next_token_id = int(np.argmax(logits[0, -1, :]))
+    next_token = tokenizer.decode([next_token_id])
+    print(f"  Next token: {repr(next_token)} (id={next_token_id})")
+
+    # Check KV cache outputs exist
+    kv_outputs = [n for n in output_names if n.startswith("present")]
+    if kv_outputs:
+        print(f"  KV cache outputs: {len(kv_outputs)} tensors (generation-ready)")
+    else:
+        print("  WARNING: No KV cache outputs — model may not support efficient generation")
+
+    # Sanity checks
+    assert logits.shape[0] == 1, "Batch size mismatch"
+    assert logits.shape[1] == input_ids.shape[1], "Sequence length mismatch"
+    assert logits.max() > logits.min(), "Logits are constant (broken model)"
+    assert not np.isnan(logits).any(), "Logits contain NaN"
+    assert not np.isinf(logits).any(), "Logits contain Inf"
+
+    print("  Validation PASSED")
 
 
 def write_transformers_js_config(onnx_dir: str):
@@ -278,8 +373,21 @@ def main():
     parser.add_argument(
         "--no-upload", action="store_true", help="Don't upload to HF Hub",
     )
+    parser.add_argument(
+        "--validate", action="store_true",
+        help="Run inference validation on exported model",
+    )
+    parser.add_argument(
+        "--validate-only", metavar="DIR",
+        help="Skip export, only validate an existing ONNX dir",
+    )
     args = parser.parse_args()
 
+    # Validate-only mode: skip export, just run validation
+    if args.validate_only:
+        validate_onnx(args.validate_only, "")
+        return
+
     # Resolve config
     if args.size:
         preset = PRESETS[args.size]
@@ -321,6 +429,10 @@ def main():
     # Write Transformers.js config
     write_transformers_js_config(onnx_dir)
 
+    # Validate
+    if args.validate:
+        validate_onnx(onnx_dir, base_model)
+
     # Upload
     if not args.no_upload:
         upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model)