Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion test/quantization/passes/test_fold_quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from tico.quantization.passes.fold_quant_ops import FoldQuantOps
from tico.serialize.quant_param import QPARAM_KEY

from test.modules.op.sub import SimpleSub
from test.support.helper import num_of_ops


Expand Down Expand Up @@ -52,6 +51,19 @@ def get_example_inputs(self):
return (torch.randn(3, 3),), {}


class MXFakeQuantize(torch.nn.Module):
"""Module that produces an MX Q-DQ pattern after decomposition."""

def __init__(self):
super().__init__()

def forward(self, inp):
return torch.ops.circle_custom.mx_fake_quantize(inp, "int8", -1)

def get_example_inputs(self):
return (torch.randn(2, 32),), {}


class FoldQuantOpsTest(unittest.TestCase):
def test_pass(self):
m = torch.nn.SiLU().eval()
Expand Down Expand Up @@ -153,3 +165,33 @@ def test_requantize(self):
== torch.ops.quantized_decomposed.quantize_per_tensor.default
):
self.assertTrue(QPARAM_KEY in node.meta)

def test_fold_mx_qdq_to_producer_qparam(self):
m = MXFakeQuantize().eval()
args, kwargs = m.get_example_inputs()
ep = torch.export.export(m, args, kwargs)

DecomposeFakeQuantize().call(ep)
self.assertEqual(
num_of_ops(ep, [torch.ops.circle_custom.quantize_mx.default]), 1
)
self.assertEqual(
num_of_ops(ep, [torch.ops.circle_custom.dequantize_mx.default]), 1
)

FoldQuantOps().call(ep)

self.assertEqual(
num_of_ops(ep, [torch.ops.circle_custom.quantize_mx.default]), 0
)
self.assertEqual(
num_of_ops(ep, [torch.ops.circle_custom.dequantize_mx.default]), 0
)

found_input_qparam = False
for node in ep.graph.nodes:
if node.op == "placeholder" and QPARAM_KEY in node.meta:
self.assertEqual(node.meta[QPARAM_KEY].dtype, "mxint8")
self.assertEqual(node.meta[QPARAM_KEY].quantized_dimension, -1)
found_input_qparam = True
self.assertTrue(found_input_qparam)
48 changes: 48 additions & 0 deletions test/unit_test/passes/test_decompose_fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,51 @@ def test_pass(self):
),
1,
)


class MXFakeQuantize(torch.nn.Module):
"""Simple module with MX fake quantization."""

def __init__(self):
super().__init__()

def forward(self, input):
return torch.ops.circle_custom.mx_fake_quantize(input, "int8", -1)

def get_example_inputs(self):
return (torch.randn(2, 32),), {}


class DecomposeMXFakeQuantizeTest(SinglePassValueTest):
"""Tests MX fake quant decomposition into logical Q-DQ."""

def test_pass(self):
self.setup(MXFakeQuantize())
self.assertEqual(
num_of_ops(
self.exported_program(),
[torch.ops.circle_custom.mx_fake_quantize.default],
),
1,
)

DecomposeFakeQuantize().call(self.exported_program())
self.assertEqual(
num_of_ops(
self.exported_program(),
[torch.ops.circle_custom.mx_fake_quantize.default],
),
0,
)
self.assertEqual(
num_of_ops(
self.exported_program(), [torch.ops.circle_custom.quantize_mx.default]
),
1,
)
self.assertEqual(
num_of_ops(
self.exported_program(), [torch.ops.circle_custom.dequantize_mx.default]
),
1,
)
3 changes: 1 addition & 2 deletions test/unit_test/utils/test_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def test_random_values(self):
self.assertFalse(torch.allclose(input_, output))
self.assertLess(compute_peir(input_, output), 0.01)

# Check if exported program includes circle_custom::quantize_mx Op
def test_export(self):
m = SimpleMXINT8(axis=2)
args, kwargs = m.get_example_inputs()
Expand All @@ -80,5 +79,5 @@ def test_export(self):
ep = export(m.eval(), args, kwargs)

