From c7e10b5631029d44b64e018a680056b36cece990 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Thu, 4 Jun 2026 11:08:06 +0300 Subject: [PATCH 1/3] [quantization] Introduce Quantize/Dequantize for MX This PR adds definition for Quantize/Dequantize stubs for MX format and adds tests for it. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../utils/test_register_custom_op.py | 69 +++++++++++++++++++ tico/utils/register_custom_op.py | 50 ++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/test/unit_test/utils/test_register_custom_op.py b/test/unit_test/utils/test_register_custom_op.py index 7a8bc318..4d23ba98 100644 --- a/test/unit_test/utils/test_register_custom_op.py +++ b/test/unit_test/utils/test_register_custom_op.py @@ -17,6 +17,7 @@ import tico.utils.register_custom_op as register_custom_op import torch +from torch._subclasses.fake_tensor import FakeTensorMode class TestRegisterCustomOp(unittest.TestCase): @@ -415,6 +416,74 @@ def test_circle_quantize_mx_mocked(self, mock_quantize_mx): round="nearest", ) + def test_circle_quantize_mx_decomposed_fake_tensor(self): + """Test CircleQuantizeMXDecomposed with FakeTensorMode""" + # These ops are designed to be fake ops only (assert False in impl) + # so test using FakeTensorMode + fake_mode = FakeTensorMode() + input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3)) + elem_format = "int8" + axis = -1 + + with fake_mode: + result = torch.ops.circle_custom.quantize_mx_decomposed( + input_tensor, elem_format, axis + ) + + # Check output shape matches input + self.assertEqual(list(result.shape), list(input_tensor.shape)) + + def test_circle_dequantize_mx_decomposed_fake_tensor(self): + """Test CircleDeQuantizeMXDecomposed with FakeTensorMode""" + # These ops are designed to be fake ops only (assert False in impl) + # so test using FakeTensorMode + fake_mode = FakeTensorMode() + input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3)) + elem_format = "int8" + axis = -1 + + with fake_mode: + result = torch.ops.circle_custom.dequantize_mx_decomposed( + input_tensor, elem_format, axis + ) + + # Check output shape matches input + self.assertEqual(list(result.shape), list(input_tensor.shape)) + + def test_circle_quantize_mx_decomposed_with_params_fake_tensor(self): + """Test CircleQuantizeMXDecomposed with all parameters using FakeTensorMode""" + fake_mode = FakeTensorMode() + input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3)) + elem_format = "int8" + axis = -1 + shared_exp_method = "max" + round_method = "nearest" + + with fake_mode: + result = torch.ops.circle_custom.quantize_mx_decomposed( + input_tensor, elem_format, axis, shared_exp_method, round_method + ) + + # Check output shape matches input + self.assertEqual(list(result.shape), list(input_tensor.shape)) + + def test_circle_dequantize_mx_decomposed_with_params_fake_tensor(self): + """Test CircleDeQuantizeMXDecomposed with all parameters using FakeTensorMode""" + fake_mode = FakeTensorMode() + input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3)) + elem_format = "int8" + axis = -1 + shared_exp_method = "max" + round_method = "nearest" + + with fake_mode: + result = torch.ops.circle_custom.dequantize_mx_decomposed( + input_tensor, elem_format, axis, shared_exp_method, round_method + ) + + # Check output shape matches input + self.assertEqual(list(result.shape), list(input_tensor.shape)) + if __name__ == "__main__": unittest.main() diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 1b99de7c..a89f1bae 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -705,6 +705,54 @@ def _( return input_ +def CircleQuantizeMXDecomposed(): + @custom_op("circle_custom::quantize_mx_decomposed", mutates_args=()) + def quantize_mx( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", + round: str = "nearest", + ) -> torch.Tensor: + # this op should be fake one, so please consider different quantization scheme in case it failed here + assert False + return input_.clone() + + @register_fake("circle_custom::quantize_mx_decomposed") + def _( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", # Fixed + round: str = "nearest", # Fixed + ) -> torch.Tensor: + return input_ + + +def CircleDeQuantizeMXDecomposed(): + @custom_op("circle_custom::dequantize_mx_decomposed", mutates_args=()) + def quantize_mx( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", + round: str = "nearest", + ) -> torch.Tensor: + # this op should be fake one, so please consider different quantization scheme in case it failed here + assert False + return input_.clone() + + @register_fake("circle_custom::dequantize_mx_decomposed") + def _( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", # Fixed + round: str = "nearest", # Fixed; + ) -> torch.Tensor: + return input_ + + def CircleRMSNorm(): @custom_op("circle_custom::rms_norm", mutates_args=()) def rms_norm( @@ -800,6 +848,8 @@ def RegisterOps(): CircleAvgPool2D() CircleInstanceNorm() CircleQuantizeMX() + CircleQuantizeMXDecomposed() + CircleDeQuantizeMXDecomposed() CircleRMSNorm() CircleAttention() CircleShape() From 120c63738286ef2dc981442c93595d8ecd291401 Mon Sep 17 00:00:00 2001 From: Stanislav Malakhov <112689352+stamalakhov@users.noreply.github.com> Date: Fri, 5 Jun 2026 08:30:39 +0300 Subject: [PATCH 2/3] Apply suggestions from code review Apply suggestions from code review Co-authored-by: seongwoo chae --- tico/utils/register_custom_op.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index a89f1bae..d69899af 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -715,8 +715,10 @@ def quantize_mx( round: str = "nearest", ) -> torch.Tensor: # this op should be fake one, so please consider different quantization scheme in case it failed here - assert False - return input_.clone() + raise RuntimeError( + "circle_custom::quantize_mx_decomposed is a fake-only op and must not be executed with real tensors" + ) + return input_.new_empty(input_.size()) @register_fake("circle_custom::quantize_mx_decomposed") def _( @@ -731,7 +733,7 @@ def _( def CircleDeQuantizeMXDecomposed(): @custom_op("circle_custom::dequantize_mx_decomposed", mutates_args=()) - def quantize_mx( + def dequantize_mx( input_: torch.Tensor, elem_format: str, axis: int, From eb53f2dce8bbc988e23e86bf8d466bf9e20a7842 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Fri, 5 Jun 2026 08:41:16 +0300 Subject: [PATCH 3/3] Apply suggestions from code review TICO-DCO-1.0-Signed-off-by: s.malakhov --- tico/utils/register_custom_op.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index d69899af..8e021539 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -716,9 +716,9 @@ def quantize_mx( ) -> torch.Tensor: # this op should be fake one, so please consider different quantization scheme in case it failed here raise RuntimeError( - "circle_custom::quantize_mx_decomposed is a fake-only op and must not be executed with real tensors" - ) - return input_.new_empty(input_.size()) + "circle_custom::quantize_mx_decomposed is a fake-only op and must not be executed with real tensors" + ) + return input_.new_empty(input_.size()) @register_fake("circle_custom::quantize_mx_decomposed") def _( @@ -741,7 +741,9 @@ def dequantize_mx( round: str = "nearest", ) -> torch.Tensor: # this op should be fake one, so please consider different quantization scheme in case it failed here - assert False + raise RuntimeError( + "circle_custom::dequantize_mx_decomposed is a fake-only op and must not be executed with real tensors" + ) return input_.clone() @register_fake("circle_custom::dequantize_mx_decomposed")