Skip to content

Commit d54e793

Browse files
author
Your Name
committed
add support for complex64 tensors
1 parent 6ed32b0 commit d54e793

4 files changed

Lines changed: 237 additions & 103 deletions

File tree

Example.ipynb

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import torch\n",
10+
"import torch.nn as nn\n",
11+
"import torch.nn.functional as F\n",
12+
"from torchvision import datasets, transforms\n",
13+
"from complexLayers import ComplexBatchNorm2d, ComplexConv2d, ComplexLinear, NaiveComplexBatchNorm2d\n",
14+
"from complexFunctions import complex_relu, complex_max_pool2d"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 2,
20+
"metadata": {},
21+
"outputs": [],
22+
"source": [
23+
"batch_size = 64\n",
24+
"trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])\n",
25+
"train_set = datasets.MNIST('../data', train=True, transform=trans, download=True)\n",
26+
"test_set = datasets.MNIST('../data', train=False, transform=trans, download=True)\n",
27+
"\n",
28+
"train_loader = torch.utils.data.DataLoader(train_set, batch_size= batch_size, shuffle=True)\n",
29+
"test_loader = torch.utils.data.DataLoader(test_set, batch_size= batch_size, shuffle=True)"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": 8,
35+
"metadata": {},
36+
"outputs": [],
37+
"source": [
38+
"class ComplexNet(nn.Module):\n",
39+
" \n",
40+
" def __init__(self):\n",
41+
" super(ComplexNet, self).__init__()\n",
42+
" self.conv1 = ComplexConv2d(1, 10, 5, 1)\n",
43+
" self.bn = ComplexBatchNorm2d(10)\n",
44+
" self.conv2 = ComplexConv2d(10, 20, 5, 1)\n",
45+
" self.fc1 = ComplexLinear(4*4*20, 500)\n",
46+
" self.fc2 = ComplexLinear(500, 10)\n",
47+
" \n",
48+
" def forward(self,x):\n",
49+
" x = self.conv1(x)\n",
50+
" x = complex_relu(x)\n",
51+
" x = complex_max_pool2d(x, 2, 2)\n",
52+
" x = self.bn(x)\n",
53+
" x = self.conv2(x)\n",
54+
" x = complex_relu(x)\n",
55+
" x = complex_max_pool2d(x, 2, 2)\n",
56+
" x = x.view(-1,4*4*20)\n",
57+
" x = self.fc1(x)\n",
58+
" x = complex_relu(x)\n",
59+
" x = self.fc2(x)\n",
60+
" x = x.abs()\n",
61+
" x = F.log_softmax(x, dim=1)\n",
62+
" return x\n",
63+
" \n",
64+
"device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
65+
"model = ComplexNet().to(device)\n",
66+
"optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n",
67+
"\n",
68+
"def train(model, device, train_loader, optimizer, epoch):\n",
69+
" model.train()\n",
70+
" for batch_idx, (data, target) in enumerate(train_loader):\n",
71+
" data, target =data.to(device).type(torch.complex64), target.to(device)\n",
72+
" optimizer.zero_grad()\n",
73+
" output = model(data)\n",
74+
" loss = F.nll_loss(output, target)\n",
75+
" loss.backward()\n",
76+
" optimizer.step()\n",
77+
" if batch_idx % 100 == 0:\n",
78+
" print('Train Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\\tLoss: {:.6f}'.format(\n",
79+
" epoch,\n",
80+
" batch_idx * len(data), \n",
81+
" len(train_loader.dataset),\n",
82+
" 100. * batch_idx / len(train_loader), \n",
83+
" loss.item())\n",
84+
" )"
85+
]
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": null,
90+
"metadata": {},
91+
"outputs": [
92+
{
93+
"name": "stdout",
94+
"output_type": "stream",
95+
"text": [
96+
"Train Epoch: 0 [ 0/ 60000 ( 0%)]\tLoss: 2.349018\n",
97+
"Train Epoch: 0 [ 6400/ 60000 ( 11%)]\tLoss: 0.252006\n",
98+
"Train Epoch: 0 [ 12800/ 60000 ( 21%)]\tLoss: 0.094634\n",
99+
"Train Epoch: 0 [ 19200/ 60000 ( 32%)]\tLoss: 0.096171\n",
100+
"Train Epoch: 0 [ 25600/ 60000 ( 43%)]\tLoss: 0.039067\n",
101+
"Train Epoch: 0 [ 32000/ 60000 ( 53%)]\tLoss: 0.062306\n",
102+
"Train Epoch: 0 [ 38400/ 60000 ( 64%)]\tLoss: 0.091644\n",
103+
"Train Epoch: 0 [ 44800/ 60000 ( 75%)]\tLoss: 0.154324\n",
104+
"Train Epoch: 0 [ 51200/ 60000 ( 85%)]\tLoss: 0.015835\n",
105+
"Train Epoch: 0 [ 57600/ 60000 ( 96%)]\tLoss: 0.005899\n",
106+
"Train Epoch: 1 [ 0/ 60000 ( 0%)]\tLoss: 0.013530\n",
107+
"Train Epoch: 1 [ 6400/ 60000 ( 11%)]\tLoss: 0.031689\n",
108+
"Train Epoch: 1 [ 12800/ 60000 ( 21%)]\tLoss: 0.025631\n",
109+
"Train Epoch: 1 [ 19200/ 60000 ( 32%)]\tLoss: 0.031679\n",
110+
"Train Epoch: 1 [ 25600/ 60000 ( 43%)]\tLoss: 0.021937\n",
111+
"Train Epoch: 1 [ 32000/ 60000 ( 53%)]\tLoss: 0.095149\n",
112+
"Train Epoch: 1 [ 38400/ 60000 ( 64%)]\tLoss: 0.008647\n",
113+
"Train Epoch: 1 [ 44800/ 60000 ( 75%)]\tLoss: 0.088300\n",
114+
"Train Epoch: 1 [ 51200/ 60000 ( 85%)]\tLoss: 0.003999\n",
115+
"Train Epoch: 1 [ 57600/ 60000 ( 96%)]\tLoss: 0.004459\n",
116+
"Train Epoch: 2 [ 0/ 60000 ( 0%)]\tLoss: 0.003121\n",
117+
"Train Epoch: 2 [ 6400/ 60000 ( 11%)]\tLoss: 0.003100\n",
118+
"Train Epoch: 2 [ 12800/ 60000 ( 21%)]\tLoss: 0.001305\n",
119+
"Train Epoch: 2 [ 19200/ 60000 ( 32%)]\tLoss: 0.017995\n"
120+
]
121+
}
122+
],
123+
"source": [
124+
"# Run training on 4 epochs\n",
125+
"for epoch in range(4):\n",
126+
" train(model, device, train_loader, optimizer, epoch)"
127+
]
128+
}
129+
],
130+
"metadata": {
131+
"kernelspec": {
132+
"display_name": "Python 3",
133+
"language": "python",
134+
"name": "python3"
135+
},
136+
"language_info": {
137+
"codemirror_mode": {
138+
"name": "ipython",
139+
"version": 3
140+
},
141+
"file_extension": ".py",
142+
"mimetype": "text/x-python",
143+
"name": "python",
144+
"nbconvert_exporter": "python",
145+
"pygments_lexer": "ipython3",
146+
"version": "3.8.0"
147+
},
148+
"toc": {
149+
"base_numbering": 1,
150+
"nav_menu": {},
151+
"number_sections": true,
152+
"sideBar": true,
153+
"skip_h1_title": true,
154+
"title_cell": "Table of Contents",
155+
"title_sidebar": "Contents",
156+
"toc_cell": false,
157+
"toc_position": {},
158+
"toc_section_display": true,
159+
"toc_window_display": false
160+
}
161+
},
162+
"nbformat": 4,
163+
"nbformat_minor": 4
164+
}

