[passes] Fuse RMSNorm to Circle RMSNorm op#754
Conversation
|
sample code used for testing: import torch
import torch.nn as nn
import tico
class RMSNormModel(nn.Module):
"""Simple model with RMSNorm for LLaMA/Qwen-style normalization."""
def __init__(self):
super().__init__()
self.rms_norm = nn.RMSNorm(normalized_shape=64, eps=1e-6)
def forward(self, x):
return self.rms_norm(x)
if __name__ == "__main__":
torch_module = RMSNormModel()
example_inputs = (torch.randn(1, 8, 64, dtype=torch.float32),)
circle_model = tico.convert(torch_module.eval(), example_inputs)
circle_model.save("rmsnorm.circle")
print("Circle model saved to: rmsnorm.circle") |
|
local run of test added: :~/git/fork/TICO$ ./ccex test -k FuseRmsNorm
RUN unit tests with -k FuseRmsNorm ...
2026-06-02 16:28:08.363913: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-06-02 16:28:08.390370: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-06-02 16:28:09.272325: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
test_pass (unit_test.passes.test_fuse_rms_norm.FuseRmsNormTest) ... ok
----------------------------------------------------------------------
Ran 1 test in 0.023s
OK
:~/git/fork/TICO$ |
df164b7 to
aafa905
Compare
This commit adds FuseRmsNorm() to the circle_legalize PassManager TICO-DCO-1.0-Signed-off-by: Marcin Słowiński <m.slowinski2@samsung.com>
|
Just curiosity, which model need this feature? |
No specific model, I was just investigating why tests designed to use RMSNorm are failing and realized that only Llama customized pass exists (introduced in #266) and "clean" RMSNorm is not fused into Circle as a single op. Therefore, model compilation is failing because of lack of support for rsqrt in the compiler. |
There was a problem hiding this comment.
It looks more like a direct op serialization case than a graph rewrite pass. Could you implenet similar implementation in tico/serialize/operators/?
There was a problem hiding this comment.
hmm I guess you're right :)
There was a problem hiding this comment.
I guess it will be easier to raise this change in separate PR, but before I do that, please let me know your opinion
I could keep serialize/operators/adapters/llama_rmsnorm.py and quantization/wrapq/wrappers/ops/quant_rmsnorm.py as-is (using circle_custom.rms_norm) and add AtenRMSNormVisitor alongside RMSNormVisitor. However, I think it makes sense to simplify and:
- replace all occurrences of
circle_custom.rms_normwithaten.rms_norm - remove
CircleRMSNormfromregister_custom_op.pycompletely
This gives us a clean workflow:
- normalize to
aten.rms_normwhen needed (e.g. for HuggingFace's LlamaRMSNorm which uses manual pow/mean/rsqrt/mul) serialize/operators/op_rmsnorm.pyhas a single pipeline for all RMSNorm ops (regardless of whether they came fromtorch.nn.RMSNormor were normalized by an adapter)
There was a problem hiding this comment.
Replacing all circle_custom.rms_norm users with aten.rms_norm and removing CircleRMSNorm sounds like a good cleanup. But it touches the HF adapter, quant wrapper, custom-op registration/fake impl, and tests. It also requires synthesizing normalized_shape from weight.shape and preserving the same restrictions (weight is not None, len(normalized_shape) == 1, eps is not None).
Actually, I think that would be better handled in a follow-up PR or when we really need it.
For this PR, just following codes looks enough.
# tico/serialize/operators/op_rmsnorm.py
from tico.serialize.circle_mapping import extract_shape
from tico.utils.errors import NotYetSupportedError
from tico.utils.validate_args_kwargs import CircleRMSNormArgs, RMSNormArgs
@register_node_visitor
class RMSNormVisitor(NodeVisitor):
target = [
torch.ops.circle_custom.rms_norm.default,
torch.ops.aten.rms_norm.default,
]
def _parse_args(self, node):
if node.target == torch.ops.aten.rms_norm.default:
args = RMSNormArgs(*node.args, **node.kwargs)
if args.weight is None:
raise NotYetSupportedError("RMSNorm without weight is not supported")
if len(args.normalized_shape) != 1:
raise NotYetSupportedError(
"Only 1-D normalized_shape RMSNorm is supported"
)
if list(extract_shape(args.weight)) != list(args.normalized_shape):
raise NotYetSupportedError(
"RMSNorm weight shape should match normalized_shape"
)
eps = args.eps
if eps is None:
raise NotYetSupportedError("RMSNorm eps=None is not supported yet")
return args.input, args.weight, eps
args = CircleRMSNormArgs(*node.args, **node.kwargs)
return args.input, args.weight, args.eps| Note: normalized_shape is dropped because circle_custom.rms_norm | ||
| infers the normalization dimension from the weight shape. |
There was a problem hiding this comment.
CircleRMSNorm computes RMS over the last axis only. Therefore, following check shold be added.
if len(args.normalized_shape) != 1:
raise NotYetSupportedError(
"Only 1-D normalized_shape RMSNorm is supported"
)| if weight is None: | ||
| continue | ||
|
|
||
| # Use default eps if not provided (PyTorch default is 1e-5) |
There was a problem hiding this comment.
Let's just disallow this case.
if eps is None:
raise NotYetSupportedError("aten.rms_norm with eps=None is not supported yet")|
delivered with #766 |


This commit adds FuseRmsNorm() to the circle_legalize PassManager