Skip to content

Commit dcc8db4

Browse files
author
Your Name
committed
resolve conflicts
1 parent 3a64068 commit dcc8db4

5 files changed

Lines changed: 160 additions & 3 deletions

File tree

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
## main
2+
3+
### Added
4+
5+
* GRU Cell and BN-GRU Cell
6+
17
## 0.4
28

39
### Fixed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ Following [[C. Trabelsi et al., International Conference on Learning Representat
2323
* Conv2d
2424
* MaxPool2d
2525
* Relu (ℂRelu)
26+
* Sigmoid
27+
* Tanh
2628
* BatchNorm1d (Naive and Covariance approach)
2729
* BatchNorm2d (Naive and Covariance approach)
30+
* GRU/BN-GRU Cell
2831

2932
## Citating the code
3033

complexPyTorch/complexFunctions.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
@author: spopoff
66
"""
77

8-
from torch.nn.functional import relu, max_pool2d, avg_pool2d, dropout, dropout2d, interpolate
8+
from torch.nn.functional import relu, max_pool2d, avg_pool2d, dropout, dropout2d, interpolate, sigmoid, tanh
99
import torch
1010

1111
def complex_matmul(A, B):
@@ -40,6 +40,20 @@ def complex_normalize(input):
4040
def complex_relu(input):
4141
return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64)
4242

43+
def complex_sigmoid(input):
44+
return sigmoid(input.real).type(torch.complex64)+1j*sigmoid(input.imag).type(torch.complex64)
45+
46+
def complex_tanh(input):
47+
return tanh(input.real).type(torch.complex64)+1j*tanh(input.imag).type(torch.complex64)
48+
49+
def complex_opposite(input):
50+
return -(input.real).type(torch.complex64)+1j*(-(input.imag).type(torch.complex64))
51+
52+
def complex_stack(input, dim):
53+
input_real = [x.real for x in input]
54+
input_imag = [x.imag for x in input]
55+
return torch.stack(input_real, dim).type(torch.complex64)+1j*torch.stack(input_imag, dim).type(torch.complex64)
56+
4357
def _retrieve_elements_from_indices(tensor, indices):
4458
flattened_tensor = tensor.flatten(start_dim=-2)
4559
output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-2)).view_as(indices)
@@ -113,3 +127,5 @@ def complex_dropout2d(input, p=0.5, training=True):
113127
mask = dropout2d(mask, p, training)*1/(1-p)
114128
mask.type(input.dtype)
115129
return mask*input
130+
131+

complexPyTorch/complexLayers.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ class ComplexReLU(Module):
8383

8484
def forward(self,input):
8585
return complex_relu(input)
86+
87+
class ComplexSigmoid(Module):
88+
89+
def forward(self,input):
90+
return complex_sigmoid(input)
91+
92+
class ComplexTanh(Module):
93+
94+
def forward(self,input):
95+
return complex_tanh(input)
8696

8797
class ComplexConvTranspose2d(Module):
8898

@@ -342,3 +352,125 @@ def forward(self, input):
342352

343353
del Crr, Cri, Cii, Rrr, Rii, Rri, det, s, t
344354
return input
355+
356+
# class complexGruCell(Module):
357+
358+
# def __init__(self):
359+
360+
class ComplexGruCell(Module):
361+
"""
362+
A GRU cell for complex-valued inputs
363+
"""
364+
365+
def __init__(self, input_length=10, hidden_length=20):
366+
super(ComplexGruCell, self).__init__()
367+
self.input_length = input_length
368+
self.hidden_length = hidden_length
369+
370+
# reset gate components
371+
self.linear_reset_w1 = ComplexLinear(self.input_length, self.hidden_length)
372+
self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length)
373+
374+
self.linear_reset_w2 = ComplexLinear(self.input_length, self.hidden_length)
375+
self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length)
376+
377+
# update gate components
378+
self.linear_gate_w3 = ComplexLinear(self.input_length, self.hidden_length)
379+
self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length)
380+
381+
self.activation_gate = ComplexSigmoid()
382+
self.activation_candidate = ComplexTanh()
383+
384+
def reset_gate(self, x, h):
385+
x_1 = self.linear_reset_w1(x)
386+
h_1 = self.linear_reset_r1(h)
387+
# gate update
388+
reset = self.activation_gate(x_1 + h_1)
389+
return reset
390+
391+
def update_gate(self, x, h):
392+
x_2 = self.linear_reset_w2(x)
393+
h_2 = self.linear_reset_r2(h)
394+
z = self.activation_gate(h_2 + x_2)
395+
return z
396+
397+
def update_component(self, x, h, r):
398+
x_3 = self.linear_gate_w3(x)
399+
h_3 = r * self.linear_gate_r3(h) # element-wise multiplication
400+
gate_update = self.activation_candidate(x_3 + h_3)
401+
return gate_update
402+
403+
def forward(self, x, h):
404+
# Equation 1. reset gate vector
405+
r = self.reset_gate(x, h)
406+
407+
# Equation 2: the update gate - the shared update gate vector z
408+
z = self.update_gate(x, h)
409+
410+
# Equation 3: The almost output component
411+
n = self.update_component(x, h, r)
412+
413+
# Equation 4: the new hidden state
414+
h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication
415+
416+
return h_new
417+
418+
class ComplexBNGruCell(Module):
419+
"""
420+
A BN-GRU cell for complex-valued inputs
421+
"""
422+
423+
def __init__(self, input_length=10, hidden_length=20):
424+
super(ComplexBNGruCell, self).__init__()
425+
self.input_length = input_length
426+
self.hidden_length = hidden_length
427+
428+
# reset gate components
429+
self.linear_reset_w1 = ComplexLinear(self.input_length, self.hidden_length)
430+
self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length)
431+
432+
self.linear_reset_w2 = ComplexLinear(self.input_length, self.hidden_length)
433+
self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length)
434+
435+
# update gate components
436+
self.linear_gate_w3 = ComplexLinear(self.input_length, self.hidden_length)
437+
self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length)
438+
439+
self.activation_gate = ComplexSigmoid()
440+
self.activation_candidate = ComplexTanh()
441+
442+
self.bn = ComplexBatchNorm2d(1)
443+
444+
def reset_gate(self, x, h):
445+
x_1 = self.linear_reset_w1(x)
446+
h_1 = self.linear_reset_r1(h)
447+
# gate update
448+
reset = self.activation_gate(self.bn(x_1) + self.bn(h_1))
449+
return reset
450+
451+
def update_gate(self, x, h):
452+
x_2 = self.linear_reset_w2(x)
453+
h_2 = self.linear_reset_r2(h)
454+
z = self.activation_gate(self.bn(h_2) + self.bn(x_2))
455+
return z
456+
457+
def update_component(self, x, h, r):
458+
x_3 = self.linear_gate_w3(x)
459+
h_3 = r * self.bn(self.linear_gate_r3(h)) # element-wise multiplication
460+
gate_update = self.activation_candidate(self.bn(self.bn(x_3) + h_3))
461+
return gate_update
462+
463+
def forward(self, x, h):
464+
# Equation 1. reset gate vector
465+
r = self.reset_gate(x, h)
466+
467+
# Equation 2: the update gate - the shared update gate vector z
468+
z = self.update_gate(x, h)
469+
470+
# Equation 3: The almost output component
471+
n = self.update_component(x, h, r)
472+
473+
# Equation 4: the new hidden state
474+
h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication
475+
476+
return h_new

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup, find_packages
22

33
setup(name='complexPyTorch',
4-
version='0.3',
4+
version='0.4.1',
55
description='A high-level toolbox for using complex valued neural networks in PyTorch.',
66
long_description=open('README.md').read().strip(),
77
long_description_content_type='text/markdown',
@@ -14,4 +14,4 @@
1414
license='MIT License',
1515
zip_safe=False,
1616
keywords='pytorch, deep learning, complex values',
17-
classifiers=[''])
17+
classifiers=[''])

0 commit comments

Comments
 (0)