|
13 | 13 | from torch.nn import Module, Parameter, init |
14 | 14 | from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d |
15 | 15 | from torch.nn import ConvTranspose2d |
16 | | -from complexFunctions import complex_relu, complex_max_pool2d |
| 16 | +from complexFunctions import complex_relu, complex_max_pool2d, complex_avg_pool2d |
17 | 17 | from complexFunctions import complex_dropout, complex_dropout2d |
18 | 18 |
|
19 | 19 | def apply_complex(fr, fi, input): |
@@ -59,6 +59,25 @@ def forward(self,input): |
59 | 59 | stride = self.stride, padding = self.padding, |
60 | 60 | dilation = self.dilation, ceil_mode = self.ceil_mode, |
61 | 61 | return_indices = self.return_indices) |
| 62 | + |
| 63 | + |
| 64 | +class ComplexAvgPool2d(Module): |
| 65 | + |
| 66 | + def __init__(self,kernel_size, stride= None, padding = 0, |
| 67 | + dilation = 1, return_indices = False, ceil_mode = False): |
| 68 | + super(ComplexAvgPool2d,self).__init__() |
| 69 | + self.kernel_size = kernel_size |
| 70 | + self.stride = stride |
| 71 | + self.padding = padding |
| 72 | + self.dilation = dilation |
| 73 | + self.ceil_mode = ceil_mode |
| 74 | + self.return_indices = return_indices |
| 75 | + |
| 76 | + def forward(self,input): |
| 77 | + return complex_avg_pool2d(input,kernel_size = self.kernel_size, |
| 78 | + stride = self.stride, padding = self.padding, |
| 79 | + dilation = self.dilation, ceil_mode = self.ceil_mode, |
| 80 | + return_indices = self.return_indices) |
62 | 81 |
|
63 | 82 | class ComplexReLU(Module): |
64 | 83 |
|
|
0 commit comments