@@ -4423,6 +4423,176 @@ def test_if_nested():
44234423 )
44244424
44254425
4426+ # Helper that builds the ONNX graph for MatMulInteger so the tests don't repeat boilerplate code every time
4427+ def _make_matmulinteger_model (A_shape , B_shape , A_dtype , B_dtype , a_zp_array = None , b_zp_array = None ):
4428+ """Build a minimal single-node ONNX graph for MatMulInteger."""
4429+
4430+ def np_dtype_to_onnx (dt ):
4431+ return {np .int8 : TensorProto .INT8 , np .uint8 : TensorProto .UINT8 }[dt ]
4432+
4433+ A_info = helper .make_tensor_value_info ("A" , np_dtype_to_onnx (A_dtype ), A_shape )
4434+ B_info = helper .make_tensor_value_info ("B" , np_dtype_to_onnx (B_dtype ), B_shape )
4435+ graph_inputs = [A_info , B_info ]
4436+ node_inputs = ["A" , "B" ]
4437+ initializers = []
4438+
4439+ def _add_zp (name , arr , dtype ):
4440+ onnx_dtype = np_dtype_to_onnx (dtype )
4441+ shape = list (arr .shape )
4442+ initializers .append (helper .make_tensor (name , onnx_dtype , shape , arr .flatten ().tolist ()))
4443+ node_inputs .append (name )
4444+
4445+ if a_zp_array is not None :
4446+ _add_zp ("a_zero_point" , a_zp_array , A_dtype )
4447+ elif b_zp_array is not None :
4448+ node_inputs .append ("" ) # placeholder only needed if b_zp is present
4449+
4450+ if b_zp_array is not None :
4451+ _add_zp ("b_zero_point" , b_zp_array , B_dtype )
4452+
4453+ out_info = helper .make_tensor_value_info ("output" , TensorProto .INT32 , None )
4454+ node = helper .make_node ("MatMulInteger" , inputs = node_inputs , outputs = ["output" ])
4455+ graph = helper .make_graph (
4456+ [node ], "matmulinteger" , graph_inputs , [out_info ], initializer = initializers
4457+ )
4458+ model = helper .make_model (graph , opset_imports = [helper .make_opsetid ("" , 10 )])
4459+ model .ir_version = 8
4460+ return model
4461+
4462+
4463+ @pytest .mark .parametrize (
4464+ "A_dtype,B_dtype,a_zp,b_zp" ,
4465+ [
4466+ (np .int8 , np .int8 , None , None ),
4467+ (np .uint8 , np .uint8 , None , None ),
4468+ (np .uint8 , np .int8 , None , None ),
4469+ pytest .param (
4470+ np .int8 ,
4471+ np .uint8 ,
4472+ None ,
4473+ None ,
4474+ marks = pytest .mark .xfail (
4475+ reason = "Some older ORT versions doesn't support mixed int8/uint8 dtype combination for MatMulInteger" ,
4476+ strict = False , # not strict - may pass on newer ORT versions
4477+ ),
4478+ ),
4479+ (np .uint8 , np .uint8 , np .uint8 (128 ), np .uint8 (128 )),
4480+ (np .int8 , np .int8 , np .int8 (1 ), np .int8 (2 )),
4481+ ],
4482+ )
4483+ def test_matmulinteger (A_dtype , B_dtype , a_zp , b_zp ):
4484+ """2-D matmul across all dtype combos and zero-point configurations."""
4485+ np .random .seed (0 )
4486+ A = np .random .randint (- 5 , 5 , (4 , 8 )).astype (A_dtype )
4487+ B = np .random .randint (- 5 , 5 , (8 , 6 )).astype (B_dtype )
4488+ model = _make_matmulinteger_model (
4489+ [4 , 8 ],
4490+ [8 , 6 ],
4491+ A_dtype ,
4492+ B_dtype ,
4493+ a_zp_array = np .array (a_zp , dtype = A_dtype ) if a_zp is not None else None ,
4494+ b_zp_array = np .array (b_zp , dtype = B_dtype ) if b_zp is not None else None ,
4495+ )
4496+ check_correctness (model , inputs = {"A" : A , "B" : B }, opset = 10 )
4497+
4498+
4499+ @pytest .mark .parametrize (
4500+ "A_shape,B_shape,a_zp,b_zp" ,
4501+ [
4502+ ((2 , 4 , 8 ), (2 , 8 , 6 ), np .int8 (1 ), np .int8 (2 )), # 3-D batched
4503+ ((2 , 3 , 4 , 8 ), (2 , 3 , 8 , 6 ), np .int8 (1 ), np .int8 (2 )), # 4-D batched
4504+ ],
4505+ )
4506+ def test_matmulinteger_batched (A_shape , B_shape , a_zp , b_zp ):
4507+ """Batched matmul — verifies the op generalizes beyond 2-D."""
4508+ np .random .seed (1 )
4509+ A = np .random .randint (- 5 , 5 , A_shape ).astype (np .int8 )
4510+ B = np .random .randint (- 5 , 5 , B_shape ).astype (np .int8 )
4511+ model = _make_matmulinteger_model (
4512+ list (A_shape ),
4513+ list (B_shape ),
4514+ np .int8 ,
4515+ np .int8 ,
4516+ a_zp_array = np .array (a_zp , dtype = np .int8 ),
4517+ b_zp_array = np .array (b_zp , dtype = np .int8 ),
4518+ )
4519+ check_correctness (model , inputs = {"A" : A , "B" : B }, opset = 10 )
4520+
4521+
4522+ def test_matmulinteger_per_channel_zp ():
4523+ """
4524+ 1-D zero points: per-row for A ([M]) and per-col for B ([N]).
4525+ Exercises the expand_dims path in the converter.
4526+ Note: ORT CPU does not support per-row a_zero_point despite the ONNX spec
4527+ allowing it, so we verify TVM output against a NumPy reference instead.
4528+ """
4529+ np .random .seed (2 )
4530+ A = np .random .randint (- 5 , 5 , (4 , 8 )).astype (np .int8 )
4531+ B = np .random .randint (- 5 , 5 , (8 , 6 )).astype (np .int8 )
4532+ a_zp = np .arange (4 , dtype = np .int8 ) # shape [M=4], per-row
4533+ b_zp = np .arange (6 , dtype = np .int8 ) # shape [N=6], per-col
4534+
4535+ # NumPy reference: mirrors the converter's expand_dims logic
4536+ expected = np .matmul (
4537+ A .astype (np .int32 ) - a_zp .astype (np .int32 )[:, np .newaxis ],
4538+ B .astype (np .int32 ) - b_zp .astype (np .int32 )[np .newaxis , :],
4539+ ).astype (np .int32 )
4540+
4541+ model = _make_matmulinteger_model (
4542+ [4 , 8 ], [8 , 6 ], np .int8 , np .int8 , a_zp_array = a_zp , b_zp_array = b_zp
4543+ )
4544+
4545+ # Run TVM only — ORT doesn't support per-row a_zero_point
4546+ tvm_model = from_onnx (model , opset = 10 , keep_params_in_input = True )
4547+ tvm_model = relax .transform .DecomposeOpsForInference ()(tvm_model )
4548+ tvm_model = relax .transform .LegalizeOps ()(tvm_model )
4549+ tvm_model , params = relax .frontend .detach_params (tvm_model )
4550+
4551+ with tvm .transform .PassContext (opt_level = 3 ):
4552+ ex = tvm .compile (tvm_model , target = "llvm" )
4553+ vm = relax .VirtualMachine (ex , tvm .cpu ())
4554+
4555+ input_list = [
4556+ {"A" : A , "B" : B }[k .name_hint ] for k in tvm_model ["main" ].params if k .name_hint in {"A" , "B" }
4557+ ]
4558+ if params :
4559+ input_list += params ["main" ]
4560+
4561+ vm .set_input ("main" , * input_list )
4562+ vm .invoke_stateful ("main" )
4563+ tvm_output = vm .get_outputs ("main" ).numpy ()
4564+
4565+ tvm .testing .assert_allclose (tvm_output , expected )
4566+
4567+
4568+ @pytest .mark .xfail (
4569+ reason = (
4570+ "ORT doesn't support per-row a_zero_point of shape [M] "
4571+ "despite the ONNX spec explicitly allowing it. "
4572+ "See: matmul_integer.cc:63 IsScalarOr1ElementVector(a_zero_point)"
4573+ ),
4574+ strict = True , # must fail, if ORT ever fixes this, the test will alert us
4575+ )
4576+ def test_matmulinteger_per_channel_zp_ort_limitation ():
4577+ """
4578+ Documents that ORT CPU rejects per-row a_zero_point of shape [M].
4579+ Marked xfail because this is a valid ONNX spec case that ORT simply
4580+ hasn't implemented. If this test starts passing, ORT has fixed the
4581+ limitation and test_matmulinteger_per_channel_zp can be simplified
4582+ to use check_correctness instead of a manual TVM-only reference.
4583+ """
4584+ np .random .seed (2 )
4585+ A = np .random .randint (- 5 , 5 , (4 , 8 )).astype (np .int8 )
4586+ B = np .random .randint (- 5 , 5 , (8 , 6 )).astype (np .int8 )
4587+ a_zp = np .arange (4 , dtype = np .int8 ) # shape [M=4], per-row
4588+ b_zp = np .arange (6 , dtype = np .int8 ) # shape [N=6], per-col
4589+
4590+ model = _make_matmulinteger_model (
4591+ [4 , 8 ], [8 , 6 ], np .int8 , np .int8 , a_zp_array = a_zp , b_zp_array = b_zp
4592+ )
4593+ check_correctness (model , inputs = {"A" : A , "B" : B }, opset = 10 )
4594+
4595+
44264596@pytest .mark .parametrize (
44274597 ("pooled_shape" , "rois" ),
44284598 [
0 commit comments