Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions test/unit_test/passes/test_fuse_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from tico.passes.fuse_rms_norm import FuseRmsNorm

from test.support.helper import num_of_ops
from test.support.pass_value_test import SinglePassValueTest


class RMSNormNet(torch.nn.Module):
def __init__(self, normalized_shape=16, eps=1e-6):
super().__init__()
self.rms_norm = torch.nn.RMSNorm(normalized_shape, eps=eps)

def forward(self, x):
return self.rms_norm(x)

def get_example_inputs(self):
return (torch.randn(1, 8, 16),), {}


class FuseRmsNormTest(SinglePassValueTest):
def test_pass(self):
self.setup(RMSNormNet())
self.assertEqual(
num_of_ops(self.exported_program(), [torch.ops.aten.rms_norm.default]), 1
)
self.assertEqual(
num_of_ops(
self.exported_program(), [torch.ops.circle_custom.rms_norm.default]
),
0,
)

self.run_value_test(FuseRmsNorm())

self.assertEqual(
num_of_ops(self.exported_program(), [torch.ops.aten.rms_norm.default]),
0,
)
self.assertEqual(
num_of_ops(
self.exported_program(), [torch.ops.circle_custom.rms_norm.default]
),
1,
)
100 changes: 100 additions & 0 deletions tico/passes/fuse_rms_norm.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks more like a direct op serialization case than a graph rewrite pass. Could you implenet similar implementation in tico/serialize/operators/?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I guess you're right :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it will be easier to raise this change in separate PR, but before I do that, please let me know your opinion

I could keep serialize/operators/adapters/llama_rmsnorm.py and quantization/wrapq/wrappers/ops/quant_rmsnorm.py as-is (using circle_custom.rms_norm) and add AtenRMSNormVisitor alongside RMSNormVisitor. However, I think it makes sense to simplify and:

  • replace all occurrences of circle_custom.rms_norm with aten.rms_norm
  • remove CircleRMSNorm from register_custom_op.py completely