self.assertEqual(
1, num_of_ops(ep, [torch.ops.circle_custom.quantize_mx.default])
1, num_of_ops(ep, [torch.ops.circle_custom.mx_fake_quantize.default])
)
53 changes: 25 additions & 28 deletions test/unit_test/utils/test_register_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,43 +314,38 @@ def test_circle_instance_norm_with_custom_params(self):
# Check output shape
self.assertEqual(list(result.shape), list(input_tensor.shape))

def test_circle_quantize_mx_int8(self):
"""Test CircleQuantizeMX with int8 format"""
def test_circle_mx_fake_quantize_basic(self):
"""Test CircleMXFakeQuantize basic functionality."""
input_tensor = torch.randn(2, 32, 32, 3)
elem_format = "int8"
axis = -1

result = torch.ops.circle_custom.quantize_mx(input_tensor, elem_format, axis)

# Check output shape
result = torch.ops.circle_custom.mx_fake_quantize(input_tensor, "int8", -1)
self.assertEqual(list(result.shape), list(input_tensor.shape))

def test_circle_quantize_mx_unsupported_format(self):
"""Test CircleQuantizeMX with unsupported format"""
def test_circle_mx_fake_quantize_unsupported_format(self):
"""Test CircleMXFakeQuantize with unsupported format."""
input_tensor = torch.randn(2, 32, 32, 3)
elem_format = "unsupported_format"
axis = -1

with self.assertRaises(RuntimeError) as context:
torch.ops.circle_custom.quantize_mx(input_tensor, elem_format, axis)

self.assertIn("Unsupported elem_format in quantize_mx", str(context.exception))
torch.ops.circle_custom.mx_fake_quantize(
input_tensor, "unsupported_format", -1
)
self.assertIn(
"Unsupported elem_format in mx_fake_quantize", str(context.exception)
)

def test_circle_quantize_mx_with_custom_params(self):
"""Test CircleQuantizeMX with custom parameters"""
def test_circle_mx_fake_quantize_with_custom_params(self):
"""Test CircleMXFakeQuantize with custom parameters."""
input_tensor = torch.randn(2, 32, 32, 3)
elem_format = "int8"
axis = -1
shared_exp_method = "max"
round_method = "nearest"

