Skip to content

Commit 8ccdd2d

Browse files
author
Your Name
committed
Add ComplexAvgPool2d
1 parent 22137ab commit 8ccdd2d

1 file changed

Lines changed: 20 additions & 1 deletion

File tree

complexLayers.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.nn import Module, Parameter, init
1414
from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d
1515
from torch.nn import ConvTranspose2d
16-
from complexFunctions import complex_relu, complex_max_pool2d
16+
from complexFunctions import complex_relu, complex_max_pool2d, complex_avg_pool2d
1717
from complexFunctions import complex_dropout, complex_dropout2d
1818

1919
def apply_complex(fr, fi, input):
@@ -59,6 +59,25 @@ def forward(self,input):
5959
stride = self.stride, padding = self.padding,
6060
dilation = self.dilation, ceil_mode = self.ceil_mode,
6161
return_indices = self.return_indices)
62+
63+
64+
class ComplexAvgPool2d(Module):
65+
66+
def __init__(self,kernel_size, stride= None, padding = 0,
67+
dilation = 1, return_indices = False, ceil_mode = False):
68+
super(ComplexAvgPool2d,self).__init__()
69+
self.kernel_size = kernel_size
70+
self.stride = stride
71+
self.padding = padding
72+
self.dilation = dilation
73+
self.ceil_mode = ceil_mode
74+
self.return_indices = return_indices
75+
76+
def forward(self,input):
77+
return complex_avg_pool2d(input,kernel_size = self.kernel_size,
78+
stride = self.stride, padding = self.padding,
79+
dilation = self.dilation, ceil_mode = self.ceil_mode,
80+
return_indices = self.return_indices)
6281

6382
class ComplexReLU(Module):
6483

0 commit comments

Comments
 (0)