This gives us a clean workflow:

  • normalize to aten.rms_norm when needed (e.g. for HuggingFace's LlamaRMSNorm which uses manual pow/mean/rsqrt/mul)
  • serialize/operators/op_rmsnorm.py has a single pipeline for all RMSNorm ops (regardless of whether they came from torch.nn.RMSNorm or were normalized by an adapter)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing all circle_custom.rms_norm users with aten.rms_norm and removing CircleRMSNorm sounds like a good cleanup. But it touches the HF adapter, quant wrapper, custom-op registration/fake impl, and tests. It also requires synthesizing normalized_shape from weight.shape and preserving the same restrictions (weight is not None, len(normalized_shape) == 1, eps is not None).

Actually, I think that would be better handled in a follow-up PR or when we really need it.

For this PR, just following codes looks enough.

# tico/serialize/operators/op_rmsnorm.py

from tico.serialize.circle_mapping import extract_shape
from tico.utils.errors import NotYetSupportedError
from tico.utils.validate_args_kwargs import CircleRMSNormArgs, RMSNormArgs

@register_node_visitor
class RMSNormVisitor(NodeVisitor):
    target = [
        torch.ops.circle_custom.rms_norm.default,
        torch.ops.aten.rms_norm.default,
    ]

    def _parse_args(self, node):
        if node.target == torch.ops.aten.rms_norm.default:
            args = RMSNormArgs(*node.args, **node.kwargs)

            if args.weight is None:
                raise NotYetSupportedError("RMSNorm without weight is not supported")

            if len(args.normalized_shape) != 1:
                raise NotYetSupportedError(
                    "Only 1-D normalized_shape RMSNorm is supported"
                )

            if list(extract_shape(args.weight)) != list(args.normalized_shape):
                raise NotYetSupportedError(
                    "RMSNorm weight shape should match normalized_shape"
                )

            eps = args.eps
            if eps is None:
                raise NotYetSupportedError("RMSNorm eps=None is not supported yet")

            return args.input, args.weight, eps

        args = CircleRMSNormArgs(*node.args, **node.kwargs)
        return args.input, args.weight, args.eps

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemented in #766

Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

if TYPE_CHECKING:
import torch.fx
import torch
from torch.export import ExportedProgram

from tico.utils.graph import create_node
from tico.utils.passes import PassBase, PassResult
from tico.utils.trace_decorators import trace_graph_diff_on_pass
from tico.utils.utils import is_target_node
from tico.utils.validate_args_kwargs import RMSNormArgs


@trace_graph_diff_on_pass
class FuseRmsNorm(PassBase):
"""
This pass fuses aten.rms_norm into circle_custom.rms_norm.

aten.rms_norm already includes the weight multiplication internally,
so its output is equivalent to circle_custom.rms_norm.

[before]

input (tensor, normalized_shape, weight, eps)
|
aten.rms_norm
|
output

[after]

input (tensor, weight, eps)
|
circle_custom.rms_norm
|
output

Note: normalized_shape is dropped because circle_custom.rms_norm
infers the normalization dimension from the weight shape.
Comment on lines +53 to +54

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CircleRMSNorm computes RMS over the last axis only. Therefore, following check shold be added.

if len(args.normalized_shape) != 1:
    raise NotYetSupportedError(
        "Only 1-D normalized_shape RMSNorm is supported"
    )

"""

def __init__(self):
super().__init__()

def call(self, exported_program: ExportedProgram) -> PassResult:
gm = exported_program.graph_module
graph: torch.fx.Graph = gm.graph
modified = False

for node in graph.nodes:
if not is_target_node(node, [torch.ops.aten.rms_norm.default]):
continue

args = RMSNormArgs(*node.args, **node.kwargs)
input_ = args.input
weight = args.weight
eps = args.eps

# weight is required for circle_custom.rms_norm
if weight is None:
continue

# Use default eps if not provided (PyTorch default is 1e-5)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the spec, eps is set to None.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just disallow this case.

if eps is None:
    raise NotYetSupportedError("aten.rms_norm with eps=None is not supported yet")

if eps is None:
eps = 1e-5

with gm.graph.inserting_before(node):
rms_norm_node = create_node(
graph,
torch.ops.circle_custom.rms_norm.default,
(input_, weight, eps),
origin=node,
)

# Reset meta to allow propagate_meta=True in replace_all_uses_with
# (same pattern as DecomposeGroupNorm)
rms_norm_node.meta = {}
node.replace_all_uses_with(rms_norm_node, propagate_meta=True)
modified = True

gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()

return PassResult(modified)
3 changes: 3 additions & 0 deletions tico/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from tico.passes.fill_meta_val import FillMetaVal
from tico.passes.fuse_leading_unsqueeze_reshape import FuseLeadingUnsqueezeReshape
from tico.passes.fuse_redundant_reshape_to_mean import FuseRedundantReshapeToMean
from tico.passes.fuse_rms_norm import FuseRmsNorm
from tico.passes.legalize_causal_mask_value import LegalizeCausalMaskValue
from tico.passes.legalize_predefined_layout_operators import (
LegalizePreDefinedLayoutOperators,
Expand Down Expand Up @@ -137,6 +138,7 @@ def run_decompositions(ep: ExportedProgram):
torch.ops.aten.prelu.default,
torch.ops.aten.linear.default,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.rms_norm.default,
)
for op in _preserve_ops:
if op in _decomp_table:
Expand Down Expand Up @@ -254,6 +256,7 @@ def convert_exported_module_to_circle(
DecomposeSliceScatter(),
DecomposeGroupNorm(),
DecomposeBatchNorm(),
FuseRmsNorm(),
DecomposeGroupedConv2d(),
CastATenWhereArgType(),
ConvertRepeatToExpandCopy(),
Expand Down
Loading