Expected behavior
In the ONNX LayerNormalization spec, the bias input B is optional. When omitted, the operator should behave as if B is a tensor of zeros with the same shape as the scale W.
ONNX Runtime accepts the no-bias form fine.
Actual behavior
The TVM frontend synthesizes a zero bias whose shape is [data.shape[1]] (i.e., the second dim of the input, unrelated to the normalization axes). This shape does not match W, and relax.nn.layer_norm then fails:
tvm.error.InternalError: Op(relax.nn.layer_norm) requires the input gamma,
beta, etc., to have size same as the lengths of the data on the given axes.
However, there exists [T.int64(8)] and [T.int64(3)] that are unequal.
Reproduction
import numpy as np
import onnx
from onnx import helper, TensorProto, numpy_helper
import onnxruntime as ort
from tvm.relax.frontend.onnx import from_onnx
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3, 4, 8])
Y = onnx.ValueInfoProto(); Y.name = "Y"
W = numpy_helper.from_array(np.ones((8,), dtype=np.float32), "W")
# axis = -1 (default), no B input
node = helper.make_node("LayerNormalization", ["X", "W"], ["Y"], axis=-1, epsilon=1e-5)
g = helper.make_graph([node], "g", [X], [Y], initializer=[W])
m = helper.make_model(g, opset_imports=[helper.make_opsetid("", 17)])
x = np.random.randn(2, 3, 4, 8).astype(np.float32)
print("ORT shape:", ort.InferenceSession(m.SerializeToString()).run(None, {"X": x})[0].shape)
# ORT shape: (2, 3, 4, 8)
inf = onnx.shape_inference.infer_shapes(m)
mod = from_onnx(inf) # InternalError: gamma/beta size mismatch
Root cause
python/tvm/relax/frontend/onnx/onnx_frontend.py, LayerNormalization._impl_v17:
gamma_shape = get_const_tuple(scale.struct_info.shape)
if bias is None:
seq_len = data.struct_info.shape[1].value # <-- wrong: uses data dim 1
bias = relax.const([0.0] * seq_len, dtype="float32")
else:
beta_shape = get_const_tuple(bias.struct_info.shape)
if gamma_shape != beta_shape:
raise ValueError("gamma and beta shapes do not match")
The synthesized bias should match gamma_shape, not data.shape[1]. With the example above, gamma_shape = (8,) but the bias is created with length data.shape[1] = 3.
A secondary issue: indexing data.struct_info.shape[1] is fragile — it crashes if data is symbolic or has fewer than 2 dims.
Suggested fix
if bias is None:
bias = relax.const(np.zeros(gamma_shape, dtype="float32"))
else:
beta_shape = get_const_tuple(bias.struct_info.shape)
if gamma_shape != beta_shape:
raise ValueError("gamma and beta shapes do not match")
Impact
Any ONNX model exporting LayerNormalization without explicit B fails. PyTorch's nn.LayerNorm(elementwise_affine=True, bias=False) (added in PyTorch 1.10 to control bias separately) exports exactly this form; many transformer variants (LLaMA-style architectures with no LayerNorm bias) hit this path.
Environment
- TVM: latest
main (commit b172d5e)
- Python: 3.11
- ONNX Runtime: 1.24.4
cc @KJlaccHoeUM9l @junrushao
Expected behavior
In the ONNX
LayerNormalizationspec, the bias inputBis optional. When omitted, the operator should behave as ifBis a tensor of zeros with the same shape as the scaleW.ONNX Runtime accepts the no-bias form fine.
Actual behavior
The TVM frontend synthesizes a zero bias whose shape is
[data.shape[1]](i.e., the second dim of the input, unrelated to the normalization axes). This shape does not matchW, andrelax.nn.layer_normthen fails:Reproduction
Root cause
python/tvm/relax/frontend/onnx/onnx_frontend.py,LayerNormalization._impl_v17:The synthesized
biasshould matchgamma_shape, notdata.shape[1]. With the example above,gamma_shape = (8,)but the bias is created with lengthdata.shape[1] = 3.A secondary issue: indexing
data.struct_info.shape[1]is fragile — it crashes ifdatais symbolic or has fewer than 2 dims.Suggested fix
Impact
Any ONNX model exporting
LayerNormalizationwithout explicitBfails. PyTorch'snn.LayerNorm(elementwise_affine=True, bias=False)(added in PyTorch 1.10 to control bias separately) exports exactly this form; many transformer variants (LLaMA-style architectures with no LayerNorm bias) hit this path.Environment
main(commit b172d5e)cc @KJlaccHoeUM9l @junrushao