Skip to content

Commit 8597d21

Browse files
authored
[Frontend][ONNX] Add MatMulInteger support to Relax ONNX frontend (#18951)
### Summary Implements the `MatMulInteger` operator (opset 10) in the Relax ONNX frontend — INT8 matrix multiplication. Required for quantized model inference (e.g. ONNX QDQ models). Closes #18945 (Tier 1 — MatMulInteger operator) ### Tests - All 4 `int8`/`uint8` dtype combinations, with and without scalar zero points - 3-D and 4-D batched matmul --------- Signed-off-by: OmarAzizi <oalazizi75@gmail.com>
1 parent 4df6b17 commit 8597d21

2 files changed

Lines changed: 225 additions & 1 deletion

File tree

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4054,6 +4054,60 @@ def _impl_v16(cls, bb, inputs, attr, params):
40544054
)
40554055

40564056

4057+
class MatMulInteger(OnnxOpConverter):
4058+
"""
4059+
Converts ONNX MatMulInteger (INT8/UINT8 quantized matrix multiply).
4060+
4061+
Computes: output = (A - a_zero_point) * (B - b_zero_point)
4062+
in int32 accumulation, per ONNX spec v10.
4063+
4064+
Zero-point shapes per spec:
4065+
a_zero_point: scalar | [M] (per-row) | [D1, D2, M, 1] (N-D per-row)
4066+
b_zero_point: scalar | [N] (per-col) | [D1, D2, 1, N] (N-D per-col)
4067+
"""
4068+
4069+
@classmethod
4070+
def _impl_v10(cls, bb, inputs, attr, params):
4071+
a = inputs[0]
4072+
b = inputs[1]
4073+
4074+
# Optional zero points with default of None (treated as 0)
4075+
a_zero_point = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
4076+
b_zero_point = inputs[3] if len(inputs) > 3 and inputs[3] is not None else None
4077+
4078+
# Widen to int32 before any arithmetic to prevent overflow
4079+
a = relax.op.astype(a, "int32")
4080+
b = relax.op.astype(b, "int32")
4081+
4082+
if a_zero_point is not None:
4083+
a_zp = relax.op.astype(
4084+
a_zero_point, "int32"
4085+
) # Ensure zero point is int32 for subtraction
4086+
a_zp = bb.normalize(a_zp) # Normalize the expr so struct_info gets populated
4087+
a_zp_ndim = len(a_zp.struct_info.shape)
4088+
4089+
# Per-row case: [M] -> [M, 1] so it broadcasts over [M, K] row-wise
4090+
# N-D case: spec says shape is [D1, D2, M, 1], which already broadcasts correctly (no need to reshape)
4091+
if a_zp_ndim == 1:
4092+
a_zp = relax.op.expand_dims(a_zp, axis=-1)
4093+
4094+
a = relax.op.subtract(a, a_zp)
4095+
4096+
if b_zero_point is not None:
4097+
b_zp = relax.op.astype(b_zero_point, "int32")
4098+
b_zp = bb.normalize(b_zp)
4099+
b_zp_ndim = len(b_zp.struct_info.shape)
4100+
4101+
# Per-col case: [N] -> [1, N] so it broadcasts over [K, N] column-wise
4102+
# N-D case: [D1, D2, 1, N] already broadcasts correctly
4103+
if b_zp_ndim == 1:
4104+
b_zp = relax.op.expand_dims(b_zp, axis=0)
4105+
4106+
b = relax.op.subtract(b, b_zp)
4107+
4108+
return relax.op.matmul(a, b, out_dtype="int32") # Output is int32 per ONNX spec
4109+
4110+
40574111
def _get_convert_map():
40584112
return {
40594113
# defs/experimental
@@ -4129,7 +4183,7 @@ def _get_convert_map():
41294183
"Cast": Cast,
41304184
"Gemm": Gemm,
41314185
"MatMul": MatMul,
4132-
# "MatMulInteger": MatMulInteger,
4186+
"MatMulInteger": MatMulInteger,
41334187
# "MatMulInteger16": MatMulInteger16,
41344188
"Reshape": Reshape,
41354189
"Sigmoid": Sigmoid,

tests/python/relax/test_frontend_onnx.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4423,6 +4423,176 @@ def test_if_nested():
44234423
)
44244424

44254425

