@@ -293,6 +293,10 @@ def conv3d(
293293 out_dtype : Optional[Union[str, DataType]]
294294 Specifies the output data type for mixed precision conv2d.
295295
296+ See Also
297+ --------
298+ conv3d_transpose : Transposed 3D convolution; paired layouts default to ``NCDHW`` / ``IODHW``.
299+
296300 Returns
297301 -------
298302 result : relax.Expr
@@ -512,6 +516,108 @@ def conv2d_transpose(
512516 )
513517
514518
519+ def conv3d_transpose (
520+ data : Expr ,
521+ weight : Expr ,
522+ strides : int | tuple [int , int , int ] = (1 , 1 , 1 ),
523+ padding : int | tuple [int , ...] = (0 , 0 , 0 ),
524+ output_padding : int | tuple [int , int , int ] = (0 , 0 , 0 ),
525+ dilation : int | tuple [int , int , int ] = (1 , 1 , 1 ),
526+ groups : int = 1 ,
527+ data_layout : str = "NCDHW" ,
528+ kernel_layout : str = "IODHW" ,
529+ out_layout : str | None = None ,
530+ out_dtype : str | DataType | None = None ,
531+ ) -> Expr :
532+ r"""Three dimensional transposed convolution operator.
533+
534+ This operator is intended to be the gradient operator of conv3d. That means, if
535+
536+ `out = conv3d(data, weight, strides, padding, dilation)`,
537+
538+ The gradient w.r.t. data can be calculated as follows:
539+
540+ `data_grad = conv3d_transpose(out_grad, weight, strides, padding, output_padding, dilation)`,
541+
542+ where `output_padding` is a parameter used to determine the output shape.
543+
544+ In the default case, where `data_layout == "NCDHW"` and `kernel_layout == "IODHW"`, `data` has
545+ shape `(N, in_channel, in_d, in_h, in_w)`, `weight` has shape
546+ `(in_channel, out_channel, weight_d, weight_h, weight_w)`, with `in_channel % groups == 0`.
547+ The output shape is `(N, out_channel * groups, out_d, out_h, out_w)`.
548+
549+ Parameters
550+ ----------
551+ data : relax.Expr
552+ The input data to the operator.
553+
554+ weight : relax.Expr
555+ The weight expressions.
556+
557+ strides : Union[int, Tuple[int, int, int]]
558+ The strides of convolution. It is required to have length either 1 or 3.
559+
560+ padding : Union[int, Tuple[int, ...]]
561+ The padding of convolution on both sides of inputs before convolution.
562+ It is required to have length either 1, 3 or 6.
563+
564+ output_padding : Union[int, Tuple[int, ...]], optional
565+ Used to disambiguate the output shape.
566+
567+ dilation : Union[int, Tuple[int, int, int]]
568+ Specifies the dilation rate to be used for dilated convolution.
569+ It is required to have length either 1 or 3.
570+
571+ groups : int
572+ Number of groups to split the input into for grouped convolution.
573+ The number of input and output channels should be divisible by the number of groups.
574+
575+ data_layout : str
576+ Layout of the input.
577+
578+ kernel_layout : str
579+ Layout of the weight.
580+
581+ out_layout : Optional[str]
582+ Layout of the output. If not specified, it is the same as data_layout
583+
584+ out_dtype : Optional[Union[str, DataType]]
585+ Specifies the output data type for mixed precision conv3d_transpose.
586+
587+ See Also
588+ --------
589+ conv3d : Forward 3D convolution (default ``OIDHW`` weights vs. ``IODHW`` here).
590+ conv2d_transpose : 2D analogue; legalization supports the same TOPI subset (canonical layout, dilation 1).
591+
592+ Returns
593+ -------
594+ result : relax.Expr
595+ The computed result.
596+ """
597+ if isinstance (strides , int ):
598+ strides = (strides , strides , strides )
599+ if isinstance (dilation , int ):
600+ dilation = (dilation , dilation , dilation )
601+ if isinstance (padding , int ):
602+ padding = (padding , padding , padding , padding , padding , padding )
603+ if isinstance (output_padding , int ):
604+ output_padding = (output_padding , output_padding , output_padding )
605+
606+ return _ffi_api .conv3d_transpose ( # type: ignore
607+ data ,
608+ weight ,
609+ strides ,
610+ padding ,
611+ output_padding ,
612+ dilation ,
613+ groups ,
614+ data_layout ,
615+ kernel_layout ,
616+ out_layout ,
617+ out_dtype ,
618+ )
619+
620+
515621def pad (
516622 data : Expr ,
517623 pad_width : list [int ] | tuple [int , ...],
0 commit comments