Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions test/unit_test/utils/test_register_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import tico.utils.register_custom_op as register_custom_op
import torch
from torch._subclasses.fake_tensor import FakeTensorMode


class TestRegisterCustomOp(unittest.TestCase):
Expand Down Expand Up @@ -415,6 +416,74 @@ def test_circle_quantize_mx_mocked(self, mock_quantize_mx):
round="nearest",
)

def test_circle_quantize_mx_decomposed_fake_tensor(self):
"""Test CircleQuantizeMXDecomposed with FakeTensorMode"""
# These ops are designed to be fake ops only (assert False in impl)
# so test using FakeTensorMode
fake_mode = FakeTensorMode()
input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3))
elem_format = "int8"
axis = -1

with fake_mode:
result = torch.ops.circle_custom.quantize_mx_decomposed(
input_tensor, elem_format, axis
)

# Check output shape matches input
self.assertEqual(list(result.shape), list(input_tensor.shape))

def test_circle_dequantize_mx_decomposed_fake_tensor(self):
"""Test CircleDeQuantizeMXDecomposed with FakeTensorMode"""
# These ops are designed to be fake ops only (assert False in impl)
# so test using FakeTensorMode
fake_mode = FakeTensorMode()
input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3))
elem_format = "int8"
axis = -1

with fake_mode:
result = torch.ops.circle_custom.dequantize_mx_decomposed(
input_tensor, elem_format, axis
)

# Check output shape matches input
self.assertEqual(list(result.shape), list(input_tensor.shape))

def test_circle_quantize_mx_decomposed_with_params_fake_tensor(self):
"""Test CircleQuantizeMXDecomposed with all parameters using FakeTensorMode"""
fake_mode = FakeTensorMode()
input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3))
elem_format = "int8"
axis = -1
shared_exp_method = "max"
round_method = "nearest"

with fake_mode:
result = torch.ops.circle_custom.quantize_mx_decomposed(
input_tensor, elem_format, axis, shared_exp_method, round_method
)

# Check output shape matches input
self.assertEqual(list(result.shape), list(input_tensor.shape))

def test_circle_dequantize_mx_decomposed_with_params_fake_tensor(self):
"""Test CircleDeQuantizeMXDecomposed with all parameters using FakeTensorMode"""
fake_mode = FakeTensorMode()
input_tensor = fake_mode.from_tensor(torch.randn(2, 32, 32, 3))
elem_format = "int8"
axis = -1
shared_exp_method = "max"
round_method = "nearest"

with fake_mode:
result = torch.ops.circle_custom.dequantize_mx_decomposed(
input_tensor, elem_format, axis, shared_exp_method, round_method
)

# Check output shape matches input
self.assertEqual(list(result.shape), list(input_tensor.shape))


if __name__ == "__main__":
unittest.main()
54 changes: 54 additions & 0 deletions tico/utils/register_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,58 @@ def _(
return input_


def CircleQuantizeMXDecomposed():
@custom_op("circle_custom::quantize_mx_decomposed", mutates_args=())
def quantize_mx(
input_: torch.Tensor,
elem_format: str,
axis: int,
shared_exp_method: str = "max",
round: str = "nearest",
) -> torch.Tensor:
# this op should be fake one, so please consider different quantization scheme in case it failed here
raise RuntimeError(
"circle_custom::quantize_mx_decomposed is a fake-only op and must not be executed with real tensors"
)
return input_.new_empty(input_.size())

@register_fake("circle_custom::quantize_mx_decomposed")
def _(
input_: torch.Tensor,
elem_format: str,
axis: int,
shared_exp_method: str = "max", # Fixed
round: str = "nearest", # Fixed
) -> torch.Tensor:
return input_


def CircleDeQuantizeMXDecomposed():
@custom_op("circle_custom::dequantize_mx_decomposed", mutates_args=())
def dequantize_mx(
input_: torch.Tensor,
elem_format: str,
axis: int,
shared_exp_method: str = "max",
round: str = "nearest",
) -> torch.Tensor:
# this op should be fake one, so please consider different quantization scheme in case it failed here
raise RuntimeError(
"circle_custom::dequantize_mx_decomposed is a fake-only op and must not be executed with real tensors"
)
return input_.clone()

@register_fake("circle_custom::dequantize_mx_decomposed")
def _(
input_: torch.Tensor,
elem_format: str,
axis: int,
shared_exp_method: str = "max", # Fixed
round: str = "nearest", # Fixed;
) -> torch.Tensor:
return input_


def CircleRMSNorm():
@custom_op("circle_custom::rms_norm", mutates_args=())
def rms_norm(
Expand Down Expand Up @@ -800,6 +852,8 @@ def RegisterOps():
CircleAvgPool2D()
CircleInstanceNorm()
CircleQuantizeMX()
CircleQuantizeMXDecomposed()
CircleDeQuantizeMXDecomposed()
CircleRMSNorm()
CircleAttention()
CircleShape()
Loading