Skip to content

Add QuantizeEmbeddingInt8 and ShareEmbeddingLmHead graph surgeries for INT8 embedding quantization#2464

Open
apsonawane wants to merge 8 commits into
mainfrom
asonawane/tieword
Open

Add QuantizeEmbeddingInt8 and ShareEmbeddingLmHead graph surgeries for INT8 embedding quantization#2464
apsonawane wants to merge 8 commits into
mainfrom
asonawane/tieword

Conversation

@apsonawane
Copy link
Copy Markdown
Contributor

Add QuantizeEmbeddingInt8 and ShareEmbeddingLmHead graph surgeries

Summary

Adds two new graph surgeries for post-hoc INT8 embedding quantization and weight sharing, along with evaluator fixes for hybrid attention architectures (e.g., Qwen3.5-2B with GatedDeltaNet + standard attention).

Motivation

Models with large vocabularies (e.g., Qwen3.5-2B with 248K tokens) have FP16 embeddings that dominate model size (~970 MB out of 2.0 GB for INT4 weights). The ModelBuilder's default quantizer (Neural Compressor) only quantizes MatMul ops, leaving Gather (embedding) as FP16. RTN-based quantizers that support INT8 embedding natively (k_quant_last) destroy accuracy on hybrid architectures (26% vs 59% MMLU).

Changes

New Graph Surgeries (graph_surgeries.py)

  • QuantizeEmbeddingInt8: Converts FP16 Gather embedding to INT8 GatherBlockQuantized with per-block asymmetric quantization (zero_point=128, block_size=32). Reduces embedding from ~970 MB to ~530 MB with negligible accuracy loss.
  • ShareEmbeddingLmHead: Replaces lm_head's INT4 MatMulNBits with INT8 MatMulNBits sharing the embedding weight via Reshape, eliminating duplicate storage. Saves ~250 MB.
  • Helper functions: _find_embed_node, _find_lm_head_node, _find_initializer, _get_node_attrs

Evaluator Fixes

  • lmeval_ort.py: Support for 3D position_ids (mRoPE) and hybrid conv_state/recurrent_state inputs for models with mixed attention + linear attention layers
  • olive_evaluator.py: Fix metric parsing for lm-eval results with non-comma metric keys and non-numeric values
  • onnx_io.py: Fix KV cache layer index detection for non-contiguous indices (e.g., attention at layers 3,7,11,15,19,23 only)

Results (Qwen3.5-2B)

Configuration Size MMLU Δ vs FP16
Baseline FP16 4.3 GB 59.27%
INT4 weights + FP16 embed 2.0 GB 57.21% -2.06%
INT4 weights + INT8 embed 1.6 GB 57.19% -2.08%
INT4 weights + shared INT8 embed/lm_head 1.4 GB 57.11% -2.16%

Testing

  • 7 unit tests added in test/passes/onnx/test_quantize_embedding.py
  • All tests pass
  • End-to-end validated via Olive recipe with MMLU evaluation

Copilot AI review requested due to automatic review settings May 14, 2026 00:17
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds two new ONNX graph surgeries to enable post-hoc INT8 embedding quantization and embedding/lm_head weight sharing (to reduce model size for large-vocab LLMs), and updates the lm-eval ORT evaluator + IO utilities to better support hybrid attention architectures and pruned/non-contiguous KV-cache indices.

Changes:

  • Add QuantizeEmbeddingInt8 (FP16/FP32 Gather → INT8 GatherBlockQuantized) and ShareEmbeddingLmHead (reuse embedding quantization params/weights for INT8 MatMulNBits) graph surgeries.
  • Improve lmeval_ort runtime IO binding to support 3D position_ids (mRoPE) and hybrid state tensors (conv_state / recurrent_state).
  • Fix KV-cache layer index detection for non-contiguous layer indices and make LM-eval metric parsing more robust to varied key formats/values.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
