Skip to content

Commit 4de1f11

Browse files
authored
[Relax] Add conv3d_transpose and ONNX ConvTranspose 3D support (#18948)
Introduce relax.nn.conv3d_transpose (attrs, C++ inference/layout, Python API) and lower it to TOPI group_conv3d_transpose_ncdhw when using NCDHW/IODHW with dilation 1, matching the conv2d_transpose legalization policy. Wire the Relax ONNX frontend to emit conv3d_transpose for 5D inputs. Extend tests for ONNX, struct info, LegalizeOps, and TVMScript round-trip; fix ConvTranspose test output spatial size to include output_padding.#18945
1 parent 52b5d55 commit 4de1f11

12 files changed

Lines changed: 687 additions & 7 deletions

File tree

include/tvm/relax/attrs/nn.h

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,58 @@ struct Conv2DTransposeAttrs : public AttrsNodeReflAdapter<Conv2DTransposeAttrs>
267267
BaseAttrsNode);
268268
}; // struct Conv2DTransposeAttrs
269269

270+
/*! \brief Attributes used in Conv3dTranspose operator */
271+
struct Conv3DTransposeAttrs : public AttrsNodeReflAdapter<Conv3DTransposeAttrs> {
272+
ffi::Array<int64_t> strides;
273+
ffi::Array<int64_t> padding;
274+
ffi::Array<int64_t> output_padding;
275+
ffi::Array<int64_t> dilation;
276+
int groups;
277+
ffi::String data_layout;
278+
ffi::String kernel_layout;
279+
ffi::String out_layout;
280+
DataType out_dtype;
281+
282+
static void RegisterReflection() {
283+
namespace refl = tvm::ffi::reflection;
284+
refl::ObjectDef<Conv3DTransposeAttrs>()
285+
.def_ro("strides", &Conv3DTransposeAttrs::strides,
286+
"Specifies the strides of the convolution.")
287+
.def_ro("padding", &Conv3DTransposeAttrs::padding,
288+
"If padding is non-zero, then the input is implicitly zero-padded"
289+
"Padding support both symmetric and asymmetric as"
290+
"one int : same padding used on all sides"
291+
"three int : back/bottom/right will use same padding as front/top/left"
292+
"six int : padding width in the order of (front, top, left, back, bottom, right)")
293+
.def_ro("output_padding", &Conv3DTransposeAttrs::output_padding,
294+
"Used to disambiguate the output shape.")
295+
.def_ro("dilation", &Conv3DTransposeAttrs::dilation,
296+
"Specifies the dilation rate to use for dilated convolution.")
297+
.def_ro("groups", &Conv3DTransposeAttrs::groups,
298+
"Number of groups to split the input into for grouped convolution. The number of "
299+
"input and "
300+
"output channels should be divisible by the number of groups.")
301+
.def_ro("data_layout", &Conv3DTransposeAttrs::data_layout,
302+
"Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
303+
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
304+
"dimensions respectively. Convolution is applied on the 'D', 'H', and"
305+
"'W' dimensions.")
306+
.def_ro("kernel_layout", &Conv3DTransposeAttrs::kernel_layout,
307+
"Dimension ordering of weight. Can be 'IODHW', etc."
308+
"'I', 'O', 'D', 'H', 'W' stands for input_channel, output_channel, depth, height, and "
309+
"width"
310+
"dimensions respectively.")
311+
.def_ro("out_layout", &Conv3DTransposeAttrs::out_layout,
312+
"Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
313+
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
314+
"dimensions respectively. Default to be same as input layout.")
315+
.def_ro("out_dtype", &Conv3DTransposeAttrs::out_dtype,
316+
"Output data type, set to explicit type under mixed precision setting");
317+
}
318+
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Conv3DTransposeAttrs", Conv3DTransposeAttrs,
319+
BaseAttrsNode);
320+
}; // struct Conv3DTransposeAttrs
321+
270322
/*! \brief Attributes used in max_pool1d and avg_pool1d operator */
271323
struct Pool1DAttrs : public AttrsNodeReflAdapter<Pool1DAttrs> {
272324
ffi::Array<int64_t> pool_size;

python/tvm/relax/frontend/onnx/onnx_frontend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1364,7 +1364,9 @@ def _impl_v1(cls, bb, inputs, attr, params):
13641364
data_layout = "NCHW"
13651365
kernel_layout = "IOHW"
13661366
elif ndim == 5:
1367-
raise NotImplementedError("Relax ConvTranspose3d not supported yet")
1367+
op = relax.op.nn.conv3d_transpose
1368+
data_layout = "NCDHW"
1369+
kernel_layout = "IODHW"
13681370
else:
13691371
raise NotImplementedError("Ndim > 5 not supported for convolution.")
13701372

python/tvm/relax/op/nn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
conv2d,
3535
conv2d_transpose,
3636
conv3d,
37+
conv3d_transpose,
3738
cross_entropy_with_logits,
3839
dropout,
3940
gelu,

python/tvm/relax/op/nn/nn.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,10 @@ def conv3d(
293293
out_dtype : Optional[Union[str, DataType]]
294294
Specifies the output data type for mixed precision conv2d.
295295
296+
See Also
297+
--------
298+
conv3d_transpose : Transposed 3D convolution; paired layouts default to ``NCDHW`` / ``IODHW``.
299+
296300
Returns
297301
-------
298302
result : relax.Expr
@@ -512,6 +516,108 @@ def conv2d_transpose(
512516
)
513517

514518

519+
def conv3d_transpose(
520+
data: Expr,
521+
weight: Expr,
522+
strides: int | tuple[int, int, int] = (1, 1, 1),
523+
padding: int | tuple[int, ...] = (0, 0, 0),
524+
output_padding: int | tuple[int, int, int] = (0, 0, 0),
525+
dilation: int | tuple[int, int, int] = (1, 1, 1),
526+
groups: int = 1,
527+
data_layout: str = "NCDHW",
528+
kernel_layout: str = "IODHW",
529+
out_layout: str | None = None,
530+
out_dtype: str | DataType | None = None,
531+
) -> Expr:
532+
r"""Three dimensional transposed convolution operator.
533+
534+
This operator is intended to be the gradient operator of conv3d. That means, if
535+
536+
`out = conv3d(data, weight, strides, padding, dilation)`,
537+
538+
The gradient w.r.t. data can be calculated as follows:
539+
540+
`data_grad = conv3d_transpose(out_grad, weight, strides, padding, output_padding, dilation)`,
541+
542+
where `output_padding` is a parameter used to determine the output shape.
543+
544+
In the default case, where `data_layout == "NCDHW"` and `kernel_layout == "IODHW"`, `data` has
545+
shape `(N, in_channel, in_d, in_h, in_w)`, `weight` has shape
546+
`(in_channel, out_channel, weight_d, weight_h, weight_w)`, with `in_channel % groups == 0`.
547+
The output shape is `(N, out_channel * groups, out_d, out_h, out_w)`.
548+
549+
Parameters
550+
----------
551+
data : relax.Expr
552+
The input data to the operator.
553+
554+
weight : relax.Expr
555+
The weight expressions.
556+
557+
strides : Union[int, Tuple[int, int, int]]
558+
The strides of convolution. It is required to have length either 1 or 3.
559+
560+
padding : Union[int, Tuple[int, ...]]
561+
The padding of convolution on both sides of inputs before convolution.
562+
It is required to have length either 1, 3 or 6.
563+
564+
output_padding : Union[int, Tuple[int, ...]], optional
565+
Used to disambiguate the output shape.
566+
567+
dilation : Union[int, Tuple[int, int, int]]
568+
Specifies the dilation rate to be used for dilated convolution.
569+
It is required to have length either 1 or 3.
570+
571+
groups : int
572+
Number of groups to split the input into for grouped convolution.
573+
The number of input and output channels should be divisible by the number of groups.
574+
575+
data_layout : str
576+
Layout of the input.
577+
578+
kernel_layout : str
579+
Layout of the weight.
580+
581+
out_layout : Optional[str]
582+
Layout of the output. If not specified, it is the same as data_layout
583+
584+
out_dtype : Optional[Union[str, DataType]]
585+
Specifies the output data type for mixed precision conv3d_transpose.
586+
587+
See Also
588+
--------
589+
conv3d : Forward 3D convolution (default ``OIDHW`` weights vs. ``IODHW`` here).
590+
conv2d_transpose : 2D analogue; legalization supports the same TOPI subset (canonical layout, dilation 1).
591+
592+
Returns
593+
-------
594+
result : relax.Expr
595+
The computed result.
596+
"""
597+
if isinstance(strides, int):
598+
strides = (strides, strides, strides)
599+
if isinstance(dilation, int):
600+
dilation = (dilation, dilation, dilation)
601+
if isinstance(padding, int):
602+
padding = (padding, padding, padding, padding, padding, padding)
603+
if isinstance(output_padding, int):
604+
output_padding = (output_padding, output_padding, output_padding)
605+
606+
return _ffi_api.conv3d_transpose( # type: ignore
607+
data,
608+
weight,
609+
strides,
610+
padding,
611+
output_padding,
612+
dilation,
613+
groups,
614+
data_layout,
615+
kernel_layout,
616+
out_layout,
617+
out_dtype,
618+
)
619+
620+
515621
def pad(
516622
data: Expr,
517623
pad_width: list[int] | tuple[int, ...],

python/tvm/relax/op/op_attrs.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ class Conv2DTransposeAttrs(Attrs):
7171
"""Attributes for nn.conv2d_transpose"""
7272

7373

74+
@tvm_ffi.register_object("relax.attrs.Conv3DTransposeAttrs")
75+
class Conv3DTransposeAttrs(Attrs):
76+
"""Attributes for nn.conv3d_transpose"""
77+
78+
7479
@tvm_ffi.register_object("relax.attrs.Pool2DAttrs")
7580
class Pool2DAttrs(Attrs):
7681
"""Attributes for nn.max_pool2d"""

python/tvm/relax/transform/legalize_ops/nn.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:
200200
)
201201
return call
202202
dilation = call.attrs.dilation
203-
if len(dilation) != 2 or dilation[0] != 1 or dilation[1] != 1:
203+
if len(dilation) != 2 or any(d != 1 for d in dilation):
204204
logging.info(
205205
"TOPI conv2d_transpose does not support dilations other than 1, "
206206
"and thus cannot be legalized by TOPI"
@@ -220,6 +220,42 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:
220220
)
221221

222222

223+
@register_legalize("relax.nn.conv3d_transpose")
224+
def _nn_conv3d_transpose(bb: BlockBuilder, call: Call) -> Expr:
225+
# Keep policy in sync with _nn_conv2d_transpose: only lower when TOPI supports the layout/dilation.
226+
if call.attrs.out_layout != call.attrs.data_layout:
227+
logging.info(
228+
"TOPI conv3d_transpose does not support different input-output "
229+
"layouts, and thus cannot be legalized by TOPI"
230+
)
231+
return call
232+
if call.attrs.data_layout != "NCDHW" or call.attrs.kernel_layout != "IODHW":
233+
logging.info(
234+
"TOPI conv3d_transpose does not support input layout other than NCDHW, "
235+
"and kernel layout other than IODHW, so cannot be legalized by TOPI"
236+
)
237+
return call
238+
dilation = call.attrs.dilation
239+
if len(dilation) != 3 or any(d != 1 for d in dilation):
240+
logging.info(
241+
"TOPI conv3d_transpose does not support dilations other than 1, "
242+
"and thus cannot be legalized by TOPI"
243+
)
244+
return call
245+
246+
return bb.call_te(
247+
topi.nn.group_conv3d_transpose_ncdhw,
248+
call.args[0],
249+
call.args[1],
250+
strides=call.attrs.strides,
251+
padding=call.attrs.padding,
252+
out_dtype=call.struct_info.dtype,
253+
output_padding=call.attrs.output_padding,
254+
groups=call.attrs.groups,
255+
primfunc_name_hint="conv3d_transpose",
256+
)
257+
258+
223259
@register_legalize("relax.nn.pad")
224260
def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
225261
pad_mode = call.attrs.pad_mode

0 commit comments

Comments
 (0)