result = torch.ops.circle_custom.quantize_mx(
input_tensor, elem_format, axis, shared_exp_method, round_method
result = torch.ops.circle_custom.mx_fake_quantize(
input_tensor, "int8", -1, "max", "nearest"
)

# Check output shape
self.assertEqual(list(result.shape), list(input_tensor.shape))

def test_circle_quantize_mx_is_logical_only(self):
"""Test that logical MX quantize is not an eager fake-quant API."""
input_tensor = torch.randn(2, 32, 32, 3)
with self.assertRaises(RuntimeError) as context:
torch.ops.circle_custom.quantize_mx(input_tensor, "int8", -1)
self.assertIn("internal logical quantize op", str(context.exception))

def test_circle_rms_norm_basic(self):
"""Test CircleRMSNorm basic functionality"""
hidden_states = torch.randn(2, 32, 3)
Expand Down Expand Up @@ -402,7 +397,9 @@ def test_circle_quantize_mx_mocked(self, mock_quantize_mx):
elem_format = "int8"
axis = -1

result = torch.ops.circle_custom.quantize_mx(input_tensor, elem_format, axis)
result = torch.ops.circle_custom.mx_fake_quantize(
input_tensor, elem_format, axis
)

# Check that _quantize_mx was called with correct parameters
mock_quantize_mx.assert_called_once_with(
Expand Down
19 changes: 19 additions & 0 deletions tico/passes/decompose_fake_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,25 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
node.replace_all_uses_with(dequnt, propagate_meta=True)
modified = True

if node.target in [torch.ops.circle_custom.mx_fake_quantize.default]:
assert len(node.args) >= 3
with gm.graph.inserting_before(node):
quant = create_node(
g,
torch.ops.circle_custom.quantize_mx.default,
args=node.args,
kwargs=node.kwargs,
origin=node,
)
dequnt = create_node(
g,
torch.ops.circle_custom.dequantize_mx.default,
args=(quant, *quant.args[1:]),
kwargs=quant.kwargs,
)
node.replace_all_uses_with(dequnt, propagate_meta=True)
modified = True

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()
Expand Down
83 changes: 83 additions & 0 deletions tico/quantization/passes/fold_quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@

from tico.serialize.quant_param import QPARAM_KEY, QuantParam
from tico.utils import logging
from tico.utils.mx.dtypes import (
assert_supported_mx_export_options,
is_mx_dtype,
mx_dtype_from_elem_format,
normalize_mx_elem_format,
)
from tico.utils.passes import PassBase, PassResult
from tico.utils.trace_decorators import trace_graph_diff_on_pass
from tico.utils.utils import get_quant_dtype
Expand All @@ -31,6 +37,46 @@
)


def _mx_op_params(node) -> tuple[str, int, str, str]:
"""Return normalized MX quantization parameters from an FX custom op node."""
assert len(node.args) >= 3
elem_format = normalize_mx_elem_format(node.args[1])
axis = node.args[2]
assert isinstance(axis, int)
shared_exp_method = node.kwargs.get(
"shared_exp_method", node.args[3] if len(node.args) > 3 else "max"
)
round_mode = node.kwargs.get(
"round", node.args[4] if len(node.args) > 4 else "nearest"
)
assert isinstance(shared_exp_method, str)
assert isinstance(round_mode, str)
return elem_format, axis, shared_exp_method, round_mode


def _mx_qparam_from_quant_node(q) -> QuantParam:
"""Build a QuantParam from a logical MX quantize node."""
elem_format, axis, shared_exp_method, round_mode = _mx_op_params(q)
assert_supported_mx_export_options(
elem_format=elem_format,
shared_exp_method=shared_exp_method,
round=round_mode,
)
qparam = QuantParam()
qparam.dtype = mx_dtype_from_elem_format(elem_format)
qparam.quantized_dimension = axis
return qparam


def _same_mx_qparam(lhs: QuantParam, rhs: QuantParam) -> bool:
"""Return True when two QuantParams describe the same MX quantization."""
return (
lhs.dtype == rhs.dtype
and is_mx_dtype(lhs.dtype)
and lhs.quantized_dimension == rhs.quantized_dimension
)


@trace_graph_diff_on_pass
class FoldQuantOps(PassBase):
"""
Expand Down Expand Up @@ -145,6 +191,43 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
dq.replace_all_uses_with(op, propagate_meta=False)
logger.debug(f"Removed redundant {dq.name}")

for dq in graph.nodes:
if dq.op != "call_function":
continue
if dq.target != torch.ops.circle_custom.dequantize_mx.default:
continue

q = dq.args[0]
if not isinstance(q, torch.fx.Node):
continue
if q.target != torch.ops.circle_custom.quantize_mx.default:
continue

op = q.args[0]
if not isinstance(op, torch.fx.Node):
continue

if _mx_op_params(q) != _mx_op_params(dq):
continue

qparam = _mx_qparam_from_quant_node(q)

if QPARAM_KEY not in op.meta:
op.meta[QPARAM_KEY] = qparam
dq.replace_all_uses_with(op, propagate_meta=False)
logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
else:
op_qparam = op.meta[QPARAM_KEY]
if not _same_mx_qparam(op_qparam, qparam):
if QPARAM_KEY not in q.meta:
q.meta[QPARAM_KEY] = qparam
assert len(q.users) == 1, "Fix me unless"
dq.replace_all_uses_with(q, propagate_meta=False)
logger.debug(f"{dq.name} is folded ({q.name} is left).")
else:
dq.replace_all_uses_with(op, propagate_meta=False)
logger.debug(f"Removed redundant {dq.name}")

graph.eliminate_dead_code()
graph.lint()
graph_module.recompile()
Expand Down
Loading
Loading