Skip to content

[Bug][Relax] nn.dropout legalize is a no-op, causing opaque VM codegen crash even at rate=0 #19695

@wuyii8941

Description

@wuyii8941

Expected behavior

relax.op.nn.dropout is publicly exported and discoverable via tvm.relax.op.nn. Constructing an IR module that contains it and calling relax.build should either:

  1. Compile to a working module (at rate=0, semantically an identity), or
  2. Raise a clear, actionable error mentioning nn.dropout.

Actual behavior

The legalize for relax.nn.dropout only logs an INFO message and returns the call unchanged:

@register_legalize("relax.nn.dropout")
def _nn_dropout(bb, call):
    logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
    return call

The unlegalized intrinsic then crashes VM codegen with:

tvm.error.InternalError: CodeGenVM cannot handle this intrinsic now:
Op(relax.nn.dropout)

There is no Dropout converter in the Relax ONNX frontend either, so the implicit assumption that "frontends will strip dropout" doesn't even hold for ONNX-imported models.

Reproduction

import numpy as np
import tvm
from tvm import relax

arr = np.ones((4,), dtype=np.float32)
bb = relax.BlockBuilder()
x = relax.Var("x", relax.TensorStructInfo(arr.shape, "float32"))
with bb.function("main", [x]):
    with bb.dataflow():
        tup = bb.emit(relax.op.nn.dropout(x, rate=0.0))   # even rate=0 fails!
        y0 = bb.emit(relax.TupleGetItem(tup, 0))
        gv = bb.emit_output(y0)
    bb.emit_func_output(gv)

mod = bb.get()
mod = relax.transform.LegalizeOps()(mod)
relax.build(mod, target="llvm")
# InternalError: CodeGenVM cannot handle this intrinsic now: Op(relax.nn.dropout)

ONNX path is also broken:

import onnx
from onnx import helper, TensorProto
from tvm.relax.frontend.onnx import from_onnx
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4])
Y = onnx.ValueInfoProto(); Y.name = "Y"
node = helper.make_node("Dropout", ["X"], ["Y"])    # opset 7 form
g = helper.make_graph([node], "g", [X], [Y])
m = helper.make_model(g, opset_imports=[helper.make_opsetid("", 7)])
mod = from_onnx(onnx.shape_inference.infer_shapes(m))
# OpNotImplemented: The following operators are not supported for frontend ONNX: Dropout

Root cause

Two places:

  1. python/tvm/relax/transform/legalize_ops/nn.py:

    @register_legalize("relax.nn.dropout")
    def _nn_dropout(bb, call):
        logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
        return call
  2. python/tvm/relax/frontend/onnx/onnx_frontend.py — no converter class for Dropout.

Suggested fix

Most operators do not need any random sampling at inference time; an inference_mode identity legalize is enough for the vast majority of compile flows:

@register_legalize("relax.nn.dropout")
def _nn_dropout(bb, call):
    # dropout returns a tuple (output, mask); in inference mode (rate=0 or training_mode=False)
    # output == input, mask == ones_like(input).
    rate = call.attrs.rate
    if isinstance(rate, (float, int)) and float(rate) == 0.0:
        x = call.args[0]
        mask = bb.call_te(topi.full_like, x, 1.0)
        return relax.Tuple([x, mask])
    raise tvm.error.OpNotImplemented(
        "nn.dropout with rate > 0 is not yet lowered; the frontend translator "
        "must strip dropout before reaching LegalizeOps."
    )

And add a Dropout converter to the ONNX frontend that, in inference mode (the only mode TVM compiles), returns the input unchanged.

Environment

  • TVM: latest main (commit b172d5e)
  • Python: 3.11

cc @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions