From 7dc00e575c8117e820dc678443b2f8e7cfa070bc Mon Sep 17 00:00:00 2001 From: seongwoo Date: Sat, 6 Jun 2026 18:15:14 +0900 Subject: [PATCH 1/3] Canonicalize MX fake quant export through Q-DQ Introduce a separate MX fake-quant frontend op and lower it to a logical quantize_mx/dequantize_mx pair before Circle export. TICO-DCO-1.0-Signed-off-by: seongwoo --- .../passes/test_fold_quant_ops.py | 44 ++++++++- .../passes/test_decompose_fake_quantize.py | 48 ++++++++++ test/unit_test/utils/test_mx.py | 3 +- .../utils/test_register_custom_op.py | 49 +++++----- tico/passes/decompose_fake_quantize.py | 19 ++++ tico/quantization/passes/fold_quant_ops.py | 83 +++++++++++++++++ tico/serialize/circle_mapping.py | 35 +++---- .../operators/op_quantize_per_tensor.py | 37 ++++++++ tico/utils/mx/dtypes.py | 93 +++++++++++++++++++ tico/utils/mx/mx_ops.py | 27 ++---- tico/utils/register_custom_op.py | 93 +++++++++++++++---- tico/utils/utils.py | 35 ++++--- 12 files changed, 460 insertions(+), 106 deletions(-) create mode 100644 tico/utils/mx/dtypes.py diff --git a/test/quantization/passes/test_fold_quant_ops.py b/test/quantization/passes/test_fold_quant_ops.py index 7b88cfc3..c4ca43b2 100644 --- a/test/quantization/passes/test_fold_quant_ops.py +++ b/test/quantization/passes/test_fold_quant_ops.py @@ -24,7 +24,6 @@ from tico.quantization.passes.fold_quant_ops import FoldQuantOps from tico.serialize.quant_param import QPARAM_KEY -from test.modules.op.sub import SimpleSub from test.support.helper import num_of_ops @@ -52,6 +51,19 @@ def get_example_inputs(self): return (torch.randn(3, 3),), {} +class MXFakeQuantize(torch.nn.Module): + """Module that produces an MX Q-DQ pattern after decomposition.""" + + def __init__(self): + super().__init__() + + def forward(self, inp): + return torch.ops.circle_custom.mx_fake_quantize(inp, "int8", -1) + + def get_example_inputs(self): + return (torch.randn(2, 32),), {} + + class FoldQuantOpsTest(unittest.TestCase): def test_pass(self): m = torch.nn.SiLU().eval() @@ -153,3 +165,33 @@ def test_requantize(self): == torch.ops.quantized_decomposed.quantize_per_tensor.default ): self.assertTrue(QPARAM_KEY in node.meta) + + def test_fold_mx_qdq_to_producer_qparam(self): + m = MXFakeQuantize().eval() + args, kwargs = m.get_example_inputs() + ep = torch.export.export(m, args, kwargs) + + DecomposeFakeQuantize().call(ep) + self.assertEqual( + num_of_ops(ep, [torch.ops.circle_custom.quantize_mx.default]), 1 + ) + self.assertEqual( + num_of_ops(ep, [torch.ops.circle_custom.dequantize_mx.default]), 1 + ) + + FoldQuantOps().call(ep) + + self.assertEqual( + num_of_ops(ep, [torch.ops.circle_custom.quantize_mx.default]), 0 + ) + self.assertEqual( + num_of_ops(ep, [torch.ops.circle_custom.dequantize_mx.default]), 0 + ) + + found_input_qparam = False + for node in ep.graph.nodes: + if node.op == "placeholder" and QPARAM_KEY in node.meta: + self.assertEqual(node.meta[QPARAM_KEY].dtype, "mxint8") + self.assertEqual(node.meta[QPARAM_KEY].quantized_dimension, -1) + found_input_qparam = True + self.assertTrue(found_input_qparam) diff --git a/test/unit_test/passes/test_decompose_fake_quantize.py b/test/unit_test/passes/test_decompose_fake_quantize.py index 85353097..57c56fe2 100644 --- a/test/unit_test/passes/test_decompose_fake_quantize.py +++ b/test/unit_test/passes/test_decompose_fake_quantize.py @@ -74,3 +74,51 @@ def test_pass(self): ), 1, ) + + +class MXFakeQuantize(torch.nn.Module): + """Simple module with MX fake quantization.""" + + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.ops.circle_custom.mx_fake_quantize(input, "int8", -1) + + def get_example_inputs(self): + return (torch.randn(2, 32),), {} + + +class DecomposeMXFakeQuantizeTest(SinglePassValueTest): + """Tests MX fake quant decomposition into logical Q-DQ.""" + + def test_pass(self): + self.setup(MXFakeQuantize()) + self.assertEqual( + num_of_ops( + self.exported_program(), + [torch.ops.circle_custom.mx_fake_quantize.default], + ), + 1, + ) + + DecomposeFakeQuantize().call(self.exported_program()) + self.assertEqual( + num_of_ops( + self.exported_program(), + [torch.ops.circle_custom.mx_fake_quantize.default], + ), + 0, + ) + self.assertEqual( + num_of_ops( + self.exported_program(), [torch.ops.circle_custom.quantize_mx.default] + ), + 1, + ) + self.assertEqual( + num_of_ops( + self.exported_program(), [torch.ops.circle_custom.dequantize_mx.default] + ), + 1, + ) diff --git a/test/unit_test/utils/test_mx.py b/test/unit_test/utils/test_mx.py index edc7d9e6..84eed7e1 100644 --- a/test/unit_test/utils/test_mx.py +++ b/test/unit_test/utils/test_mx.py @@ -71,7 +71,6 @@ def test_random_values(self): self.assertFalse(torch.allclose(input_, output)) self.assertLess(compute_peir(input_, output), 0.01) - # Check if exported program includes circle_custom::quantize_mx Op def test_export(self): m = SimpleMXINT8(axis=2) args, kwargs = m.get_example_inputs() @@ -80,5 +79,5 @@ def test_export(self): ep = export(m.eval(), args, kwargs) self.assertEqual( - 1, num_of_ops(ep, [torch.ops.circle_custom.quantize_mx.default]) + 1, num_of_ops(ep, [torch.ops.circle_custom.mx_fake_quantize.default]) ) diff --git a/test/unit_test/utils/test_register_custom_op.py b/test/unit_test/utils/test_register_custom_op.py index 7a8bc318..5a2460a9 100644 --- a/test/unit_test/utils/test_register_custom_op.py +++ b/test/unit_test/utils/test_register_custom_op.py @@ -314,43 +314,38 @@ def test_circle_instance_norm_with_custom_params(self): # Check output shape self.assertEqual(list(result.shape), list(input_tensor.shape)) - def test_circle_quantize_mx_int8(self): - """Test CircleQuantizeMX with int8 format""" + def test_circle_mx_fake_quantize_basic(self): + """Test CircleMXFakeQuantize basic functionality.""" input_tensor = torch.randn(2, 32, 32, 3) - elem_format = "int8" - axis = -1 - - result = torch.ops.circle_custom.quantize_mx(input_tensor, elem_format, axis) - - # Check output shape + result = torch.ops.circle_custom.mx_fake_quantize(input_tensor, "int8", -1) self.assertEqual(list(result.shape), list(input_tensor.shape)) - def test_circle_quantize_mx_unsupported_format(self): - """Test CircleQuantizeMX with unsupported format""" + def test_circle_mx_fake_quantize_unsupported_format(self): + """Test CircleMXFakeQuantize with unsupported format.""" input_tensor = torch.randn(2, 32, 32, 3) - elem_format = "unsupported_format" - axis = -1 - with self.assertRaises(RuntimeError) as context: - torch.ops.circle_custom.quantize_mx(input_tensor, elem_format, axis) - - self.assertIn("Unsupported elem_format in quantize_mx", str(context.exception)) + torch.ops.circle_custom.mx_fake_quantize( + input_tensor, "unsupported_format", -1 + ) + self.assertIn( + "Unsupported elem_format in mx_fake_quantize", str(context.exception) + ) - def test_circle_quantize_mx_with_custom_params(self): - """Test CircleQuantizeMX with custom parameters""" + def test_circle_mx_fake_quantize_with_custom_params(self): + """Test CircleMXFakeQuantize with custom parameters.""" input_tensor = torch.randn(2, 32, 32, 3) - elem_format = "int8" - axis = -1 - shared_exp_method = "max" - round_method = "nearest" - - result = torch.ops.circle_custom.quantize_mx( - input_tensor, elem_format, axis, shared_exp_method, round_method + result = torch.ops.circle_custom.mx_fake_quantize( + input_tensor, "int8", -1, "max", "nearest" ) - - # Check output shape self.assertEqual(list(result.shape), list(input_tensor.shape)) + def test_circle_quantize_mx_is_logical_only(self): + """Test that logical MX quantize is not an eager fake-quant API.""" + input_tensor = torch.randn(2, 32, 32, 3) + with self.assertRaises(RuntimeError) as context: + torch.ops.circle_custom.quantize_mx(input_tensor, "int8", -1) + self.assertIn("internal logical quantize op", str(context.exception)) + def test_circle_rms_norm_basic(self): """Test CircleRMSNorm basic functionality""" hidden_states = torch.randn(2, 32, 3) diff --git a/tico/passes/decompose_fake_quantize.py b/tico/passes/decompose_fake_quantize.py index e26dda3d..f2c017a5 100644 --- a/tico/passes/decompose_fake_quantize.py +++ b/tico/passes/decompose_fake_quantize.py @@ -124,6 +124,25 @@ def call(self, exported_program: ExportedProgram) -> PassResult: node.replace_all_uses_with(dequnt, propagate_meta=True) modified = True + if node.target in [torch.ops.circle_custom.mx_fake_quantize.default]: + assert len(node.args) >= 3 + with gm.graph.inserting_before(node): + quant = create_node( + g, + torch.ops.circle_custom.quantize_mx.default, + args=node.args, + kwargs=node.kwargs, + origin=node, + ) + dequnt = create_node( + g, + torch.ops.circle_custom.dequantize_mx.default, + args=(quant, *quant.args[1:]), + kwargs=quant.kwargs, + ) + node.replace_all_uses_with(dequnt, propagate_meta=True) + modified = True + gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile() diff --git a/tico/quantization/passes/fold_quant_ops.py b/tico/quantization/passes/fold_quant_ops.py index 48afa7d0..9c237bb7 100644 --- a/tico/quantization/passes/fold_quant_ops.py +++ b/tico/quantization/passes/fold_quant_ops.py @@ -22,6 +22,12 @@ from tico.serialize.quant_param import QPARAM_KEY, QuantParam from tico.utils import logging +from tico.utils.mx.dtypes import ( + assert_supported_mx_export_options, + is_mx_dtype, + mx_dtype_from_elem_format, + normalize_mx_elem_format, +) from tico.utils.passes import PassBase, PassResult from tico.utils.trace_decorators import trace_graph_diff_on_pass from tico.utils.utils import get_quant_dtype @@ -31,6 +37,46 @@ ) +def _mx_op_params(node) -> tuple[str, int, str, str]: + """Return normalized MX quantization parameters from an FX custom op node.""" + assert len(node.args) >= 3 + elem_format = normalize_mx_elem_format(node.args[1]) + axis = node.args[2] + assert isinstance(axis, int) + shared_exp_method = node.kwargs.get( + "shared_exp_method", node.args[3] if len(node.args) > 3 else "max" + ) + round_mode = node.kwargs.get( + "round", node.args[4] if len(node.args) > 4 else "nearest" + ) + assert isinstance(shared_exp_method, str) + assert isinstance(round_mode, str) + return elem_format, axis, shared_exp_method, round_mode + + +def _mx_qparam_from_quant_node(q) -> QuantParam: + """Build a QuantParam from a logical MX quantize node.""" + elem_format, axis, shared_exp_method, round_mode = _mx_op_params(q) + assert_supported_mx_export_options( + elem_format=elem_format, + shared_exp_method=shared_exp_method, + round=round_mode, + ) + qparam = QuantParam() + qparam.dtype = mx_dtype_from_elem_format(elem_format) + qparam.quantized_dimension = axis + return qparam + + +def _same_mx_qparam(lhs: QuantParam, rhs: QuantParam) -> bool: + """Return True when two QuantParams describe the same MX quantization.""" + return ( + lhs.dtype == rhs.dtype + and is_mx_dtype(lhs.dtype) + and lhs.quantized_dimension == rhs.quantized_dimension + ) + + @trace_graph_diff_on_pass class FoldQuantOps(PassBase): """ @@ -145,6 +191,43 @@ def call(self, exported_program: ExportedProgram) -> PassResult: dq.replace_all_uses_with(op, propagate_meta=False) logger.debug(f"Removed redundant {dq.name}") + for dq in graph.nodes: + if dq.op != "call_function": + continue + if dq.target != torch.ops.circle_custom.dequantize_mx.default: + continue + + q = dq.args[0] + if not isinstance(q, torch.fx.Node): + continue + if q.target != torch.ops.circle_custom.quantize_mx.default: + continue + + op = q.args[0] + if not isinstance(op, torch.fx.Node): + continue + + if _mx_op_params(q) != _mx_op_params(dq): + continue + + qparam = _mx_qparam_from_quant_node(q) + + if QPARAM_KEY not in op.meta: + op.meta[QPARAM_KEY] = qparam + dq.replace_all_uses_with(op, propagate_meta=False) + logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.") + else: + op_qparam = op.meta[QPARAM_KEY] + if not _same_mx_qparam(op_qparam, qparam): + if QPARAM_KEY not in q.meta: + q.meta[QPARAM_KEY] = qparam + assert len(q.users) == 1, "Fix me unless" + dq.replace_all_uses_with(q, propagate_meta=False) + logger.debug(f"{dq.name} is folded ({q.name} is left).") + else: + dq.replace_all_uses_with(op, propagate_meta=False) + logger.debug(f"Removed redundant {dq.name}") + graph.eliminate_dead_code() graph.lint() graph_module.recompile() diff --git a/tico/serialize/circle_mapping.py b/tico/serialize/circle_mapping.py index 20336778..f8532f7e 100644 --- a/tico/serialize/circle_mapping.py +++ b/tico/serialize/circle_mapping.py @@ -47,10 +47,10 @@ def to_circle_dtype( return circle_type -# Convert str dtype used in QuantParam to circle dtype def str_to_circle_dtype( str_dtype: str, ) -> int: + """Convert a QuantParam dtype string to a Circle tensor dtype.""" dmap = { "float32": circle.TensorType.TensorType.FLOAT32, "float": circle.TensorType.TensorType.FLOAT32, @@ -65,10 +65,20 @@ def str_to_circle_dtype( "uint4": circle.TensorType.TensorType.UINT4, # TODO Add more dtypes } - + optional_dtypes = { + "mxint8": "MXINT8", + "mxfp4": "MXFP4", + "mxfp6_e3m2": "MXFP6_E3M2", + "mxfp6_e2m3": "MXFP6_E2M3", + "mxfp8_e4m3": "MXFP8_E4M3", + "mxfp8_e5m2": "MXFP8_E5M2", + } + for dtype, circle_name in optional_dtypes.items(): + circle_dtype = getattr(circle.TensorType.TensorType, circle_name, None) + if circle_dtype is not None: + dmap[dtype] = circle_dtype if str_dtype not in dmap: raise RuntimeError(f"Unsupported dtype {str_dtype}") - circle_type = dmap[str_dtype] assert circle_type is not None return circle_type @@ -94,38 +104,31 @@ def np_dtype_from_circle_dtype(circle_dtype: int): return np_dtype -# Return dtype of node def extract_torch_dtype(node: torch.fx.Node) -> torch.dtype: + """Return the torch dtype encoded in a node's meta value.""" assert node.meta is not None assert node.meta.get("val") is not None val = node.meta.get("val") - val_dtype = None if isinstance(val, torch.Tensor): assert isinstance(val.dtype, torch.dtype) - val_dtype = val.dtype - else: - val_dtype = torch.tensor(val).dtype - return val_dtype + return val.dtype + return torch.tensor(val).dtype def extract_circle_dtype(node: torch.fx.Node) -> int: return to_circle_dtype(extract_torch_dtype(node)) -# Return shape of node def extract_shape(node: torch.fx.Node) -> torch.Size: + """Return the shape encoded in a node's meta value.""" assert node.meta is not None assert node.meta.get("val") is not None val = node.meta.get("val") - val_shape = None if isinstance(val, torch.Tensor): - val_shape = val.size() - else: - val_shape = torch.tensor(val).shape - - return val_shape + return val.size() + return torch.tensor(val).shape def extract_circle_shape(node: torch.fx.Node) -> Tuple[List[int], Optional[List[int]]]: diff --git a/tico/serialize/operators/op_quantize_per_tensor.py b/tico/serialize/operators/op_quantize_per_tensor.py index 84665516..5e7c1a3b 100644 --- a/tico/serialize/operators/op_quantize_per_tensor.py +++ b/tico/serialize/operators/op_quantize_per_tensor.py @@ -25,6 +25,8 @@ from tico.serialize.operators.hashable_opcode import OpCode from tico.serialize.operators.node_visitor import NodeVisitor, register_node_visitor from tico.serialize.operators.utils import create_builtin_operator, get_op_index +from tico.serialize.quant_param import QPARAM_KEY +from tico.utils.mx.dtypes import is_mx_dtype, mx_dtype_from_elem_format from tico.utils.validate_args_kwargs import QuantizePerTensorArgs @@ -78,3 +80,38 @@ def define_node( operator.builtinOptions = option return operator + + +@register_node_visitor +class QuantizeMXDefaultVisitor(NodeVisitor): + """Serialize a logical MX quantize node as a Circle QUANTIZE operator.""" + + target: List[torch._ops.OpOverload] = [ + torch.ops.circle_custom.quantize_mx.default, + ] + + def __init__(self, op_codes: Dict[OpCode, int], graph: CircleSubgraph): + super().__init__(op_codes, graph) + + def define_node(self, node: torch.fx.Node) -> circle.Operator.OperatorT: + tensor = node.args[0] + elem_format = node.args[1] + axis = node.args[2] + expected_dtype = mx_dtype_from_elem_format(elem_format) + + output_tensor: circle.Tensor.TensorT = self.graph.get_tensor(node) + assert output_tensor.quantization is not None + assert output_tensor.quantization.quantizedDimension == axis + assert QPARAM_KEY in node.meta + assert is_mx_dtype(node.meta[QPARAM_KEY].dtype) + assert node.meta[QPARAM_KEY].dtype == expected_dtype + + op_index = get_op_index( + circle.BuiltinOperator.BuiltinOperator.QUANTIZE, self._op_codes + ) + operator = create_builtin_operator(self.graph, op_index, [tensor], [node]) + operator.builtinOptionsType = ( + circle.BuiltinOptions.BuiltinOptions.QuantizeOptions + ) + operator.builtinOptions = circle.QuantizeOptions.QuantizeOptionsT() + return operator diff --git a/tico/utils/mx/dtypes.py b/tico/utils/mx/dtypes.py new file mode 100644 index 00000000..854d0284 --- /dev/null +++ b/tico/utils/mx/dtypes.py @@ -0,0 +1,93 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Final + + +MX_ELEM_FORMAT_TO_DTYPE: Final[dict[str, str]] = { + "int8": "mxint8", + "fp4": "mxfp4", + "fp4_e2m1": "mxfp4", + "fp6_e3m2": "mxfp6_e3m2", + "fp6_e2m3": "mxfp6_e2m3", + "fp8_e4m3": "mxfp8_e4m3", + "fp8_e5m2": "mxfp8_e5m2", +} + +MX_DTYPE_TO_ELEM_FORMAT: Final[dict[str, str]] = { + dtype: elem_format + for elem_format, dtype in MX_ELEM_FORMAT_TO_DTYPE.items() + if elem_format != "fp4_e2m1" +} + +SUPPORTED_MX_ELEM_FORMATS: Final[frozenset[str]] = frozenset( + MX_ELEM_FORMAT_TO_DTYPE.keys() +) +SUPPORTED_MX_DTYPES: Final[frozenset[str]] = frozenset(MX_DTYPE_TO_ELEM_FORMAT.keys()) + + +def mx_dtype_from_elem_format(elem_format: str) -> str: + """Return the Circle quantized dtype string for an MX element format.""" + try: + return MX_ELEM_FORMAT_TO_DTYPE[elem_format] + except KeyError as exc: + raise ValueError(f"Unsupported MX element format: {elem_format!r}") from exc + + +def elem_format_from_mx_dtype(dtype: str) -> str: + """Return the MX element format encoded by a Circle quantized dtype string.""" + try: + return MX_DTYPE_TO_ELEM_FORMAT[dtype] + except KeyError as exc: + raise ValueError(f"Unsupported MX dtype: {dtype!r}") from exc + + +def is_mx_dtype(dtype: str) -> bool: + """Return True when the dtype string denotes a supported MX dtype.""" + return dtype in SUPPORTED_MX_DTYPES + + +def normalize_mx_elem_format(elem_format: str) -> str: + """Return the canonical MX element format spelling used by Circle metadata.""" + if elem_format == "fp4_e2m1": + return "fp4" + if elem_format in SUPPORTED_MX_ELEM_FORMATS: + return elem_format + raise ValueError(f"Unsupported MX element format: {elem_format!r}") + + +def assert_supported_mx_export_options( + *, + elem_format: str, + shared_exp_method: str, + round: str, +) -> None: + """Validate MX fake-quant options that can be represented in Circle qparams. + + Circle tensor quantization metadata currently carries the MX dtype and the + quantized dimension. It does not have fields for the shared-exponent method + or the rounding mode, so this helper rejects non-default options before the + Q-DQ pair is folded into tensor metadata. + """ + normalize_mx_elem_format(elem_format) + if shared_exp_method != "max": + raise RuntimeError( + "Circle MX export currently supports only shared_exp_method='max'. " + f"Got {shared_exp_method!r}." + ) + if round != "nearest": + raise RuntimeError( + "Circle MX export currently supports only round='nearest'. " + f"Got {round!r}." + ) diff --git a/tico/utils/mx/mx_ops.py b/tico/utils/mx/mx_ops.py index d00161ec..427481ec 100644 --- a/tico/utils/mx/mx_ops.py +++ b/tico/utils/mx/mx_ops.py @@ -18,26 +18,10 @@ Copyright (c) Microsoft Corporation. Licensed under the MIT License. -Name: mx_ops.py - -Pytorch methods for MX quantization. - -Usage Notes: - - Use the "Exposed Methods" below to implement autograd functions - - Use autograd functions to then implement torch.nn.Module(s) - - Do *not* use methods in this file in Modules, they have no defined - backwards pass and will block gradient computation. - - Avoid importing internal function if at all possible. - -Exposed Methods: - quantize_mx_op - quantizes a tensor to MX format. - -Internal Methods: - _safe_lshift, _safe_rshift - fp16 compatible shifts - _shared_exponents - Returns MX shared exponent for the passed tensor - _reshape_to_blocks - tiles a tensor by splitting one dim into two - _undo_reshape_to_blocks - undos the above reshaping - _quantize_mx - quantizes a tensor to MX format +Pytorch helpers for MX fake quantization. + +The public `quantize_mx` helper intentionally calls the eager fake-quant custom +operator. During Circle export it is canonicalized into logical MX Q-DQ nodes. """ import torch @@ -265,6 +249,7 @@ def quantize_mx( shared_exp_method: str = "max", round: str = "nearest", ) -> torch.Tensor: - return torch.ops.circle_custom.quantize_mx( + """Call the eager MX fake-quantization custom operator.""" + return torch.ops.circle_custom.mx_fake_quantize( input_, elem_format, axis, shared_exp_method=shared_exp_method, round=round ) diff --git a/tico/utils/register_custom_op.py b/tico/utils/register_custom_op.py index 1b99de7c..7e0a933e 100644 --- a/tico/utils/register_custom_op.py +++ b/tico/utils/register_custom_op.py @@ -18,6 +18,7 @@ from torch._subclasses.fake_tensor import FakeTensor from torch.library import custom_op, register_fake +from tico.utils.mx.dtypes import normalize_mx_elem_format, SUPPORTED_MX_ELEM_FORMATS from tico.utils.mx.mx_ops import _quantize_mx # Note that an operator assumes input tensor has NHWC format. @@ -664,43 +665,93 @@ def _( return input.new_empty(input.size()) -def CircleQuantizeMX(): - # This operator conducts fake-quantization of microscaling - # NOTE Why using "quantize"_mx not "fake_quantize"_mx? - # To align with function name of microxcaling repo. - # https://github.com/microsoft/microxcaling/blob/v1.1.0/mx/mx_ops.py#L173 - @custom_op("circle_custom::quantize_mx", mutates_args=()) - def quantize_mx( +def CircleMXFakeQuantize(): + """Register the eager MX fake-quantization custom operator.""" + + @custom_op("circle_custom::mx_fake_quantize", mutates_args=()) + def mx_fake_quantize( input_: torch.Tensor, elem_format: str, axis: int, shared_exp_method: str = "max", round: str = "nearest", ) -> torch.Tensor: - if elem_format == "int8": - scale_bits = 8 - block_size = 32 - else: - raise RuntimeError(f"Unsupported elem_format in quantize_mx: {elem_format}") - - result = _quantize_mx( + if elem_format not in SUPPORTED_MX_ELEM_FORMATS: + raise RuntimeError( + f"Unsupported elem_format in mx_fake_quantize: {elem_format}" + ) + return _quantize_mx( input_, - scale_bits=scale_bits, - elem_format=elem_format, + scale_bits=8, + elem_format=normalize_mx_elem_format(elem_format), axes=[axis], - block_size=block_size, + block_size=32, shared_exp_method=shared_exp_method, round=round, ) - return result + + @register_fake("circle_custom::mx_fake_quantize") + def _( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", + round: str = "nearest", + ) -> torch.Tensor: + return input_ + + +def CircleQuantizeMX(): + """Register the internal logical MX quantize custom operator.""" + + @custom_op("circle_custom::quantize_mx", mutates_args=()) + def quantize_mx( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", + round: str = "nearest", + ) -> torch.Tensor: + raise RuntimeError( + "circle_custom::quantize_mx is an internal logical quantize op for " + "Circle export. Use circle_custom::mx_fake_quantize for eager MX " + "fake-quantization." + ) @register_fake("circle_custom::quantize_mx") def _( input_: torch.Tensor, elem_format: str, axis: int, - shared_exp_method: str = "max", # Fixed - round: str = "nearest", # Fixed + shared_exp_method: str = "max", + round: str = "nearest", + ) -> torch.Tensor: + return input_ + + +def CircleDequantizeMX(): + """Register the internal logical MX dequantize custom operator.""" + + @custom_op("circle_custom::dequantize_mx", mutates_args=()) + def dequantize_mx( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", + round: str = "nearest", + ) -> torch.Tensor: + raise RuntimeError( + "circle_custom::dequantize_mx is an internal logical dequantize op " + "for Circle export and should be folded before eager execution." + ) + + @register_fake("circle_custom::dequantize_mx") + def _( + input_: torch.Tensor, + elem_format: str, + axis: int, + shared_exp_method: str = "max", + round: str = "nearest", ) -> torch.Tensor: return input_ @@ -799,7 +850,9 @@ def RegisterOps(): CircleMaxPool2D() CircleAvgPool2D() CircleInstanceNorm() + CircleMXFakeQuantize() CircleQuantizeMX() + CircleDequantizeMX() CircleRMSNorm() CircleAttention() CircleShape() diff --git a/tico/utils/utils.py b/tico/utils/utils.py index 00125377..bbb5b95a 100644 --- a/tico/utils/utils.py +++ b/tico/utils/utils.py @@ -57,9 +57,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class ArgTypeError(Exception): - """ - Invalid argument type - """ + """Raised when a runtime argument type does not match a type hint.""" pass @@ -249,32 +247,31 @@ def run_bash_cmd(command: typing.List[str]) -> subprocess.CompletedProcess[str]: def has_quantization_ops(graph: torch.fx.Graph): - """ - Checks whether the given fx graph contains any quantization-related operations. - - This function inspects the provided graph to determine if it includes operations associated - with quantization (e.g., quantize, dequantize, fake quantize, etc.). The presence of such operations - can be used to decide whether to run subsequent quantization-specific passes on the graph. + """Return True if an FX graph contains affine or MX quantization operators.""" - Parameters: - graph: The fx graph to be examined. It is expected that the graph supports - iteration or traversal over its constituent operations. + def _maybe_custom_op(name: str): + """Return the default overload if the custom op exists.""" + try: + return getattr(torch.ops.circle_custom, name).default + except AttributeError: + return None - Returns: - bool: True if the graph contains one or more quantization-related operations, False otherwise. - """ quantized_ops = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, ] + for custom_op in ( + _maybe_custom_op("quantize_mx"), + _maybe_custom_op("dequantize_mx"), + ): + if custom_op is not None: + quantized_ops.append(custom_op) + for node in graph.nodes: - if node.op != "call_function": - continue - if node.target in quantized_ops: + if node.op == "call_function" and node.target in quantized_ops: return True - return False From c5123c7dfb5220cdd7a04b9a105a5d4018ca2fb0 Mon Sep 17 00:00:00 2001 From: seongwoo Date: Mon, 8 Jun 2026 12:12:12 +0900 Subject: [PATCH 2/3] fix test. --- test/unit_test/utils/test_register_custom_op.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/unit_test/utils/test_register_custom_op.py b/test/unit_test/utils/test_register_custom_op.py index 5a2460a9..af7c98f8 100644 --- a/test/unit_test/utils/test_register_custom_op.py +++ b/test/unit_test/utils/test_register_custom_op.py @@ -397,7 +397,9 @@ def test_circle_quantize_mx_mocked(self, mock_quantize_mx): elem_format = "int8" axis = -1 - result = torch.ops.circle_custom.quantize_mx(input_tensor, elem_format, axis) + result = torch.ops.circle_custom.mx_fake_quantize( + input_tensor, elem_format, axis + ) # Check that _quantize_mx was called with correct parameters mock_quantize_mx.assert_called_once_with( From 45af0c5e139a67f13577a81b6e4afd6deb123c12 Mon Sep 17 00:00:00 2001 From: seongwoo Date: Mon, 8 Jun 2026 21:42:17 +0900 Subject: [PATCH 3/3] apply comment. --- tico/serialize/circle_mapping.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tico/serialize/circle_mapping.py b/tico/serialize/circle_mapping.py index f8532f7e..89a7fb08 100644 --- a/tico/serialize/circle_mapping.py +++ b/tico/serialize/circle_mapping.py @@ -63,20 +63,9 @@ def str_to_circle_dtype( "int64": circle.TensorType.TensorType.INT64, "bool": circle.TensorType.TensorType.BOOL, "uint4": circle.TensorType.TensorType.UINT4, - # TODO Add more dtypes + "mxint8": circle.TensorType.TensorType.MXINT8, + "mxfp4": circle.TensorType.TensorType.MXFP4, } - optional_dtypes = { - "mxint8": "MXINT8", - "mxfp4": "MXFP4", - "mxfp6_e3m2": "MXFP6_E3M2", - "mxfp6_e2m3": "MXFP6_E2M3", - "mxfp8_e4m3": "MXFP8_E4M3", - "mxfp8_e5m2": "MXFP8_E5M2", - } - for dtype, circle_name in optional_dtypes.items(): - circle_dtype = getattr(circle.TensorType.TensorType, circle_name, None) - if circle_dtype is not None: - dmap[dtype] = circle_dtype if str_dtype not in dmap: raise RuntimeError(f"Unsupported dtype {str_dtype}") circle_type = dmap[str_dtype]