Skip to content

Commit be9eeed

Browse files
committed
Merge branch 'develop' into feature/eigen-device
2 parents db95288 + 6366ff4 commit be9eeed

7 files changed

Lines changed: 1164 additions & 6 deletions

File tree

examples/mnist/mnist_cnn.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#!/usr/bin/env python3
2+
3+
# Python example of Convolutional Neural Network.
4+
# Please refer primitiv repository for more details.
5+
#
6+
# Usage:
7+
# $ ./download_data.sh
8+
# $ python3 ./mnist_cnn.py
9+
10+
import random
11+
12+
import numpy as np
13+
14+
from primitiv import functions as F
15+
from primitiv import initializers as I
16+
from primitiv import optimizers as O
17+
from primitiv import devices as D
18+
from primitiv import Device, Graph, Parameter, Shape
19+
20+
NUM_TRAIN_SAMPLES = 60000
21+
NUM_TEST_SAMPLES = 10000
22+
BATCH_SIZE = 200
23+
NUM_TRAIN_BATCHES = NUM_TRAIN_SAMPLES // BATCH_SIZE
24+
NUM_TEST_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE
25+
MAX_EPOCH = 100
26+
27+
IMAGE_HEIGHT = 28
28+
IMAGE_WIDTH = 28
29+
30+
KERNEL_SIZE1 = 5 # should be an odd number
31+
KERNEL_SIZE2 = 5 # ditto
32+
NUM_CHANNELS1 = 8
33+
NUM_CHANNELS2 = 16
34+
PADDING1 = KERNEL_SIZE1 // 2
35+
PADDING2 = KERNEL_SIZE2 // 2
36+
37+
NUM_INPUT_UNITS = (IMAGE_HEIGHT // 4) * (IMAGE_WIDTH // 4) * NUM_CHANNELS2
38+
NUM_HIDDEN_UNITS = 256
39+
NUM_OUTPUT_UNITS = 10
40+
41+
42+
def load_images(filename, n):
43+
with open(filename, "rb") as ifs:
44+
ifs.seek(16) # header
45+
return (np.fromfile(ifs, dtype=np.uint8, count=n*NUM_INPUT_UNITS) / 255) \
46+
.astype(np.float32) \
47+
.reshape((n, IMAGE_HEIGHT, IMAGE_WIDTH))
48+
49+
50+
def load_labels(filename, n):
51+
with open(filename, "rb") as ifs:
52+
ifs.seek(8) # header
53+
return np.fromfile(ifs, dtype=np.uint8, count=n) \
54+
.astype(np.uint32)
55+
56+
def main():
57+
# Loads data
58+
train_inputs = load_images("data/train-images-idx3-ubyte", NUM_TRAIN_SAMPLES)
59+
train_labels = load_labels("data/train-labels-idx1-ubyte", NUM_TRAIN_SAMPLES)
60+
test_inputs = load_images("data/t10k-images-idx3-ubyte", NUM_TEST_SAMPLES)
61+
test_labels = load_labels("data/t10k-labels-idx1-ubyte", NUM_TEST_SAMPLES)
62+
63+
dev = D.CUDA(0);
64+
Device.set_default(dev)
65+
g = Graph()
66+
Graph.set_default(g)
67+
68+
# Parameters of CNNs
69+
# Shape: {kernel_height, kernel_width, in_channels, out_channels}
70+
pw_cnn1 = Parameter(
71+
Shape([KERNEL_SIZE1, KERNEL_SIZE1, 1, NUM_CHANNELS1]),
72+
I.XavierUniformConv2D())
73+
pw_cnn2 = Parameter(
74+
Shape([KERNEL_SIZE2, KERNEL_SIZE2, NUM_CHANNELS1, NUM_CHANNELS2]),
75+
I.XavierUniformConv2D())
76+
77+
# Parameters of FC layers
78+
pw_fc1 = Parameter(Shape([NUM_HIDDEN_UNITS, NUM_INPUT_UNITS]), I.XavierUniform())
79+
pw_fc2 = Parameter(Shape([NUM_OUTPUT_UNITS, NUM_HIDDEN_UNITS]), I.XavierUniform())
80+
pb_fc1 = Parameter(Shape([NUM_HIDDEN_UNITS]), I.Constant(0))
81+
pb_fc2 = Parameter(Shape([NUM_OUTPUT_UNITS]), I.Constant(0))
82+
83+
# Optimizer
84+
optimizer = O.SGD(.1)
85+
optimizer.add(pw_cnn1, pw_cnn2, pw_fc1, pw_fc2, pb_fc1, pb_fc2)
86+
87+
# Helper lambda to construct the predictor network.
88+
def make_graph(inputs, train):
89+
# Input and parameters.
90+
#x = F.input(Shape([IMAGE_HEIGHT, IMAGE_WIDTH], BATCH_SIZE), inputs)
91+
x = F.input(inputs)
92+
w_cnn1 = F.parameter(pw_cnn1)
93+
w_cnn2 = F.parameter(pw_cnn2)
94+
w_fc1 = F.parameter(pw_fc1)
95+
w_fc2 = F.parameter(pw_fc2)
96+
b_fc1 = F.parameter(pb_fc1)
97+
b_fc2 = F.parameter(pb_fc2)
98+
# CNNs
99+
h_cnn1 = F.relu(F.conv2d(x, w_cnn1, PADDING1, PADDING1, 1, 1, 1, 1))
100+
h_pool1 = F.max_pool2d(h_cnn1, 2, 2, 0, 0, 2, 2)
101+
h_cnn2 = F.relu(F.conv2d(h_pool1, w_cnn2, PADDING2, PADDING2, 1, 1, 1, 1))
102+
h_pool2 = F.max_pool2d(h_cnn2, 2, 2, 0, 0, 2, 2)
103+
# FC layers
104+
x_fc = F.dropout(F.flatten(h_pool2), .5, train)
105+
h_fc = F.dropout(
106+
F.relu(F.matmul(w_fc1, x_fc) + b_fc1), .5, train)
107+
return F.matmul(w_fc2, h_fc) + b_fc2
108+
109+
# Batch randomizer
110+
ids = list(range(NUM_TRAIN_SAMPLES))
111+
112+
for epoch in range(MAX_EPOCH):
113+
# Shuffles sample IDs.
114+
random.shuffle(ids)
115+
116+
# Training loop
117+
for batch in range(NUM_TRAIN_BATCHES):
118+
print("\rTraining... %d / %d" % (batch + 1, NUM_TRAIN_BATCHES), end="")
119+
# Makes a minibatch for training.
120+
inputs = [train_inputs[ids[batch * BATCH_SIZE + i]] for i in range(BATCH_SIZE)]
121+
labels = [train_labels[ids[batch * BATCH_SIZE + i]] for i in range(BATCH_SIZE)]
122+
123+
# Constructs the graph.
124+
g.clear();
125+
y = make_graph(inputs, True);
126+
loss = F.softmax_cross_entropy(y, labels, 0)
127+
avg_loss = F.batch.mean(loss)
128+
129+
# Dump computation graph at the first time.
130+
# if epoch == 0 and batch == 0:
131+
# print(g.dump("dot"))
132+
133+
# Implicit forward, backward, and updates parameters.
134+
optimizer.reset_gradients()
135+
avg_loss.backward()
136+
optimizer.update()
137+
138+
print()
139+
140+
match = 0
141+
142+
# Test loop
143+
for batch in range(NUM_TEST_BATCHES):
144+
print("\rTesting... %d / %d" % (batch + 1, NUM_TEST_BATCHES), end="")
145+
# Makes a test minibatch.
146+
inputs = [test_inputs[batch * BATCH_SIZE + i] for i in range(BATCH_SIZE)]
147+
148+
# Constructs the graph.
149+
g.clear()
150+
y = make_graph(inputs, False)
151+
152+
# Gets outputs, argmax, and compares them with the label.
153+
y_val = y.to_list()
154+
for i in range(BATCH_SIZE):
155+
maxval = -1e10
156+
argmax = -1
157+
for j in range(NUM_OUTPUT_UNITS):
158+
v = y_val[j + i * NUM_OUTPUT_UNITS]
159+
if v > maxval:
160+
maxval = v
161+
argmax = j
162+
163+
if argmax == test_labels[i + batch * BATCH_SIZE]:
164+
match += 1
165+
166+
accuracy = 100.0 * match / NUM_TEST_SAMPLES;
167+
print("epoch %d: accuracy: %.2f%%" % (epoch, accuracy))
168+
169+
return 0
170+
171+
172+
if __name__ == "__main__":
173+
main()

primitiv/_function.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ cdef extern from "primitiv/functions.h":
5353
Var func_softmax_cross_entropy "primitiv::functions::softmax_cross_entropy" [Var](const Var &x, const Var &t, unsigned dim) except +
5454
Var func_softmax_cross_entropy "primitiv::functions::softmax_cross_entropy" [Var](const Var &x, const vector[unsigned] &ids, unsigned dim) except +
5555
Var func_stop_gradient "primitiv::functions::stop_gradient" [Var](const Var &x) except +
56+
Var func_conv2d "primitiv::functions::conv2d" [Var](const Var &x, const Var &w, unsigned padding0, unsigned padding1, unsigned stride0, unsigned stride1, unsigned dilation0, unsigned dilation1) except +
57+
Var func_max_pool2d "primitiv::functions::max_pool2d" [Var](const Var &x, unsigned window0, unsigned window1, unsigned padding0, unsigned padding1, unsigned stride0, unsigned stride1) except +
5658

5759
CppTensor func_constant_tensor "primitiv::functions::constant_tensor" (const CppShape &shape, float k, CppDevice *dev) except +
5860
CppNode func_constant_node "primitiv::functions::constant_node" (const CppShape &shape, float k, CppDevice *dev, CppGraph *g) except +

primitiv/_function.pyx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,26 @@ class functions:
205205
def stop_gradient(Node x):
206206
return wrapNode(func_stop_gradient(x.wrapped))
207207

208+
@staticmethod
209+
def conv2d(Node x, Node w,
210+
unsigned padding0, unsigned padding1,
211+
unsigned stride0, unsigned stride1,
212+
unsigned dilation0, unsigned dilation1):
213+
return wrapNode(func_conv2d(x.wrapped, w.wrapped,
214+
padding0, padding1,
215+
stride0, stride1,
216+
dilation0, dilation1))
217+
218+
@staticmethod
219+
def max_pool2d(Node x,
220+
unsigned window0, unsigned window1,
221+
unsigned padding0, unsigned padding1,
222+
unsigned stride0, unsigned stride1):
223+
return wrapNode(func_max_pool2d(x.wrapped,
224+
window0, window1,
225+
padding0, padding1,
226+
stride0, stride1))
227+
208228
@staticmethod
209229
def constant(shape, float k, Device device = None, Graph graph = None):
210230
return wrapNode(func_constant_node(normShape(shape).wrapped, k,
@@ -459,6 +479,26 @@ class tensor_functions:
459479
def stop_gradient(Tensor x):
460480
return Tensor.get_wrapper_with_new(new CppTensor(func_stop_gradient(x.wrapped[0])))
461481

482+
@staticmethod
483+
def conv2d(Tensor x, Tensor w,
484+
unsigned padding0, unsigned padding1,
485+
unsigned stride0, unsigned stride1,
486+
unsigned dilation0, unsigned dilation1):
487+
return Tensor.get_wrapper_with_new(new CppTensor(func_conv2d(x.wrapped[0], w.wrapped[0],
488+
padding0, padding1,
489+
stride0, stride1,
490+
dilation0, dilation1)))
491+
492+
@staticmethod
493+
def max_pool2d(Tensor x,
494+
unsigned window0, unsigned window1,
495+
unsigned padding0, unsigned padding1,
496+
unsigned stride0, unsigned stride1):
497+
return Tensor.get_wrapper_with_new(new CppTensor(func_max_pool2d(x.wrapped[0],
498+
window0, window1,
499+
padding0, padding1,
500+
stride0, stride1)))
501+
462502
@staticmethod
463503
def constant(shape, float k, Device device = None):
464504
return Tensor.get_wrapper_with_new(new CppTensor(func_constant_tensor(normShape(shape).wrapped, k,

primitiv/initializers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from primitiv.initializers._initializer_impl import Identity
55
from primitiv.initializers._initializer_impl import XavierUniform
66
from primitiv.initializers._initializer_impl import XavierNormal
7+
from primitiv.initializers._initializer_impl import XavierUniformConv2D
8+
from primitiv.initializers._initializer_impl import XavierNormalConv2D
9+
710

811
__all__ = [
912
"Constant",
@@ -12,4 +15,6 @@
1215
"Identity",
1316
"XavierUniform",
1417
"XavierNormal",
18+
"XavierUniformConv2D",
19+
"XavierNormalConv2D",
1520
]

primitiv/initializers/_initializer_impl.pxd

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ cdef extern from "primitiv/initializer_impl.h":
2020
cdef cppclass CppXavierNormal "primitiv::initializers::XavierNormal" (CppInitializer):
2121
CppXavierNormal(float scale)
2222

23+
cdef cppclass CppXavierUniformConv2D "primitiv::initializers::XavierUniformConv2D" (CppInitializer):
24+
CppXavierUniformConv2D(float scale)
25+
26+
cdef cppclass CppXavierNormalConv2D "primitiv::initializers::XavierNormalConv2D" (CppInitializer):
27+
CppXavierNormalConv2D(float scale)
28+
2329

2430
cdef class Constant(Initializer):
2531
pass
@@ -38,3 +44,9 @@ cdef class XavierUniform(Initializer):
3844

3945
cdef class XavierNormal(Initializer):
4046
pass
47+
48+
cdef class XavierUniformConv2D(Initializer):
49+
pass
50+
51+
cdef class XavierNormalConv2D(Initializer):
52+
pass

primitiv/initializers/_initializer_impl.pyx

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ cdef class XavierUniform(Initializer):
107107
def __init__(self, scale = 1.0):
108108
"""Crates a new initializer object.
109109
110-
:param scale: Scale of the distribusion.
110+
:param scale: Additional scaling factor of the uniform distribution.
111111
:type scale: float
112112
113113
"""
@@ -147,3 +147,53 @@ cdef class XavierNormal(Initializer):
147147
temp = <CppXavierNormal*> self.wrapped_newed
148148
del temp
149149
self.wrapped_newed = NULL
150+
151+
152+
cdef class XavierUniformConv2D(Initializer):
153+
"""The Xavier initialization with the uniform distribution for conv2d filters.
154+
155+
"""
156+
157+
def __init__(self, scale = 1.0):
158+
"""Creates a new `XavierUniformConv2D` initializer.
159+
160+
:param scale: Additional scaling factor of the uniform distribution.
161+
:type scale: float
162+
163+
"""
164+
if self.wrapped_newed is not NULL:
165+
raise TypeError("__init__() has already been called.")
166+
self.wrapped_newed = new CppXavierUniformConv2D(scale)
167+
self.wrapped = self.wrapped_newed
168+
169+
def __dealloc__(self):
170+
cdef CppXavierUniformConv2D *temp
171+
if self.wrapped_newed is not NULL:
172+
temp = <CppXavierUniformConv2D*> self.wrapped_newed
173+
del temp
174+
self.wrapped_newed = NULL
175+
176+
177+
cdef class XavierNormalConv2D(Initializer):
178+
"""The Xavier initialization with the normal distribution for conv2d filters.
179+
180+
"""
181+
182+
def __init__(self, scale = 1.0):
183+
"""Creates a new `XavierNormalConv2D` initializer.
184+
185+
:param scale: Additional scaling factor of the normal distribution.
186+
:type scale: float
187+
188+
"""
189+
if self.wrapped_newed is not NULL:
190+
raise TypeError("__init__() has already been called.")
191+
self.wrapped_newed = new CppXavierNormalConv2D(scale)
192+
self.wrapped = self.wrapped_newed
193+
194+
def __dealloc__(self):
195+
cdef CppXavierNormalConv2D *temp
196+
if self.wrapped_newed is not NULL:
197+
temp = <CppXavierNormalConv2D*> self.wrapped_newed
198+
del temp
199+
self.wrapped_newed = NULL

0 commit comments

Comments
 (0)