|
@@ -10,7 +10,7 @@
|
|
|
# "sentencepiece>=0.1.99",
|
|
# "sentencepiece>=0.1.99",
|
|
|
# "protobuf>=3.20.0",
|
|
# "protobuf>=3.20.0",
|
|
|
# "numpy",
|
|
# "numpy",
|
|
|
-# "optimum>=1.17.0",
|
|
|
|
|
|
|
+# "optimum[onnxruntime]",
|
|
|
# "onnx>=1.15.0",
|
|
# "onnx>=1.15.0",
|
|
|
# "onnxruntime>=1.17.0",
|
|
# "onnxruntime>=1.17.0",
|
|
|
# ]
|
|
# ]
|
|
@@ -68,7 +68,7 @@ def merge_adapters(base_model: str, sft_model: str, grpo_model: str) -> tuple:
|
|
|
"""Load base model, merge SFT + GRPO adapters, return (model, tokenizer)."""
|
|
"""Load base model, merge SFT + GRPO adapters, return (model, tokenizer)."""
|
|
|
print(f"\nStep 1: Loading base model {base_model}...")
|
|
print(f"\nStep 1: Loading base model {base_model}...")
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
- base_model, torch_dtype=torch.float32, trust_remote_code=True,
|
|
|
|
|
|
|
+ base_model, dtype=torch.float32, trust_remote_code=True,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
print(f"Step 2: Merging SFT adapter {sft_model}...")
|
|
print(f"Step 2: Merging SFT adapter {sft_model}...")
|
|
@@ -94,98 +94,193 @@ def export_onnx(model, tokenizer, output_dir: str):
|
|
|
tokenizer.save_pretrained(merged_dir)
|
|
tokenizer.save_pretrained(merged_dir)
|
|
|
|
|
|
|
|
print(f"\nStep 5: Exporting to ONNX at {output_dir}...")
|
|
print(f"\nStep 5: Exporting to ONNX at {output_dir}...")
|
|
|
|
|
+ # no_post_process=True avoids the 2GB protobuf serialization limit
|
|
|
|
|
+ # that occurs during tied-weight deduplication on large FP32 models.
|
|
|
|
|
+ # The exported model still works correctly — the tied weights just
|
|
|
|
|
+ # aren't deduplicated in the graph, which is fine since we quantize next.
|
|
|
main_export(
|
|
main_export(
|
|
|
model_name_or_path=merged_dir,
|
|
model_name_or_path=merged_dir,
|
|
|
output=output_dir,
|
|
output=output_dir,
|
|
|
task="text-generation-with-past",
|
|
task="text-generation-with-past",
|
|
|
device="cpu",
|
|
device="cpu",
|
|
|
fp16=False,
|
|
fp16=False,
|
|
|
|
|
+ no_post_process=True,
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
# Clean up temp merged dir
|
|
# Clean up temp merged dir
|
|
|
shutil.rmtree(merged_dir, ignore_errors=True)
|
|
shutil.rmtree(merged_dir, ignore_errors=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _find_onnx_model(onnx_dir: str) -> Path:
|
|
|
|
|
+ """Find the main ONNX model file in the output directory."""
|
|
|
|
|
+ model_path = Path(onnx_dir) / "model.onnx"
|
|
|
|
|
+ if model_path.exists():
|
|
|
|
|
+ return model_path
|
|
|
|
|
+ candidates = list(Path(onnx_dir).glob("*.onnx"))
|
|
|
|
|
+ if not candidates:
|
|
|
|
|
+ raise FileNotFoundError(f"No .onnx files found in {onnx_dir}")
|
|
|
|
|
+ return candidates[0]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def quantize_onnx(onnx_dir: str, quantize_type: str):
|
|
def quantize_onnx(onnx_dir: str, quantize_type: str):
|
|
|
- """Quantize the exported ONNX model."""
|
|
|
|
|
|
|
+ """Quantize the exported ONNX model using optimum-cli for memory safety."""
|
|
|
if quantize_type == "none":
|
|
if quantize_type == "none":
|
|
|
print("\nSkipping quantization (FP32).")
|
|
print("\nSkipping quantization (FP32).")
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
- model_path = Path(onnx_dir) / "model.onnx"
|
|
|
|
|
- if not model_path.exists():
|
|
|
|
|
- # Optimum may produce decoder_model.onnx for text-generation-with-past
|
|
|
|
|
- candidates = list(Path(onnx_dir).glob("*.onnx"))
|
|
|
|
|
- if not candidates:
|
|
|
|
|
- print(" WARNING: No .onnx files found to quantize.")
|
|
|
|
|
- return
|
|
|
|
|
- model_path = candidates[0]
|
|
|
|
|
-
|
|
|
|
|
|
|
+ model_path = _find_onnx_model(onnx_dir)
|
|
|
print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
|
|
print(f"\nStep 6: Quantizing {model_path.name} ({quantize_type})...")
|
|
|
|
|
|
|
|
- if quantize_type == "q4":
|
|
|
|
|
- try:
|
|
|
|
|
- from onnxruntime.quantization import matmul_nbits_quantizer
|
|
|
|
|
- quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
|
|
|
|
|
- model=str(model_path),
|
|
|
|
|
- block_size=32,
|
|
|
|
|
- is_symmetric=True,
|
|
|
|
|
- bits=4,
|
|
|
|
|
- )
|
|
|
|
|
- quant.process()
|
|
|
|
|
- q_path = model_path.with_name(
|
|
|
|
|
- model_path.stem + "_q4" + model_path.suffix,
|
|
|
|
|
- )
|
|
|
|
|
- quant.model.save(str(q_path))
|
|
|
|
|
- size_mb = q_path.stat().st_size / (1024 * 1024)
|
|
|
|
|
- print(f" Q4: {size_mb:.1f} MB -> {q_path.name}")
|
|
|
|
|
- except ImportError:
|
|
|
|
|
- print(" WARNING: onnxruntime quantization not available, trying alternative...")
|
|
|
|
|
- _quantize_dynamic(model_path, quantize_type)
|
|
|
|
|
|
|
+ q_dir = Path(onnx_dir) / f"quantized_{quantize_type}"
|
|
|
|
|
+ q_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
|
|
|
|
+ if quantize_type == "q4":
|
|
|
|
|
+ # Use optimum-cli which handles external data format and memory properly.
|
|
|
|
|
+ # Direct MatMulNBitsQuantizer can OOM on large FP32 models.
|
|
|
|
|
+ cmd = [
|
|
|
|
|
+ sys.executable, "-m", "optimum.commands.optimum_cli",
|
|
|
|
|
+ "onnxruntime", "quantize",
|
|
|
|
|
+ "--onnx_model", onnx_dir,
|
|
|
|
|
+ "--o", str(q_dir),
|
|
|
|
|
+ "--avx2",
|
|
|
|
|
+ ]
|
|
|
|
|
+ print(f" Running: optimum-cli onnxruntime quantize ...")
|
|
|
|
|
+ result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
|
+ if result.returncode != 0:
|
|
|
|
|
+ print(f" optimum-cli failed, falling back to MatMulNBitsQuantizer...")
|
|
|
|
|
+ _quantize_q4_direct(model_path, q_dir)
|
|
|
elif quantize_type == "q8":
|
|
elif quantize_type == "q8":
|
|
|
- _quantize_dynamic(model_path, quantize_type)
|
|
|
|
|
-
|
|
|
|
|
|
|
+ cmd = [
|
|
|
|
|
+ sys.executable, "-m", "optimum.commands.optimum_cli",
|
|
|
|
|
+ "onnxruntime", "quantize",
|
|
|
|
|
+ "--onnx_model", onnx_dir,
|
|
|
|
|
+ "--o", str(q_dir),
|
|
|
|
|
+ "--avx2",
|
|
|
|
|
+ ]
|
|
|
|
|
+ result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
|
+ if result.returncode != 0:
|
|
|
|
|
+ print(f" optimum-cli quantize failed: {result.stderr[:300]}")
|
|
|
elif quantize_type == "fp16":
|
|
elif quantize_type == "fp16":
|
|
|
- _convert_fp16(model_path)
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def _quantize_dynamic(model_path: Path, qtype: str):
|
|
|
|
|
- """Dynamic quantization fallback."""
|
|
|
|
|
- from onnxruntime.quantization import quantize_dynamic, QuantType
|
|
|
|
|
-
|
|
|
|
|
- weight_type = QuantType.QUInt8 if qtype == "q8" else QuantType.QInt8
|
|
|
|
|
- q_path = model_path.with_name(
|
|
|
|
|
- model_path.stem + f"_{qtype}" + model_path.suffix,
|
|
|
|
|
|
|
+ _convert_fp16_external(model_path, q_dir)
|
|
|
|
|
+
|
|
|
|
|
+ # Report sizes
|
|
|
|
|
+ for f in sorted(q_dir.glob("*.onnx*")):
|
|
|
|
|
+ size_mb = f.stat().st_size / (1024 * 1024)
|
|
|
|
|
+ if size_mb > 1:
|
|
|
|
|
+ print(f" {f.name}: {size_mb:.1f} MB")
|
|
|
|
|
+
|
|
|
|
|
+ # Move quantized files back into main dir, replacing originals
|
|
|
|
|
+ for f in q_dir.iterdir():
|
|
|
|
|
+ dest = Path(onnx_dir) / f.name
|
|
|
|
|
+ if dest.exists():
|
|
|
|
|
+ dest.unlink()
|
|
|
|
|
+ shutil.move(str(f), str(dest))
|
|
|
|
|
+ q_dir.rmdir()
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _quantize_q4_direct(model_path: Path, output_dir: Path):
|
|
|
|
|
+ """Direct Q4 quantization fallback (may need significant RAM)."""
|
|
|
|
|
+ from onnxruntime.quantization import matmul_nbits_quantizer
|
|
|
|
|
+ quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
|
|
|
|
|
+ model=str(model_path),
|
|
|
|
|
+ block_size=32,
|
|
|
|
|
+ is_symmetric=True,
|
|
|
|
|
+ bits=4,
|
|
|
)
|
|
)
|
|
|
- quantize_dynamic(
|
|
|
|
|
- model_input=str(model_path),
|
|
|
|
|
- model_output=str(q_path),
|
|
|
|
|
- weight_type=weight_type,
|
|
|
|
|
|
|
+ quant.process()
|
|
|
|
|
+ q_path = output_dir / model_path.name
|
|
|
|
|
+ quant.model.save(str(q_path))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _convert_fp16_external(model_path: Path, output_dir: Path):
|
|
|
|
|
+ """Convert ONNX model to FP16 using onnxconverter-common (handles external data)."""
|
|
|
|
|
+ print(" Converting to FP16 (external data format)...")
|
|
|
|
|
+ try:
|
|
|
|
|
+ from onnxconverter_common import float16
|
|
|
|
|
+ import onnx
|
|
|
|
|
+ model = onnx.load(str(model_path), load_external_data=True)
|
|
|
|
|
+ model_fp16 = float16.convert_float_to_float16(model, keep_io_types=True)
|
|
|
|
|
+ fp16_path = output_dir / model_path.name
|
|
|
|
|
+ onnx.save(model_fp16, str(fp16_path))
|
|
|
|
|
+ except ImportError:
|
|
|
|
|
+ print(" onnxconverter-common not available; skipping FP16 conversion.")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def validate_onnx(onnx_dir: str, base_model: str):
|
|
|
|
|
+ """Run a sample inference through the ONNX model to verify it works."""
|
|
|
|
|
+ import onnxruntime as ort
|
|
|
|
|
+ import numpy as np
|
|
|
|
|
+
|
|
|
|
|
+ model_path = _find_onnx_model(onnx_dir)
|
|
|
|
|
+ print(f"\nValidation: loading {model_path.name}...")
|
|
|
|
|
+
|
|
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(onnx_dir, trust_remote_code=True)
|
|
|
|
|
+ session = ort.InferenceSession(
|
|
|
|
|
+ str(model_path),
|
|
|
|
|
+ providers=["CPUExecutionProvider"],
|
|
|
)
|
|
)
|
|
|
- size_mb = q_path.stat().st_size / (1024 * 1024)
|
|
|
|
|
- print(f" {qtype.upper()}: {size_mb:.1f} MB -> {q_path.name}")
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def _convert_fp16(model_path: Path):
|
|
|
|
|
- """Convert ONNX model to FP16."""
|
|
|
|
|
- import onnx
|
|
|
|
|
- from onnx import numpy_helper
|
|
|
|
|
-
|
|
|
|
|
- print(" Converting to FP16...")
|
|
|
|
|
- model = onnx.load(str(model_path))
|
|
|
|
|
- for initializer in model.graph.initializer:
|
|
|
|
|
- if initializer.data_type == onnx.TensorProto.FLOAT:
|
|
|
|
|
- np_data = numpy_helper.to_array(initializer)
|
|
|
|
|
- initializer.CopyFrom(
|
|
|
|
|
- numpy_helper.from_array(np_data.astype("float16"), initializer.name),
|
|
|
|
|
- )
|
|
|
|
|
- fp16_path = model_path.with_name(
|
|
|
|
|
- model_path.stem + "_fp16" + model_path.suffix,
|
|
|
|
|
|
|
+
|
|
|
|
|
+ # Tokenize a test prompt
|
|
|
|
|
+ test_query = "/no_think Expand this search query: distributed consensus"
|
|
|
|
|
+ chat_prompt = tokenizer.apply_chat_template(
|
|
|
|
|
+ [{"role": "user", "content": test_query}],
|
|
|
|
|
+ add_generation_prompt=True,
|
|
|
|
|
+ tokenize=False,
|
|
|
)
|
|
)
|
|
|
- onnx.save(model, str(fp16_path))
|
|
|
|
|
- size_mb = fp16_path.stat().st_size / (1024 * 1024)
|
|
|
|
|
- print(f" FP16: {size_mb:.1f} MB -> {fp16_path.name}")
|
|
|
|
|
|
|
+ inputs = tokenizer(chat_prompt, return_tensors="np")
|
|
|
|
|
+ input_ids = inputs["input_ids"].astype(np.int64)
|
|
|
|
|
+ attention_mask = inputs["attention_mask"].astype(np.int64)
|
|
|
|
|
+
|
|
|
|
|
+ # Build feed dict with all required inputs
|
|
|
|
|
+ seq_len = input_ids.shape[1]
|
|
|
|
|
+ feed = {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
|
|
|
+
|
|
|
|
|
+ # Add position_ids if needed
|
|
|
|
|
+ all_inputs = {inp.name: inp for inp in session.get_inputs()}
|
|
|
|
|
+ if "position_ids" in all_inputs:
|
|
|
|
|
+ feed["position_ids"] = np.arange(seq_len, dtype=np.int64).reshape(1, -1)
|
|
|
|
|
+
|
|
|
|
|
+ # Initialize past_key_values to zeros if the model expects them
|
|
|
|
|
+ for name, inp in sorted(all_inputs.items()):
|
|
|
|
|
+ if name.startswith("past_key_values"):
|
|
|
|
|
+ shape = []
|
|
|
|
|
+ for dim in inp.shape:
|
|
|
|
|
+ shape.append(dim if isinstance(dim, int) else 0)
|
|
|
|
|
+ # batch dim = 1
|
|
|
|
|
+ if shape and shape[0] == 0:
|
|
|
|
|
+ shape[0] = 1
|
|
|
|
|
+ feed[name] = np.zeros(shape, dtype=np.float32)
|
|
|
|
|
+
|
|
|
|
|
+ # Run inference
|
|
|
|
|
+ output_names = [o.name for o in session.get_outputs()]
|
|
|
|
|
+ results = session.run(output_names, feed)
|
|
|
|
|
+
|
|
|
|
|
+ # Check logits shape
|
|
|
|
|
+ logits = results[0]
|
|
|
|
|
+ print(f" Input tokens: {input_ids.shape[1]}")
|
|
|
|
|
+ print(f" Output logits shape: {logits.shape}")
|
|
|
|
|
+ print(f" Logits range: [{logits.min():.2f}, {logits.max():.2f}]")
|
|
|
|
|
+
|
|
|
|
|
+ # Greedy decode next token
|
|
|
|
|
+ next_token_id = int(np.argmax(logits[0, -1, :]))
|
|
|
|
|
+ next_token = tokenizer.decode([next_token_id])
|
|
|
|
|
+ print(f" Next token: {repr(next_token)} (id={next_token_id})")
|
|
|
|
|
+
|
|
|
|
|
+ # Check KV cache outputs exist
|
|
|
|
|
+ kv_outputs = [n for n in output_names if n.startswith("present")]
|
|
|
|
|
+ if kv_outputs:
|
|
|
|
|
+ print(f" KV cache outputs: {len(kv_outputs)} tensors (generation-ready)")
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(" WARNING: No KV cache outputs — model may not support efficient generation")
|
|
|
|
|
+
|
|
|
|
|
+ # Sanity checks
|
|
|
|
|
+ assert logits.shape[0] == 1, "Batch size mismatch"
|
|
|
|
|
+ assert logits.shape[1] == input_ids.shape[1], "Sequence length mismatch"
|
|
|
|
|
+ assert logits.max() > logits.min(), "Logits are constant (broken model)"
|
|
|
|
|
+ assert not np.isnan(logits).any(), "Logits contain NaN"
|
|
|
|
|
+ assert not np.isinf(logits).any(), "Logits contain Inf"
|
|
|
|
|
+
|
|
|
|
|
+ print(" Validation PASSED")
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_transformers_js_config(onnx_dir: str):
|
|
def write_transformers_js_config(onnx_dir: str):
|
|
@@ -278,8 +373,21 @@ def main():
|
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
|
"--no-upload", action="store_true", help="Don't upload to HF Hub",
|
|
"--no-upload", action="store_true", help="Don't upload to HF Hub",
|
|
|
)
|
|
)
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--validate", action="store_true",
|
|
|
|
|
+ help="Run inference validation on exported model",
|
|
|
|
|
+ )
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--validate-only", metavar="DIR",
|
|
|
|
|
+ help="Skip export, only validate an existing ONNX dir",
|
|
|
|
|
+ )
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
+ # Validate-only mode: skip export, just run validation
|
|
|
|
|
+ if args.validate_only:
|
|
|
|
|
+ validate_onnx(args.validate_only, "")
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
# Resolve config
|
|
# Resolve config
|
|
|
if args.size:
|
|
if args.size:
|
|
|
preset = PRESETS[args.size]
|
|
preset = PRESETS[args.size]
|
|
@@ -321,6 +429,10 @@ def main():
|
|
|
# Write Transformers.js config
|
|
# Write Transformers.js config
|
|
|
write_transformers_js_config(onnx_dir)
|
|
write_transformers_js_config(onnx_dir)
|
|
|
|
|
|
|
|
|
|
+ # Validate
|
|
|
|
|
+ if args.validate:
|
|
|
|
|
+ validate_onnx(onnx_dir, base_model)
|
|
|
|
|
+
|
|
|
# Upload
|
|
# Upload
|
|
|
if not args.no_upload:
|
|
if not args.no_upload:
|
|
|
upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model)
|
|
upload_to_hub(onnx_dir, output_repo, base_model, sft_model, grpo_model)
|