Expected behavior
relax.op.nn.dropout is publicly exported and discoverable via tvm.relax.op.nn. Constructing an IR module that contains it and calling relax.build should either:
- Compile to a working module (at
rate=0, semantically an identity), or
- Raise a clear, actionable error mentioning
nn.dropout.
Actual behavior
The legalize for relax.nn.dropout only logs an INFO message and returns the call unchanged:
@register_legalize("relax.nn.dropout")
def _nn_dropout(bb, call):
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
return call
The unlegalized intrinsic then crashes VM codegen with:
tvm.error.InternalError: CodeGenVM cannot handle this intrinsic now:
Op(relax.nn.dropout)
There is no Dropout converter in the Relax ONNX frontend either, so the implicit assumption that "frontends will strip dropout" doesn't even hold for ONNX-imported models.
Reproduction
import numpy as np
import tvm
from tvm import relax
arr = np.ones((4,), dtype=np.float32)
bb = relax.BlockBuilder()
x = relax.Var("x", relax.TensorStructInfo(arr.shape, "float32"))
with bb.function("main", [x]):
with bb.dataflow():
tup = bb.emit(relax.op.nn.dropout(x, rate=0.0)) # even rate=0 fails!
y0 = bb.emit(relax.TupleGetItem(tup, 0))
gv = bb.emit_output(y0)
bb.emit_func_output(gv)
mod = bb.get()
mod = relax.transform.LegalizeOps()(mod)
relax.build(mod, target="llvm")
# InternalError: CodeGenVM cannot handle this intrinsic now: Op(relax.nn.dropout)
ONNX path is also broken:
import onnx
from onnx import helper, TensorProto
from tvm.relax.frontend.onnx import from_onnx
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4])
Y = onnx.ValueInfoProto(); Y.name = "Y"
node = helper.make_node("Dropout", ["X"], ["Y"]) # opset 7 form
g = helper.make_graph([node], "g", [X], [Y])
m = helper.make_model(g, opset_imports=[helper.make_opsetid("", 7)])
mod = from_onnx(onnx.shape_inference.infer_shapes(m))
# OpNotImplemented: The following operators are not supported for frontend ONNX: Dropout
Root cause
Two places:
-
python/tvm/relax/transform/legalize_ops/nn.py:
@register_legalize("relax.nn.dropout")
def _nn_dropout(bb, call):
logging.info("Dropout is handled by frontend translator at this moment and is not legalized.")
return call
-
python/tvm/relax/frontend/onnx/onnx_frontend.py — no converter class for Dropout.
Suggested fix
Most operators do not need any random sampling at inference time; an inference_mode identity legalize is enough for the vast majority of compile flows:
@register_legalize("relax.nn.dropout")
def _nn_dropout(bb, call):
# dropout returns a tuple (output, mask); in inference mode (rate=0 or training_mode=False)
# output == input, mask == ones_like(input).
rate = call.attrs.rate
if isinstance(rate, (float, int)) and float(rate) == 0.0:
x = call.args[0]
mask = bb.call_te(topi.full_like, x, 1.0)
return relax.Tuple([x, mask])
raise tvm.error.OpNotImplemented(
"nn.dropout with rate > 0 is not yet lowered; the frontend translator "
"must strip dropout before reaching LegalizeOps."
)
And add a Dropout converter to the ONNX frontend that, in inference mode (the only mode TVM compiles), returns the input unchanged.
Environment
- TVM: latest
main (commit b172d5e)
- Python: 3.11
cc @junrushao
Expected behavior
relax.op.nn.dropoutis publicly exported and discoverable viatvm.relax.op.nn. Constructing an IR module that contains it and callingrelax.buildshould either:rate=0, semantically an identity), ornn.dropout.Actual behavior
The legalize for
relax.nn.dropoutonly logs anINFOmessage and returns the call unchanged:The unlegalized intrinsic then crashes VM codegen with:
There is no Dropout converter in the Relax ONNX frontend either, so the implicit assumption that "frontends will strip dropout" doesn't even hold for ONNX-imported models.
Reproduction
ONNX path is also broken:
Root cause
Two places:
python/tvm/relax/transform/legalize_ops/nn.py:python/tvm/relax/frontend/onnx/onnx_frontend.py— no converter class forDropout.Suggested fix
Most operators do not need any random sampling at inference time; an
inference_modeidentity legalize is enough for the vast majority of compile flows:And add a
Dropoutconverter to the ONNX frontend that, in inference mode (the only mode TVM compiles), returns the input unchanged.Environment
main(commit b172d5e)cc @junrushao