diff --git a/test/unit_test/utils/test_register_custom_op.py b/test/unit_test/utils/test_register_custom_op.py index 7a8bc318..4d23ba98 100644 --- a/test/unit_test/utils/test_register_custom_op.py +++ b/test/unit_test/utils/test_register_custom_op.py @@ -17,6 +17,7 @@ import tico.utils.register_custom_op as register_custom_op import torch +from torch._subclasses.fake_tensor import FakeTensorMode class TestRegisterCustomOp(unittest.TestCase): @@ -415,6 +416,74 @@ def test_circle_quantize_mx_mocked(self, mock_quantize_mx): round="nearest", ) + def test_circle_quantize_mx_decomposed_fake_tensor(self): + """Test CircleQuantizeMXDecomposed with FakeTensorMode""" + # These ops are designed to be fake ops only (assert False in impl) + # so test using FakeTensorMode + fake_mode = FakeTensorMode() + input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3)) + elem_format = "int8" + axis = -1 + + with fake_mode: + result = torch.ops.circle_custom.quantize_mx_decomposed( + input_tensor, elem_format, axis + ) + + # Check output shape matches input + self.assertEqual(list(result.shape), list(input_tensor.shape)) + + def test_circle_dequantize_mx_decomposed_fake_tensor(self): + """Test CircleDeQuantizeMXDecomposed with FakeTensorMode""" + # These ops are designed to be fake ops only (assert False in impl) + # so test using FakeTensorMode + fake_mode = FakeTensorMode() + input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3)) + elem_format = "int8" + axis = -1 + + with fake_mode: + result = torch.ops.circle_custom.dequantize_mx_decomposed( + input_tensor, elem_format, axis + ) + + # Check output shape matches input + self.assertEqual(list(result.shape), list(input_tensor.shape)) + + def test_circle_quantize_mx_decomposed_with_params_fake_tensor(self): + """Test CircleQuantizeMXDecomposed with all parameters using FakeTensorMode""" + fake_mode = FakeTensorMode() + input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3)) + elem_format = "int8" + axis = -1 + shared_exp_method = "max" + round_method = "nearest" + + with fake_mode: + result = torch.ops.circle_custom.quantize_mx_decomposed( + input_tensor, elem_format, axis, shared_exp_method, round_method + ) + + # Check output shape matches input + self.assertEqual(list(result.shape), list(input_tensor.shape)) + + def test_circle_dequantize_mx_decomposed_with_params_fake_tensor(self): + """Test CircleDeQuantizeMXDecomposed with all parameters using FakeTensorMode""" + fake_mode = FakeTensorMode() + input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3)) + elem_format = "int8" + axis = -1 + shared_exp_method = "max" + round_method = "nearest" + + with fake_mode: + result = torch.ops.circle_custom.dequantize_mx_decomposed( + input_tensor, elem_format, axis, shared_exp_method, round_method + ) + + # Check output shape matches input + self.assertEqual(list(result.shape), list(input_tensor.shape)) + if __name__ == "__main__": unittest.main() diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 1b99de7c..8e021539 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -705,6 +705,58 @@ def _( return input_ +def CircleQuantizeMXDecomposed(): + @custom_op("circle_custom::quantize_mx_decomposed", mutates_args=()) + def quantize_mx( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", + round: str = "nearest", + ) -> torch.Tensor: + # this op should be fake one, so please consider different quantization scheme in case it failed here + raise RuntimeError( + "circle_custom::quantize_mx_decomposed is a fake-only op and must not be executed with real tensors" + ) + return input_.new_empty(input_.size()) + + @register_fake("circle_custom::quantize_mx_decomposed") + def _( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", # Fixed + round: str = "nearest", # Fixed + ) -> torch.Tensor: + return input_ + + +def CircleDeQuantizeMXDecomposed(): + @custom_op("circle_custom::dequantize_mx_decomposed", mutates_args=()) + def dequantize_mx( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", + round: str = "nearest", + ) -> torch.Tensor: + # this op should be fake one, so please consider different quantization scheme in case it failed here + raise RuntimeError( + "circle_custom::dequantize_mx_decomposed is a fake-only op and must not be executed with real tensors" + ) + return input_.clone() + + @register_fake("circle_custom::dequantize_mx_decomposed") + def _( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", # Fixed + round: str = "nearest", # Fixed; + ) -> torch.Tensor: + return input_ + + def CircleRMSNorm(): @custom_op("circle_custom::rms_norm", mutates_args=()) def rms_norm( @@ -800,6 +852,8 @@ def RegisterOps(): CircleAvgPool2D() CircleInstanceNorm() CircleQuantizeMX() + CircleQuantizeMXDecomposed() + CircleDeQuantizeMXDecomposed() CircleRMSNorm() CircleAttention() CircleShape()