Skip to content

[Bug][Relax][ONNX] LayerNormalization without bias synthesizes a wrong-shape zero tensor and fails #19691

@wuyii8941

Description

@wuyii8941

Expected behavior

In the ONNX LayerNormalization spec, the bias input B is optional. When omitted, the operator should behave as if B is a tensor of zeros with the same shape as the scale W.

ONNX Runtime accepts the no-bias form fine.

Actual behavior

The TVM frontend synthesizes a zero bias whose shape is [data.shape[1]] (i.e., the second dim of the input, unrelated to the normalization axes). This shape does not match W, and relax.nn.layer_norm then fails:

tvm.error.InternalError: Op(relax.nn.layer_norm) requires the input gamma,
beta, etc., to have size same as the lengths of the data on the given axes.
However, there exists [T.int64(8)] and [T.int64(3)] that are unequal.

Reproduction

import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper
import onnxruntime as ort
from tvm.relax.frontend.onnx import from_onnx

X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4, 8])
Y = onnx.ValueInfoProto(); Y.name = "Y"
W = numpy_helper.from_array(np.ones((8,), dtype=np.float32), "W")
# axis = -1 (default), no B input
node = helper.make_node("LayerNormalization", ["X", "W"], ["Y"], axis=-1, epsilon=1e-5)
g = helper.make_graph([node], "g", [X], [Y], initializer=[W])
m = helper.make_model(g, opset_imports=[helper.make_opsetid("", 17)])

x = np.random.randn(2, 3, 4, 8).astype(np.float32)
print("ORT shape:", ort.InferenceSession(m.SerializeToString()).run(None, {"X": x})[0].shape)
# ORT shape: (2, 3, 4, 8)

inf = onnx.shape_inference.infer_shapes(m)
mod = from_onnx(inf)  # InternalError: gamma/beta size mismatch

Root cause

python/tvm/relax/frontend/onnx/onnx_frontend.py, LayerNormalization._impl_v17:

gamma_shape = get_const_tuple(scale.struct_info.shape)

if bias is None:
    seq_len = data.struct_info.shape[1].value      # <-- wrong: uses data dim 1
    bias = relax.const([0.0] * seq_len, dtype="float32")
else:
    beta_shape = get_const_tuple(bias.struct_info.shape)
    if gamma_shape != beta_shape:
        raise ValueError("gamma and beta shapes do not match")

The synthesized bias should match gamma_shape, not data.shape[1]. With the example above, gamma_shape = (8,) but the bias is created with length data.shape[1] = 3.

A secondary issue: indexing data.struct_info.shape[1] is fragile — it crashes if data is symbolic or has fewer than 2 dims.

Suggested fix

if bias is None:
    bias = relax.const(np.zeros(gamma_shape, dtype="float32"))
else:
    beta_shape = get_const_tuple(bias.struct_info.shape)
    if gamma_shape != beta_shape:
        raise ValueError("gamma and beta shapes do not match")

Impact

Any ONNX model exporting LayerNormalization without explicit B fails. PyTorch's nn.LayerNorm(elementwise_affine=True, bias=False) (added in PyTorch 1.10 to control bias separately) exports exactly this form; many transformer variants (LLaMA-style architectures with no LayerNorm bias) hit this path.

Environment

  • TVM: latest main (commit b172d5e)
  • Python: 3.11
  • ONNX Runtime: 1.24.4

cc @KJlaccHoeUM9l @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions