Add QuantizeEmbeddingInt8 and ShareEmbeddingLmHead graph surgeries for INT8 embedding quantization#2464
Add QuantizeEmbeddingInt8 and ShareEmbeddingLmHead graph surgeries for INT8 embedding quantization#2464apsonawane wants to merge 8 commits into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds two new ONNX graph surgeries to enable post-hoc INT8 embedding quantization and embedding/lm_head weight sharing (to reduce model size for large-vocab LLMs), and updates the lm-eval ORT evaluator + IO utilities to better support hybrid attention architectures and pruned/non-contiguous KV-cache indices.
Changes:
- Add
QuantizeEmbeddingInt8(FP16/FP32Gather→ INT8GatherBlockQuantized) andShareEmbeddingLmHead(reuse embedding quantization params/weights for INT8MatMulNBits) graph surgeries. - Improve
lmeval_ortruntime IO binding to support 3Dposition_ids(mRoPE) and hybrid state tensors (conv_state/recurrent_state). - Fix KV-cache layer index detection for non-contiguous layer indices and make LM-eval metric parsing more robust to varied key formats/values.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
olive/passes/onnx/graph_surgeries.py |
Adds two new embedding-focused graph surgeries and helper functions. |
olive/passes/onnx/model_builder.py |
Removes a debug message about ignored tied-embedding flags in embedding construction. |
olive/evaluator/lmeval_ort.py |
Adds support for mRoPE position_ids rank detection and hybrid state IO binding/buffers. |
olive/evaluator/olive_evaluator.py |
Tightens parsing of lm-eval metric outputs (skip aliases/non-numeric, handle comma keys). |
olive/common/onnx_io.py |
Detects actual KV-cache layer indices from input names (supports non-contiguous indices). |
test/passes/onnx/test_quantize_embedding.py |
Adds unit tests covering the new embedding surgeries. |
Comments suppressed due to low confidence (1)
test/passes/onnx/test_quantize_embedding.py:176
old_init_namesis assigned but never used, which will fail linting (ruff F841). Remove the variable or assert on it (e.g., compare old vs new initializers) so the assignment is meaningful.
old_init_names = {init.name for init in model.graph.initializer}
| # find the actual layer indices (may be non-contiguous after pruning) | ||
| layer_indices = [] | ||
| for i_name in io_config["input_names"]: | ||
| num_layers += int(re.match(kv_format, i_name) is not None) | ||
| m = re.match(kv_format, i_name) | ||
| if m: | ||
| idx = int(m.group(1)) | ||
| if idx not in layer_indices: | ||
| layer_indices.append(idx) | ||
| layer_indices.sort() |
There was a problem hiding this comment.
nit: declare layer_indicies as a set and convert to list with sorted after iteration.
| if "position_ids" in self.io_config["input_names"]: | ||
| idx = self.io_config["input_names"].index("position_ids") | ||
| self.position_ids_rank = len(self.io_config["input_shapes"][idx]) |
There was a problem hiding this comment.
You could merge this condition with the loop below. That would avoid multiple iterations thru' the list.
| self.hybrid_states = {} | ||
| for idx, name in enumerate(self.io_config["input_names"]): | ||
| if "conv_state" in name or "recurrent_state" in name: | ||
| shape = self.io_config["input_shapes"][idx] | ||
| dtype = self.io_config["input_types"][idx] | ||
| self.hybrid_states[name] = {"shape": shape, "dtype": dtype} | ||
|
|
||
| # detect hybrid state outputs | ||
| self.hybrid_state_outputs = {} | ||
| for idx, name in enumerate(self.io_config["output_names"]): | ||
| if "conv_state" in name or "recurrent_state" in name: | ||
| shape = self.io_config["output_shapes"][idx] | ||
| dtype = self.io_config["output_types"][idx] | ||
| self.hybrid_state_outputs[name] = {"shape": shape, "dtype": dtype} | ||
|
|
There was a problem hiding this comment.
These loops can be merged into one!
| def _find_embed_node(model, op_type, label): | ||
| """Find the embed_tokens node of the given op_type and its index.""" | ||
| for i, node in enumerate(model.graph.node): | ||
| if node.op_type == op_type and "embed_tokens" in node.name: | ||
| return node, i | ||
| logger.warning("No embed_tokens %s node found, skipping %s", op_type, label) | ||
| return None, None | ||
|
|
||
|
|
||
| def _find_lm_head_node(model): | ||
| """Find the lm_head MatMulNBits node and its index.""" | ||
| for i, node in enumerate(model.graph.node): | ||
| if node.op_type == "MatMulNBits" and "lm_head" in node.name: | ||
| return node, i | ||
| logger.warning("No lm_head MatMulNBits found") | ||
| return None, None | ||
|
|
||
|
|
||
| def _find_initializer(model, name): | ||
| """Find an initializer by name.""" | ||
| for init in model.graph.initializer: | ||
| if init.name == name: | ||
| return init | ||
| return None | ||
|
|
||
|
|
||
| def _get_node_attrs(node, *attr_names): | ||
| """Extract integer attributes from a node by name.""" | ||
| result = {} | ||
| for attr in node.attribute: | ||
| if attr.name in attr_names: | ||
| result[attr.name] = attr.i | ||
| return result | ||
|
|
||
|
|
||
| def _ensure_msft_opset(model): | ||
| """Ensure com.microsoft opset import is present in the model.""" | ||
| for opset in model.opset_import: | ||
| if opset.domain == "com.microsoft": | ||
| return | ||
| model.opset_import.append(onnx.helper.make_opsetid("com.microsoft", 1)) |
| model.graph.initializer.append(numpy_helper.from_array(q_flat, name=qweight_name)) | ||
| model.graph.initializer.append(numpy_helper.from_array(scales, name=scales_name)) | ||
| model.graph.initializer.append(numpy_helper.from_array(zero_points, name=zp_name)) |
Add QuantizeEmbeddingInt8 and ShareEmbeddingLmHead graph surgeries
Summary
Adds two new graph surgeries for post-hoc INT8 embedding quantization and weight sharing, along with evaluator fixes for hybrid attention architectures (e.g., Qwen3.5-2B with GatedDeltaNet + standard attention).
Motivation
Models with large vocabularies (e.g., Qwen3.5-2B with 248K tokens) have FP16 embeddings that dominate model size (~970 MB out of 2.0 GB for INT4 weights). The ModelBuilder's default quantizer (Neural Compressor) only quantizes
MatMulops, leavingGather(embedding) as FP16. RTN-based quantizers that support INT8 embedding natively (k_quant_last) destroy accuracy on hybrid architectures (26% vs 59% MMLU).Changes
New Graph Surgeries (
graph_surgeries.py)QuantizeEmbeddingInt8: Converts FP16Gatherembedding to INT8GatherBlockQuantizedwith per-block asymmetric quantization (zero_point=128, block_size=32). Reduces embedding from ~970 MB to ~530 MB with negligible accuracy loss.ShareEmbeddingLmHead: Replaces lm_head's INT4MatMulNBitswith INT8MatMulNBitssharing the embedding weight viaReshape, eliminating duplicate storage. Saves ~250 MB._find_embed_node,_find_lm_head_node,_find_initializer,_get_node_attrsEvaluator Fixes
lmeval_ort.py: Support for 3Dposition_ids(mRoPE) and hybridconv_state/recurrent_stateinputs for models with mixed attention + linear attention layersolive_evaluator.py: Fix metric parsing for lm-eval results with non-comma metric keys and non-numeric valuesonnx_io.py: Fix KV cache layer index detection for non-contiguous indices (e.g., attention at layers 3,7,11,15,19,23 only)Results (Qwen3.5-2B)
Testing
test/passes/onnx/test_quantize_embedding.py