-
Notifications
You must be signed in to change notification settings - Fork 31
[passes] Fuse RMSNorm to Circle RMSNorm op #754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| # Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import torch | ||
|
|
||
| from tico.passes.fuse_rms_norm import FuseRmsNorm | ||
|
|
||
| from test.support.helper import num_of_ops | ||
| from test.support.pass_value_test import SinglePassValueTest | ||
|
|
||
|
|
||
| class RMSNormNet(torch.nn.Module): | ||
| def __init__(self, normalized_shape=16, eps=1e-6): | ||
| super().__init__() | ||
| self.rms_norm = torch.nn.RMSNorm(normalized_shape, eps=eps) | ||
|
|
||
| def forward(self, x): | ||
| return self.rms_norm(x) | ||
|
|
||
| def get_example_inputs(self): | ||
| return (torch.randn(1, 8, 16),), {} | ||
|
|
||
|
|
||
| class FuseRmsNormTest(SinglePassValueTest): | ||
| def test_pass(self): | ||
| self.setup(RMSNormNet()) | ||
| self.assertEqual( | ||
| num_of_ops(self.exported_program(), [torch.ops.aten.rms_norm.default]), 1 | ||
| ) | ||
| self.assertEqual( | ||
| num_of_ops( | ||
| self.exported_program(), [torch.ops.circle_custom.rms_norm.default] | ||
| ), | ||
| 0, | ||
| ) | ||
|
|
||
| self.run_value_test(FuseRmsNorm()) | ||
|
|
||
| self.assertEqual( | ||
| num_of_ops(self.exported_program(), [torch.ops.aten.rms_norm.default]), | ||
| 0, | ||
| ) | ||
| self.assertEqual( | ||
| num_of_ops( | ||
| self.exported_program(), [torch.ops.circle_custom.rms_norm.default] | ||
| ), | ||
| 1, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| # Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| import torch.fx | ||
| import torch | ||
| from torch.export import ExportedProgram | ||
|
|
||
| from tico.utils.graph import create_node | ||
| from tico.utils.passes import PassBase, PassResult | ||
| from tico.utils.trace_decorators import trace_graph_diff_on_pass | ||
| from tico.utils.utils import is_target_node | ||
| from tico.utils.validate_args_kwargs import RMSNormArgs | ||
|
|
||
|
|
||
| @trace_graph_diff_on_pass | ||
| class FuseRmsNorm(PassBase): | ||
| """ | ||
| This pass fuses aten.rms_norm into circle_custom.rms_norm. | ||
|
|
||
| aten.rms_norm already includes the weight multiplication internally, | ||
| so its output is equivalent to circle_custom.rms_norm. | ||
|
|
||
| [before] | ||
|
|
||
| input (tensor, normalized_shape, weight, eps) | ||
| | | ||
| aten.rms_norm | ||
| | | ||
| output | ||
|
|
||
| [after] | ||
|
|
||
| input (tensor, weight, eps) | ||
| | | ||
| circle_custom.rms_norm | ||
| | | ||
| output | ||
|
|
||
| Note: normalized_shape is dropped because circle_custom.rms_norm | ||
| infers the normalization dimension from the weight shape. | ||
|
Comment on lines
+53
to
+54
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CircleRMSNorm computes RMS over the last axis only. Therefore, following check shold be added. if len(args.normalized_shape) != 1:
raise NotYetSupportedError(
"Only 1-D normalized_shape RMSNorm is supported"
) |
||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
||
| def call(self, exported_program: ExportedProgram) -> PassResult: | ||
| gm = exported_program.graph_module | ||
| graph: torch.fx.Graph = gm.graph | ||
| modified = False | ||
|
|
||
| for node in graph.nodes: | ||
| if not is_target_node(node, [torch.ops.aten.rms_norm.default]): | ||
| continue | ||
|
|
||
| args = RMSNormArgs(*node.args, **node.kwargs) | ||
| input_ = args.input | ||
| weight = args.weight | ||
| eps = args.eps | ||
|
|
||
| # weight is required for circle_custom.rms_norm | ||
| if weight is None: | ||
| continue | ||
|
|
||
| # Use default eps if not provided (PyTorch default is 1e-5) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to the spec, eps is set to None.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's just disallow this case. if eps is None:
raise NotYetSupportedError("aten.rms_norm with eps=None is not supported yet") |
||
| if eps is None: | ||
| eps = 1e-5 | ||
|
|
||
| with gm.graph.inserting_before(node): | ||
| rms_norm_node = create_node( | ||
| graph, | ||
| torch.ops.circle_custom.rms_norm.default, | ||
| (input_, weight, eps), | ||
| origin=node, | ||
| ) | ||
|
|
||
| # Reset meta to allow propagate_meta=True in replace_all_uses_with | ||
| # (same pattern as DecomposeGroupNorm) | ||
| rms_norm_node.meta = {} | ||
| node.replace_all_uses_with(rms_norm_node, propagate_meta=True) | ||
| modified = True | ||
|
|
||
| gm.graph.eliminate_dead_code() | ||
| gm.graph.lint() | ||
| gm.recompile() | ||
|
|
||
| return PassResult(modified) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks more like a direct op serialization case than a graph rewrite pass. Could you implenet similar implementation in
tico/serialize/operators/?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm I guess you're right :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it will be easier to raise this change in separate PR, but before I do that, please let me know your opinion
I could keep
serialize/operators/adapters/llama_rmsnorm.pyandquantization/wrapq/wrappers/ops/quant_rmsnorm.pyas-is (usingcircle_custom.rms_norm) and addAtenRMSNormVisitoralongsideRMSNormVisitor. However, I think it makes sense to simplify and:circle_custom.rms_normwithaten.rms_normCircleRMSNormfromregister_custom_op.pycompletelyThis gives us a clean workflow:
aten.rms_normwhen needed (e.g. for HuggingFace's LlamaRMSNorm which uses manual pow/mean/rsqrt/mul)serialize/operators/op_rmsnorm.pyhas a single pipeline for all RMSNorm ops (regardless of whether they came fromtorch.nn.RMSNormor were normalized by an adapter)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replacing all
circle_custom.rms_normusers withaten.rms_normand removingCircleRMSNormsounds like a good cleanup. But it touches the HF adapter, quant wrapper, custom-op registration/fake impl, and tests. It also requires synthesizingnormalized_shapefromweight.shapeand preserving the same restrictions (weight is not None,len(normalized_shape) == 1,eps is not None).Actually, I think that would be better handled in a follow-up PR or when we really need it.
For this PR, just following codes looks enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implemented in #766