diff --git a/test/unit_test/passes/test_fuse_rms_norm.py b/test/unit_test/passes/test_fuse_rms_norm.py new file mode 100644 index 00000000..255a7543 --- /dev/null +++ b/test/unit_test/passes/test_fuse_rms_norm.py @@ -0,0 +1,59 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from tico.passes.fuse_rms_norm import FuseRmsNorm + +from test.support.helper import num_of_ops +from test.support.pass_value_test import SinglePassValueTest + + +class RMSNormNet(torch.nn.Module): + def __init__(self, normalized_shape=16, eps=1e-6): + super().__init__() + self.rms_norm = torch.nn.RMSNorm(normalized_shape, eps=eps) + + def forward(self, x): + return self.rms_norm(x) + + def get_example_inputs(self): + return (torch.randn(1, 8, 16),), {} + + +class FuseRmsNormTest(SinglePassValueTest): + def test_pass(self): + self.setup(RMSNormNet()) + self.assertEqual( + num_of_ops(self.exported_program(), [torch.ops.aten.rms_norm.default]), 1 + ) + self.assertEqual( + num_of_ops( + self.exported_program(), [torch.ops.circle_custom.rms_norm.default] + ), + 0, + ) + + self.run_value_test(FuseRmsNorm()) + + self.assertEqual( + num_of_ops(self.exported_program(), [torch.ops.aten.rms_norm.default]), + 0, + ) + self.assertEqual( + num_of_ops( + self.exported_program(), [torch.ops.circle_custom.rms_norm.default] + ), + 1, + ) diff --git a/tico/passes/fuse_rms_norm.py b/tico/passes/fuse_rms_norm.py new file mode 100644 index 00000000..1e131027 --- /dev/null +++ b/tico/passes/fuse_rms_norm.py @@ -0,0 +1,100 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import torch.fx +import torch +from torch.export import ExportedProgram + +from tico.utils.graph import create_node +from tico.utils.passes import PassBase, PassResult +from tico.utils.trace_decorators import trace_graph_diff_on_pass +from tico.utils.utils import is_target_node +from tico.utils.validate_args_kwargs import RMSNormArgs + + +@trace_graph_diff_on_pass +class FuseRmsNorm(PassBase): + """ + This pass fuses aten.rms_norm into circle_custom.rms_norm. + + aten.rms_norm already includes the weight multiplication internally, + so its output is equivalent to circle_custom.rms_norm. + + [before] + + input (tensor, normalized_shape, weight, eps) + | + aten.rms_norm + | + output + + [after] + + input (tensor, weight, eps) + | + circle_custom.rms_norm + | + output + + Note: normalized_shape is dropped because circle_custom.rms_norm + infers the normalization dimension from the weight shape. + """ + + def __init__(self): + super().__init__() + + def call(self, exported_program: ExportedProgram) -> PassResult: + gm = exported_program.graph_module + graph: torch.fx.Graph = gm.graph + modified = False + + for node in graph.nodes: + if not is_target_node(node, [torch.ops.aten.rms_norm.default]): + continue + + args = RMSNormArgs(*node.args, **node.kwargs) + input_ = args.input + weight = args.weight + eps = args.eps + + # weight is required for circle_custom.rms_norm + if weight is None: + continue + + # Use default eps if not provided (PyTorch default is 1e-5) + if eps is None: + eps = 1e-5 + + with gm.graph.inserting_before(node): + rms_norm_node = create_node( + graph, + torch.ops.circle_custom.rms_norm.default, + (input_, weight, eps), + origin=node, + ) + + # Reset meta to allow propagate_meta=True in replace_all_uses_with + # (same pattern as DecomposeGroupNorm) + rms_norm_node.meta = {} + node.replace_all_uses_with(rms_norm_node, propagate_meta=True) + modified = True + + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + return PassResult(modified) diff --git a/tico/utils/convert.py b/tico/utils/convert.py index 8bee6410..3d17f34d 100644 --- a/tico/utils/convert.py +++ b/tico/utils/convert.py @@ -46,6 +46,7 @@ from tico.passes.fill_meta_val import FillMetaVal from tico.passes.fuse_leading_unsqueeze_reshape import FuseLeadingUnsqueezeReshape from tico.passes.fuse_redundant_reshape_to_mean import FuseRedundantReshapeToMean +from tico.passes.fuse_rms_norm import FuseRmsNorm from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue from tico.passes.legalize_predefined_layout_operators import ( LegalizePreDefinedLayoutOperators, @@ -137,6 +138,7 @@ def run_decompositions(ep: ExportedProgram): torch.ops.aten.prelu.default, torch.ops.aten.linear.default, torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.rms_norm.default, ) for op in _preserve_ops: if op in _decomp_table: @@ -254,6 +256,7 @@ def convert_exported_module_to_circle( DecomposeSliceScatter(), DecomposeGroupNorm(), DecomposeBatchNorm(), + FuseRmsNorm(), DecomposeGroupedConv2d(), CastATenWhereArgType(), ConvertRepeatToExpandCopy(),