4426+
# Helper that builds the ONNX graph for MatMulInteger so the tests don't repeat boilerplate code every time
4427+
def _make_matmulinteger_model(A_shape, B_shape, A_dtype, B_dtype, a_zp_array=None, b_zp_array=None):
4428+
"""Build a minimal single-node ONNX graph for MatMulInteger."""
4429+
4430+
def np_dtype_to_onnx(dt):
4431+
return {np.int8: TensorProto.INT8, np.uint8: TensorProto.UINT8}[dt]
4432+
4433+
A_info = helper.make_tensor_value_info("A", np_dtype_to_onnx(A_dtype), A_shape)
4434+
B_info = helper.make_tensor_value_info("B", np_dtype_to_onnx(B_dtype), B_shape)
4435+
graph_inputs = [A_info, B_info]
4436+
node_inputs = ["A", "B"]
4437+
initializers = []
4438+
4439+
def _add_zp(name, arr, dtype):
4440+
onnx_dtype = np_dtype_to_onnx(dtype)
4441+
shape = list(arr.shape)
4442+
initializers.append(helper.make_tensor(name, onnx_dtype, shape, arr.flatten().tolist()))
4443+
node_inputs.append(name)
4444+
4445+
if a_zp_array is not None:
4446+
_add_zp("a_zero_point", a_zp_array, A_dtype)
4447+
elif b_zp_array is not None:
4448+
node_inputs.append("") # placeholder only needed if b_zp is present
4449+
4450+
if b_zp_array is not None:
4451+
_add_zp("b_zero_point", b_zp_array, B_dtype)
4452+
4453+
out_info = helper.make_tensor_value_info("output", TensorProto.INT32, None)
4454+
node = helper.make_node("MatMulInteger", inputs=node_inputs, outputs=["output"])
4455+
graph = helper.make_graph(
4456+
[node], "matmulinteger", graph_inputs, [out_info], initializer=initializers
4457+
)
4458+
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 10)])
4459+
model.ir_version = 8
4460+
return model
4461+
4462+
4463+
@pytest.mark.parametrize(
4464+
"A_dtype,B_dtype,a_zp,b_zp",
4465+
[
4466+
(np.int8, np.int8, None, None),
4467+
(np.uint8, np.uint8, None, None),
4468+
(np.uint8, np.int8, None, None),
4469+
pytest.param(
4470+
np.int8,
4471+
np.uint8,
4472+
None,
4473+
None,
4474+
marks=pytest.mark.xfail(
4475+
reason="Some older ORT versions doesn't support mixed int8/uint8 dtype combination for MatMulInteger",
4476+
strict=False, # not strict - may pass on newer ORT versions
4477+
),
4478+
),
4479+
(np.uint8, np.uint8, np.uint8(128), np.uint8(128)),
4480+
(np.int8, np.int8, np.int8(1), np.int8(2)),
4481+
],
4482+
)
4483+
def test_matmulinteger(A_dtype, B_dtype, a_zp, b_zp):
4484+
"""2-D matmul across all dtype combos and zero-point configurations."""
4485+
np.random.seed(0)
4486+
A = np.random.randint(-5, 5, (4, 8)).astype(A_dtype)
4487+
B = np.random.randint(-5, 5, (8, 6)).astype(B_dtype)
4488+
model = _make_matmulinteger_model(
4489+
[4, 8],
4490+
[8, 6],
4491+
A_dtype,
4492+
B_dtype,
4493+
a_zp_array=np.array(a_zp, dtype=A_dtype) if a_zp is not None else None,
4494+
b_zp_array=np.array(b_zp, dtype=B_dtype) if b_zp is not None else None,
4495+
)
4496+
check_correctness(model, inputs={"A": A, "B": B}, opset=10)
4497+
4498+
4499+
@pytest.mark.parametrize(
4500+
"A_shape,B_shape,a_zp,b_zp",
4501+
[
4502+
((2, 4, 8), (2, 8, 6), np.int8(1), np.int8(2)), # 3-D batched
4503+
((2, 3, 4, 8), (2, 3, 8, 6), np.int8(1), np.int8(2)), # 4-D batched
4504+
],
4505+
)
4506+
def test_matmulinteger_batched(A_shape, B_shape, a_zp, b_zp):
4507+
"""Batched matmul — verifies the op generalizes beyond 2-D."""
4508+
np.random.seed(1)
4509+
A = np.random.randint(-5, 5, A_shape).astype(np.int8)
4510+
B = np.random.randint(-5, 5, B_shape).astype(np.int8)
4511+
model = _make_matmulinteger_model(
4512+
list(A_shape),
4513+
list(B_shape),
4514+
np.int8,
4515+
np.int8,
4516+
a_zp_array=np.array(a_zp, dtype=np.int8),
4517+
b_zp_array=np.array(b_zp, dtype=np.int8),
4518+
)
4519+
check_correctness(model, inputs={"A": A, "B": B}, opset=10)
4520+
4521+
4522+
def test_matmulinteger_per_channel_zp():
4523+
"""
4524+
1-D zero points: per-row for A ([M]) and per-col for B ([N]).
4525+
Exercises the expand_dims path in the converter.
4526+
Note: ORT CPU does not support per-row a_zero_point despite the ONNX spec
4527+
allowing it, so we verify TVM output against a NumPy reference instead.
4528+
"""
4529+
np.random.seed(2)
4530+
A = np.random.randint(-5, 5, (4, 8)).astype(np.int8)
4531+
B = np.random.randint(-5, 5, (8, 6)).astype(np.int8)
4532+
a_zp = np.arange(4, dtype=np.int8) # shape [M=4], per-row
4533+
b_zp = np.arange(6, dtype=np.int8) # shape [N=6], per-col
4534+
4535+
# NumPy reference: mirrors the converter's expand_dims logic
4536+
expected = np.matmul(
4537+
A.astype(np.int32) - a_zp.astype(np.int32)[:, np.newaxis],
4538+
B.astype(np.int32) - b_zp.astype(np.int32)[np.newaxis, :],
4539+
).astype(np.int32)
4540+
4541+
model = _make_matmulinteger_model(
4542+
[4, 8], [8, 6], np.int8, np.int8, a_zp_array=a_zp, b_zp_array=b_zp
4543+
)
4544+
4545+
# Run TVM only — ORT doesn't support per-row a_zero_point
4546+
tvm_model = from_onnx(model, opset=10, keep_params_in_input=True)
4547+
tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
4548+
tvm_model = relax.transform.LegalizeOps()(tvm_model)
4549+
tvm_model, params = relax.frontend.detach_params(tvm_model)
4550+
4551+
with tvm.transform.PassContext(opt_level=3):
4552+
ex = tvm.compile(tvm_model, target="llvm")
4553+
vm = relax.VirtualMachine(ex, tvm.cpu())
4554+
4555+
input_list = [
4556+
{"A": A, "B": B}[k.name_hint] for k in tvm_model["main"].params if k.name_hint in {"A", "B"}
4557+
]
4558+
if params:
4559+
input_list += params["main"]
4560+
4561+
vm.set_input("main", *input_list)
4562+
vm.invoke_stateful("main")
4563+
tvm_output = vm.get_outputs("main").numpy()
4564+
4565+
tvm.testing.assert_allclose(tvm_output, expected)
4566+
4567+
4568+
@pytest.mark.xfail(
4569+
reason=(
4570+
"ORT doesn't support per-row a_zero_point of shape [M] "
4571+
"despite the ONNX spec explicitly allowing it. "
4572+
"See: matmul_integer.cc:63 IsScalarOr1ElementVector(a_zero_point)"
4573+
),
4574+
strict=True, # must fail, if ORT ever fixes this, the test will alert us
4575+
)
4576+
def test_matmulinteger_per_channel_zp_ort_limitation():
4577+
"""
4578+
Documents that ORT CPU rejects per-row a_zero_point of shape [M].
4579+
Marked xfail because this is a valid ONNX spec case that ORT simply
4580+
hasn't implemented. If this test starts passing, ORT has fixed the
4581+
limitation and test_matmulinteger_per_channel_zp can be simplified
4582+
to use check_correctness instead of a manual TVM-only reference.
4583+
"""
4584+
np.random.seed(2)
4585+
A = np.random.randint(-5, 5, (4, 8)).astype(np.int8)
4586+
B = np.random.randint(-5, 5, (8, 6)).astype(np.int8)
4587+
a_zp = np.arange(4, dtype=np.int8) # shape [M=4], per-row
4588+
b_zp = np.arange(6, dtype=np.int8) # shape [N=6], per-col
4589+
4590+
model = _make_matmulinteger_model(
4591+
[4, 8], [8, 6], np.int8, np.int8, a_zp_array=a_zp, b_zp_array=b_zp
4592+
)
4593+
check_correctness(model, inputs={"A": A, "B": B}, opset=10)
4594+
4595+
44264596
@pytest.mark.parametrize(
44274597
("pooled_shape", "rois"),
44284598
[

0 commit comments

Comments
 (0)