README.md

100644100755
Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
A high-level toolbox for using complex valued neural networks in PyTorch.
44

5+
Before version 1.7 of PyTroch, complex tensor were not supported.
6+
The initial version of **complexPyTorch** represented complex tensor using two tensors, one for the real and one for the imaginary part.
7+
Since version 1.7, compex tensors of type `torch.complex64` are allowed, but only a limited number of operation are supported.
8+
The current version **complexPyTorch** use complex tensors (hence requires PyTorch version >= 1.7) and add support for various operations and layers.
9+
510
## Complex Valued Networks with PyTorch
611

712
Artificial neural networks are mainly used for treating data encoded in real values, such as digitized images or sounds.
@@ -17,7 +22,6 @@ Following [[C. Trabelsi et al., International Conference on Learning Representat
1722
* BatchNorm2d (Naive and Covariance approach)
1823

1924

20-
2125
## Syntax and usage
2226

2327
The syntax is supposed to copy the one of the standard real functions and modules from PyTorch.
@@ -58,49 +62,42 @@ class ComplexNet(nn.Module):
5862

5963
def __init__(self):
6064
super(ComplexNet, self).__init__()
61-
self.conv1 = ComplexConv2d(1, 20, 5, 1)
62-
self.bn = ComplexBatchNorm2d(20)
63-
self.conv2 = ComplexConv2d(20, 50, 5, 1)
64-
self.fc1 = ComplexLinear(4*4*50, 500)
65+
self.conv1 = ComplexConv2d(1, 10, 5, 1)
66+
self.bn = ComplexBatchNorm2d(10)
67+
self.conv2 = ComplexConv2d(10, 20, 5, 1)
68+
self.fc1 = ComplexLinear(4*4*20, 500)
6569
self.fc2 = ComplexLinear(500, 10)
6670

6771
def forward(self,x):
68-
xr = x
69-
# imaginary part to zero
70-
xi = torch.zeros(xr.shape, dtype = xr.dtype, device = xr.device)
71-
xr,xi = self.conv1(xr,xi)
72-
xr,xi = complex_relu(xr,xi)
73-
xr,xi = complex_max_pool2d(xr,xi, 2, 2)
74-
75-
76-
xr,xi = self.bn(xr,xi)
77-
xr,xi = self.conv2(xr,xi)
78-
xr,xi = complex_relu(xr,xi)
79-
xr,xi = complex_max_pool2d(xr,xi, 2, 2)
80-
81-
xr = xr.view(-1, 4*4*50)
82-
xi = xi.view(-1, 4*4*50)
83-
xr,xi = self.fc1(xr,xi)
84-
xr,xi = complex_relu(xr,xi)
85-
xr,xi = self.fc2(xr,xi)
86-
# take the absolute value as output
87-
x = torch.sqrt(torch.pow(xr,2)+torch.pow(xi,2))
88-
return F.log_softmax(x, dim=1)
72+
x = self.conv1(x)
73+
x = complex_relu(x)
74+
x = complex_max_pool2d(x, 2, 2)
75+
x = self.bn(x)
76+
x = self.conv2(x)
77+
x = complex_relu(x)
78+
x = complex_max_pool2d(x, 2, 2)
79+
x = x.view(-1,4*4*20)
80+
x = self.fc1(x)
81+
x = complex_relu(x)
82+
x = self.fc2(x)
83+
x = x.abs()
84+
x = F.log_softmax(x, dim=1)
85+
return x
8986

90-
device = torch.device("cuda:0" )
87+
device = device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9188
model = ComplexNet().to(device)
9289
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
9390

9491
def train(model, device, train_loader, optimizer, epoch):
9592
model.train()
9693
for batch_idx, (data, target) in enumerate(train_loader):
97-
data, target = data.to(device), target.to(device)
94+
data, target = data.to(device).type(torch.complex64), target.to(device)
9895
optimizer.zero_grad()
9996
output = model(data)
10097
loss = F.nll_loss(output, target)
10198
loss.backward()
10299
optimizer.step()
103-
if batch_idx % 1000 == 0:
100+
if batch_idx % 100 == 0:
104101
print('Train Epoch: {:3} [{:6}/{:6} ({:3.0f}%)]\tLoss: {:.6f}'.format(
105102
epoch,
106103
batch_idx * len(data),
@@ -113,11 +110,7 @@ def train(model, device, train_loader, optimizer, epoch):
113110
for epoch in range(50):
114111
train(model, device, train_loader, optimizer, epoch)
115112
```
116-
117-
## Todo
118-
* Script ComplexBatchNorm for improved efficiency ([jit doc](https://pytorch.org/docs/stable/jit.html))
119-
* Add more layers (Conv1D, Upsample, ConvTranspose...)
120-
* Add complex cost functions and usual functions (e.g. Pearson correlation)
113+
121114

122115
## Acknowledgments
123116

complexFunctions.py

100644100755
Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,24 @@
66
"""
77

88
from torch.nn.functional import relu, max_pool2d, dropout, dropout2d
9+
import torch
910

10-
def complex_relu(input_r,input_i):
11-
return relu(input_r), relu(input_i)
11+
def complex_relu(input):
12+
return relu(input.real).type(torch.complex64)+1j*relu(input.imag).type(torch.complex64)
1213

13-
def complex_max_pool2d(input_r,input_i,kernel_size, stride=None, padding=0,
14+
def complex_max_pool2d(input,kernel_size, stride=None, padding=0,
1415
dilation=1, ceil_mode=False, return_indices=False):
1516

16-
return max_pool2d(input_r, kernel_size, stride, padding, dilation,
17-
ceil_mode, return_indices), \
18-
max_pool2d(input_i, kernel_size, stride, padding, dilation,
19-
ceil_mode, return_indices)
17+
return max_pool2d(input.real, kernel_size, stride, padding, dilation,
18+
ceil_mode, return_indices).type(torch.complex64) \
19+
+ 1j*max_pool2d(input.imag, kernel_size, stride, padding, dilation,
20+
ceil_mode, return_indices).type(torch.complex64)
2021

2122
def complex_dropout(input_r,input_i, p=0.5, training=True, inplace=False):
22-
return dropout(input_r, p, training, inplace), \
23-
dropout(input_i, p, training, inplace)
23+
return dropout(input_r, p, training, inplace).type(torch.complex64) \
24+
+1j*dropout(input_i, p, training, inplace).type(torch.complex64)
2425

2526

2627
def complex_dropout2d(input_r,input_i, p=0.5, training=True, inplace=False):
27-
return dropout2d(input_r, p, training, inplace), \
28-
dropout2d(input_i, p, training, inplace)
28+
return dropout2d(input_r, p, training, inplace).type(torch.complex64) \
29+
+1j*dropout2d(input_i, p, training, inplace).type(torch.complex64)

0 commit comments

Comments
 (0)