olive/passes/onnx/graph_surgeries.py Adds two new embedding-focused graph surgeries and helper functions.
olive/passes/onnx/model_builder.py Removes a debug message about ignored tied-embedding flags in embedding construction.
olive/evaluator/lmeval_ort.py Adds support for mRoPE position_ids rank detection and hybrid state IO binding/buffers.
olive/evaluator/olive_evaluator.py Tightens parsing of lm-eval metric outputs (skip aliases/non-numeric, handle comma keys).
olive/common/onnx_io.py Detects actual KV-cache layer indices from input names (supports non-contiguous indices).
test/passes/onnx/test_quantize_embedding.py Adds unit tests covering the new embedding surgeries.
Comments suppressed due to low confidence (1)

test/passes/onnx/test_quantize_embedding.py:176

  • old_init_names is assigned but never used, which will fail linting (ruff F841). Remove the variable or assert on it (e.g., compare old vs new initializers) so the assignment is meaningful.

        old_init_names = {init.name for init in model.graph.initializer}

Comment thread test/passes/onnx/test_quantize_embedding.py Outdated
Comment thread olive/passes/onnx/graph_surgeries.py
Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread test/passes/onnx/test_quantize_embedding.py Fixed
Comment thread olive/common/onnx_io.py Outdated
Comment on lines +92 to +100
# find the actual layer indices (may be non-contiguous after pruning)
layer_indices = []
for i_name in io_config["input_names"]:
num_layers += int(re.match(kv_format, i_name) is not None)
m = re.match(kv_format, i_name)
if m:
idx = int(m.group(1))
if idx not in layer_indices:
layer_indices.append(idx)
layer_indices.sort()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: declare layer_indicies as a set and convert to list with sorted after iteration.

Comment thread olive/evaluator/lmeval_ort.py Outdated
Comment on lines +298 to +300
if "position_ids" in self.io_config["input_names"]:
idx = self.io_config["input_names"].index("position_ids")
self.position_ids_rank = len(self.io_config["input_shapes"][idx])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could merge this condition with the loop below. That would avoid multiple iterations thru' the list.

Comment on lines +303 to +317
self.hybrid_states = {}
for idx, name in enumerate(self.io_config["input_names"]):
if "conv_state" in name or "recurrent_state" in name:
shape = self.io_config["input_shapes"][idx]
dtype = self.io_config["input_types"][idx]
self.hybrid_states[name] = {"shape": shape, "dtype": dtype}

# detect hybrid state outputs
self.hybrid_state_outputs = {}
for idx, name in enumerate(self.io_config["output_names"]):
if "conv_state" in name or "recurrent_state" in name:
shape = self.io_config["output_shapes"][idx]
dtype = self.io_config["output_types"][idx]
self.hybrid_state_outputs[name] = {"shape": shape, "dtype": dtype}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These loops can be merged into one!

Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment on lines +2374 to +2414
def _find_embed_node(model, op_type, label):
"""Find the embed_tokens node of the given op_type and its index."""
for i, node in enumerate(model.graph.node):
if node.op_type == op_type and "embed_tokens" in node.name:
return node, i
logger.warning("No embed_tokens %s node found, skipping %s", op_type, label)
return None, None


def _find_lm_head_node(model):
"""Find the lm_head MatMulNBits node and its index."""
for i, node in enumerate(model.graph.node):
if node.op_type == "MatMulNBits" and "lm_head" in node.name:
return node, i
logger.warning("No lm_head MatMulNBits found")
return None, None


def _find_initializer(model, name):
"""Find an initializer by name."""
for init in model.graph.initializer:
if init.name == name:
return init
return None


def _get_node_attrs(node, *attr_names):
"""Extract integer attributes from a node by name."""
result = {}
for attr in node.attribute:
if attr.name in attr_names:
result[attr.name] = attr.i
return result


def _ensure_msft_opset(model):
"""Ensure com.microsoft opset import is present in the model."""
for opset in model.opset_import:
if opset.domain == "com.microsoft":
return
model.opset_import.append(onnx.helper.make_opsetid("com.microsoft", 1))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use OnnxDAG instead.

Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment on lines +2489 to +2491
model.graph.initializer.append(numpy_helper.from_array(q_flat, name=qweight_name))
model.graph.initializer.append(numpy_helper.from_array(scales, name=scales_name))
model.graph.initializer.append(numpy_helper.from_array(zero_points, name=zp_name))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use OnnxDAG to manmipulate the graph.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or onnx-ir

@apsonawane apsonawane requested review from shaahji and xiaoyu-work May 20, 2026 22:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants