Skip to content

[passes] Fuse RMSNorm to Circle RMSNorm op#754

Closed
dahlinPL wants to merge 1 commit into
Samsung:mainfrom
dahlinPL:rmsnorm
Closed

[passes] Fuse RMSNorm to Circle RMSNorm op#754
dahlinPL wants to merge 1 commit into
Samsung:mainfrom
dahlinPL:rmsnorm

Conversation

@dahlinPL

@dahlinPL dahlinPL commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

This commit adds FuseRmsNorm() to the circle_legalize PassManager

@dahlinPL

dahlinPL commented Jun 2, 2026

Copy link
Copy Markdown
Contributor Author

sample code used for testing:

import torch
import torch.nn as nn

import tico


class RMSNormModel(nn.Module):
    """Simple model with RMSNorm for LLaMA/Qwen-style normalization."""

    def __init__(self):
        super().__init__()
        self.rms_norm = nn.RMSNorm(normalized_shape=64, eps=1e-6)

    def forward(self, x):
        return self.rms_norm(x)


if __name__ == "__main__":
    torch_module = RMSNormModel()
    example_inputs = (torch.randn(1, 8, 64, dtype=torch.float32),)

    circle_model = tico.convert(torch_module.eval(), example_inputs)
    circle_model.save("rmsnorm.circle")

    print("Circle model saved to: rmsnorm.circle")

result from main branch with tico==0.2.0.dev260601:
image

result from this PR:
image

@dahlinPL dahlinPL marked this pull request as ready for review June 2, 2026 14:03
@dahlinPL

dahlinPL commented Jun 2, 2026

Copy link
Copy Markdown
Contributor Author

local run of test added:

:~/git/fork/TICO$ ./ccex test -k FuseRmsNorm
RUN unit tests with -k FuseRmsNorm ...
2026-06-02 16:28:08.363913: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-06-02 16:28:08.390370: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-06-02 16:28:09.272325: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
test_pass (unit_test.passes.test_fuse_rms_norm.FuseRmsNormTest) ... ok

----------------------------------------------------------------------
Ran 1 test in 0.023s

OK
:~/git/fork/TICO$ 

@dahlinPL dahlinPL force-pushed the rmsnorm branch 3 times, most recently from df164b7 to aafa905 Compare June 2, 2026 14:38
This commit adds FuseRmsNorm() to the circle_legalize PassManager

TICO-DCO-1.0-Signed-off-by: Marcin Słowiński <m.slowinski2@samsung.com>
@mhs4670go

Copy link
Copy Markdown
Contributor

Just curiosity, which model need this feature?

@dahlinPL

dahlinPL commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

Just curiosity, which model need this feature?

No specific model, I was just investigating why tests designed to use RMSNorm are failing and realized that only Llama customized pass exists (introduced in #266) and "clean" RMSNorm is not fused into Circle as a single op. Therefore, model compilation is failing because of lack of support for rsqrt in the compiler.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks more like a direct op serialization case than a graph rewrite pass. Could you implenet similar implementation in tico/serialize/operators/?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I guess you're right :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it will be easier to raise this change in separate PR, but before I do that, please let me know your opinion

I could keep serialize/operators/adapters/llama_rmsnorm.py and quantization/wrapq/wrappers/ops/quant_rmsnorm.py as-is (using circle_custom.rms_norm) and add AtenRMSNormVisitor alongside RMSNormVisitor. However, I think it makes sense to simplify and:

  • replace all occurrences of circle_custom.rms_norm with aten.rms_norm
  • remove CircleRMSNorm from register_custom_op.py completely

This gives us a clean workflow:

  • normalize to aten.rms_norm when needed (e.g. for HuggingFace's LlamaRMSNorm which uses manual pow/mean/rsqrt/mul)
  • serialize/operators/op_rmsnorm.py has a single pipeline for all RMSNorm ops (regardless of whether they came from torch.nn.RMSNorm or were normalized by an adapter)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing all circle_custom.rms_norm users with aten.rms_norm and removing CircleRMSNorm sounds like a good cleanup. But it touches the HF adapter, quant wrapper, custom-op registration/fake impl, and tests. It also requires synthesizing normalized_shape from weight.shape and preserving the same restrictions (weight is not None, len(normalized_shape) == 1, eps is not None).

Actually, I think that would be better handled in a follow-up PR or when we really need it.

For this PR, just following codes looks enough.

# tico/serialize/operators/op_rmsnorm.py

from tico.serialize.circle_mapping import extract_shape
from tico.utils.errors import NotYetSupportedError
from tico.utils.validate_args_kwargs import CircleRMSNormArgs, RMSNormArgs

@register_node_visitor
class RMSNormVisitor(NodeVisitor):
    target = [
        torch.ops.circle_custom.rms_norm.default,
        torch.ops.aten.rms_norm.default,
    ]

    def _parse_args(self, node):
        if node.target == torch.ops.aten.rms_norm.default:
            args = RMSNormArgs(*node.args, **node.kwargs)

            if args.weight is None:
                raise NotYetSupportedError("RMSNorm without weight is not supported")

            if len(args.normalized_shape) != 1:
                raise NotYetSupportedError(
                    "Only 1-D normalized_shape RMSNorm is supported"
                )

            if list(extract_shape(args.weight)) != list(args.normalized_shape):
                raise NotYetSupportedError(
                    "RMSNorm weight shape should match normalized_shape"
                )

            eps = args.eps
            if eps is None:
                raise NotYetSupportedError("RMSNorm eps=None is not supported yet")

            return args.input, args.weight, eps

        args = CircleRMSNormArgs(*node.args, **node.kwargs)
        return args.input, args.weight, args.eps

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemented in #766

Comment on lines +53 to +54
Note: normalized_shape is dropped because circle_custom.rms_norm
infers the normalization dimension from the weight shape.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CircleRMSNorm computes RMS over the last axis only. Therefore, following check shold be added.

if len(args.normalized_shape) != 1:
    raise NotYetSupportedError(
        "Only 1-D normalized_shape RMSNorm is supported"
    )

if weight is None:
continue

# Use default eps if not provided (PyTorch default is 1e-5)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the spec, eps is set to None.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just disallow this case.

if eps is None:
    raise NotYetSupportedError("aten.rms_norm with eps=None is not supported yet")

@dahlinPL dahlinPL marked this pull request as draft June 8, 2026 13:45
@dahlinPL

dahlinPL commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

delivered with #766

@dahlinPL dahlinPL closed this Jun 9, 2026
@dahlinPL dahlinPL deleted the rmsnorm branch June 9, 2026 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants