Skip to content

Commit b091a31

Browse files
author
Sebastien M. Popoff
committed
Merge branch 'master' of https://github.com/octaveguinebretiere/complexPyTorch into octaveguinebretiere-master
2 parents 3a64068 + d00e50c commit b091a31

2 files changed

Lines changed: 280 additions & 0 deletions

File tree

complexPyTorch/complexFunctions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,23 @@ def complex_normalize(input):
4040
def complex_relu(input):
4141
return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64)
4242

43+
def complex_relu(input):
44+
return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64)
45+
46+
def complex_sigmoid(input):
47+
return sigmoid(input.real).type(torch.complex64)+1j*sigmoid(input.imag).type(torch.complex64)
48+
49+
def complex_tanh(input):
50+
return tanh(input.real).type(torch.complex64)+1j*tanh(input.imag).type(torch.complex64)
51+
52+
def complex_opposite(input):
53+
return -(input.real).type(torch.complex64)+1j*(-(input.imag).type(torch.complex64))
54+
55+
def complex_stack(input, dim):
56+
input_real = [x.real for x in input]
57+
input_imag = [x.imag for x in input]
58+
return torch.stack(input_real, dim).type(torch.complex64)+1j*torch.stack(input_imag, dim).type(torch.complex64)
59+
4360
def _retrieve_elements_from_indices(tensor, indices):
4461
flattened_tensor = tensor.flatten(start_dim=-2)
4562
output = flattened_tensor.gather(dim=-1, index=indices.flatten(start_dim=-2)).view_as(indices)

complexPyTorch/complexLayers.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,21 @@
1313
from torch.nn import Module, Parameter, init
1414
from torch.nn import Conv2d, Linear, BatchNorm1d, BatchNorm2d
1515
from torch.nn import ConvTranspose2d
16+
<<<<<<< HEAD:complexPyTorch/complexLayers.py
1617
from .complexFunctions import complex_relu, complex_max_pool2d, complex_avg_pool2d
1718
from .complexFunctions import complex_dropout, complex_dropout2d
19+
=======
20+
from complexFunctions import (
21+
complex_relu,
22+
complex_tanh,
23+
complex_sigmoid,
24+
complex_max_pool2d,
25+
complex_avg_pool2d,
26+
complex_dropout,
27+
complex_dropout2d,
28+
complex_opposite,
29+
)
30+
>>>>>>> d00e50c5d93386dd85b363a39c24cf783b7914aa:complexLayers.py
1831

1932
def apply_complex(fr, fi, input, dtype = torch.complex64):
2033
return (fr(input.real)-fi(input.imag)).type(dtype) \
@@ -83,6 +96,16 @@ class ComplexReLU(Module):
8396

8497
def forward(self,input):
8598
return complex_relu(input)
99+
100+
class ComplexSigmoid(Module):
101+
102+
def forward(self,input):
103+
return complex_sigmoid(input)
104+
105+
class ComplexTanh(Module):
106+
107+
def forward(self,input):
108+
return complex_tanh(input)
86109

87110
class ComplexConvTranspose2d(Module):
88111

@@ -342,3 +365,243 @@ def forward(self, input):
342365

343366
del Crr, Cri, Cii, Rrr, Rii, Rri, det, s, t
344367
return input
368+
369+
# class complexGruCell(Module):
370+
371+
# def __init__(self):
372+
373+
class ComplexGruCell(Module):
374+
"""
375+
A GRU cell for complex-valued inputs
376+
"""
377+
378+
def __init__(self, input_length=10, hidden_length=20):
379+
super(ComplexGruCell, self).__init__()
380+
self.input_length = input_length
381+
self.hidden_length = hidden_length
382+
383+
# reset gate components
384+
self.linear_reset_w1 = ComplexLinear(self.input_length, self.hidden_length)
385+
self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length)
386+
387+
self.linear_reset_w2 = ComplexLinear(self.input_length, self.hidden_length)
388+
self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length)
389+
390+
# update gate components
391+
self.linear_gate_w3 = ComplexLinear(self.input_length, self.hidden_length)
392+
self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length)
393+
394+
self.activation_gate = ComplexSigmoid()
395+
self.activation_candidate = ComplexTanh()
396+
397+
def reset_gate(self, x, h):
398+
x_1 = self.linear_reset_w1(x)
399+
h_1 = self.linear_reset_r1(h)
400+
# gate update
401+
reset = self.activation_gate(x_1 + h_1)
402+
return reset
403+
404+
def update_gate(self, x, h):
405+
x_2 = self.linear_reset_w2(x)
406+
h_2 = self.linear_reset_r2(h)
407+
z = self.activation_gate(h_2 + x_2)
408+
return z
409+
410+
def update_component(self, x, h, r):
411+
x_3 = self.linear_gate_w3(x)
412+
h_3 = r * self.linear_gate_r3(h) # element-wise multiplication
413+
gate_update = self.activation_candidate(x_3 + h_3)
414+
return gate_update
415+
416+
def forward(self, x, h):
417+
# Equation 1. reset gate vector
418+
r = self.reset_gate(x, h)
419+
420+
# Equation 2: the update gate - the shared update gate vector z
421+
z = self.update_gate(x, h)
422+
423+
# Equation 3: The almost output component
424+
n = self.update_component(x, h, r)
425+
426+
# Equation 4: the new hidden state
427+
h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication
428+
429+
return h_new
430+
431+
class ComplexBNGruCell(Module):
432+
"""
433+
A BN-GRU cell for complex-valued inputs
434+
"""
435+
436+
def __init__(self, input_length=10, hidden_length=20):
437+
super(ComplexBNGruCell, self).__init__()
438+
self.input_length = input_length
439+
self.hidden_length = hidden_length
440+
441+
# reset gate components
442+
self.linear_reset_w1 = ComplexLinear(self.input_length, self.hidden_length)
443+
self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length)
444+
445+
self.linear_reset_w2 = ComplexLinear(self.input_length, self.hidden_length)
446+
self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length)
447+
448+
# update gate components
449+
self.linear_gate_w3 = ComplexLinear(self.input_length, self.hidden_length)
450+
self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length)
451+
452+
self.activation_gate = ComplexSigmoid()
453+
self.activation_candidate = ComplexTanh()
454+
455+
self.bn = ComplexBatchNorm2d(1)
456+
457+
def reset_gate(self, x, h):
458+
x_1 = self.linear_reset_w1(x)
459+
h_1 = self.linear_reset_r1(h)
460+
# gate update
461+
reset = self.activation_gate(self.bn(x_1) + self.bn(h_1))
462+
return reset
463+
464+
def update_gate(self, x, h):
465+
x_2 = self.linear_reset_w2(x)
466+
h_2 = self.linear_reset_r2(h)
467+
z = self.activation_gate(self.bn(h_2) + self.bn(x_2))
468+
return z
469+
470+
def update_component(self, x, h, r):
471+
x_3 = self.linear_gate_w3(x)
472+
h_3 = r * self.bn(self.linear_gate_r3(h)) # element-wise multiplication
473+
gate_update = self.activation_candidate(self.bn(self.bn(x_3) + h_3))
474+
return gate_update
475+
476+
def forward(self, x, h):
477+
# Equation 1. reset gate vector
478+
r = self.reset_gate(x, h)
479+
480+
# Equation 2: the update gate - the shared update gate vector z
481+
z = self.update_gate(x, h)
482+
483+
# Equation 3: The almost output component
484+
n = self.update_component(x, h, r)
485+
486+
# Equation 4: the new hidden state
487+
h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication
488+
489+
return h_new
490+
491+
class ComplexGRUCell(Module):
492+
"""
493+
A GRU cell for complex-valued inputs
494+
"""
495+
496+
def __init__(self, input_length=10, hidden_length=20):
497+
super(ComplexGRUCell, self).__init__()
498+
self.input_length = input_length
499+
self.hidden_length = hidden_length
500+
501+
# reset gate components
502+
self.linear_reset_w1 = ComplexLinear(self.input_length, self.hidden_length)
503+
self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length)
504+
505+
self.linear_reset_w2 = ComplexLinear(self.input_length, self.hidden_length)
506+
self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length)
507+
508+
# update gate components
509+
self.linear_gate_w3 = ComplexLinear(self.input_length, self.hidden_length)
510+
self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length)
511+
512+
self.activation_gate = ComplexSigmoid()
513+
self.activation_candidate = ComplexTanh()
514+
515+
def reset_gate(self, x, h):
516+
x_1 = self.linear_reset_w1(x)
517+
h_1 = self.linear_reset_r1(h)
518+
# gate update
519+
reset = self.activation_gate(x_1 + h_1)
520+
return reset
521+
522+
def update_gate(self, x, h):
523+
x_2 = self.linear_reset_w2(x)
524+
h_2 = self.linear_reset_r2(h)
525+
z = self.activation_gate(h_2 + x_2)
526+
return z
527+
528+
def update_component(self, x, h, r):
529+
x_3 = self.linear_gate_w3(x)
530+
h_3 = r * self.linear_gate_r3(h) # element-wise multiplication
531+
gate_update = self.activation_candidate(x_3 + h_3)
532+
return gate_update
533+
534+
def forward(self, x, h):
535+
# Equation 1. reset gate vector
536+
r = self.reset_gate(x, h)
537+
538+
# Equation 2: the update gate - the shared update gate vector z
539+
z = self.update_gate(x, h)
540+
541+
# Equation 3: The almost output component
542+
n = self.update_component(x, h, r)
543+
544+
# Equation 4: the new hidden state
545+
h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication
546+
547+
return h_new
548+
549+
class ComplexBNGRUCell(Module):
550+
"""
551+
A BN-GRU cell for complex-valued inputs
552+
"""
553+
554+
def __init__(self, input_length=10, hidden_length=20):
555+
super(ComplexBNGRUCell, self).__init__()
556+
self.input_length = input_length
557+
self.hidden_length = hidden_length
558+
559+
# reset gate components
560+
self.linear_reset_w1 = ComplexLinear(self.input_length, self.hidden_length)
561+
self.linear_reset_r1 = ComplexLinear(self.hidden_length, self.hidden_length)
562+
563+
self.linear_reset_w2 = ComplexLinear(self.input_length, self.hidden_length)
564+
self.linear_reset_r2 = ComplexLinear(self.hidden_length, self.hidden_length)
565+
566+
# update gate components
567+
self.linear_gate_w3 = ComplexLinear(self.input_length, self.hidden_length)
568+
self.linear_gate_r3 = ComplexLinear(self.hidden_length, self.hidden_length)
569+
570+
self.activation_gate = ComplexSigmoid()
571+
self.activation_candidate = ComplexTanh()
572+
573+
self.bn = ComplexBatchNorm2d(1)
574+
575+
def reset_gate(self, x, h):
576+
x_1 = self.linear_reset_w1(x)
577+
h_1 = self.linear_reset_r1(h)
578+
# gate update
579+
reset = self.activation_gate(self.bn(x_1) + self.bn(h_1))
580+
return reset
581+
582+
def update_gate(self, x, h):
583+
x_2 = self.linear_reset_w2(x)
584+
h_2 = self.linear_reset_r2(h)
585+
z = self.activation_gate(self.bn(h_2) + self.bn(x_2))
586+
return z
587+
588+
def update_component(self, x, h, r):
589+
x_3 = self.linear_gate_w3(x)
590+
h_3 = r * self.bn(self.linear_gate_r3(h)) # element-wise multiplication
591+
gate_update = self.activation_candidate(self.bn(self.bn(x_3) + h_3))
592+
return gate_update
593+
594+
def forward(self, x, h):
595+
# Equation 1. reset gate vector
596+
r = self.reset_gate(x, h)
597+
598+
# Equation 2: the update gate - the shared update gate vector z
599+
z = self.update_gate(x, h)
600+
601+
# Equation 3: The almost output component
602+
n = self.update_component(x, h, r)
603+
604+
# Equation 4: the new hidden state
605+
h_new = (1 + complex_opposite(z)) * n + z * h # element-wise multiplication
606+
607+
return h_new

0 commit comments

Comments
 (0)