|
| 1 | +#!/usr/bin/env python3 |
| 2 | + |
| 3 | +# Python example of Convolutional Neural Network. |
| 4 | +# Please refer primitiv repository for more details. |
| 5 | +# |
| 6 | +# Usage: |
| 7 | +# $ ./download_data.sh |
| 8 | +# $ python3 ./mnist_cnn.py |
| 9 | + |
| 10 | +import random |
| 11 | + |
| 12 | +import numpy as np |
| 13 | + |
| 14 | +from primitiv import functions as F |
| 15 | +from primitiv import initializers as I |
| 16 | +from primitiv import optimizers as O |
| 17 | +from primitiv import devices as D |
| 18 | +from primitiv import Device, Graph, Parameter, Shape |
| 19 | + |
| 20 | +NUM_TRAIN_SAMPLES = 60000 |
| 21 | +NUM_TEST_SAMPLES = 10000 |
| 22 | +BATCH_SIZE = 200 |
| 23 | +NUM_TRAIN_BATCHES = NUM_TRAIN_SAMPLES // BATCH_SIZE |
| 24 | +NUM_TEST_BATCHES = NUM_TEST_SAMPLES // BATCH_SIZE |
| 25 | +MAX_EPOCH = 100 |
| 26 | + |
| 27 | +IMAGE_HEIGHT = 28 |
| 28 | +IMAGE_WIDTH = 28 |
| 29 | + |
| 30 | +KERNEL_SIZE1 = 5 # should be an odd number |
| 31 | +KERNEL_SIZE2 = 5 # ditto |
| 32 | +NUM_CHANNELS1 = 8 |
| 33 | +NUM_CHANNELS2 = 16 |
| 34 | +PADDING1 = KERNEL_SIZE1 // 2 |
| 35 | +PADDING2 = KERNEL_SIZE2 // 2 |
| 36 | + |
| 37 | +NUM_INPUT_UNITS = (IMAGE_HEIGHT // 4) * (IMAGE_WIDTH // 4) * NUM_CHANNELS2 |
| 38 | +NUM_HIDDEN_UNITS = 256 |
| 39 | +NUM_OUTPUT_UNITS = 10 |
| 40 | + |
| 41 | + |
| 42 | +def load_images(filename, n): |
| 43 | + with open(filename, "rb") as ifs: |
| 44 | + ifs.seek(16) # header |
| 45 | + return (np.fromfile(ifs, dtype=np.uint8, count=n*NUM_INPUT_UNITS) / 255) \ |
| 46 | + .astype(np.float32) \ |
| 47 | + .reshape((n, IMAGE_HEIGHT, IMAGE_WIDTH)) |
| 48 | + |
| 49 | + |
| 50 | +def load_labels(filename, n): |
| 51 | + with open(filename, "rb") as ifs: |
| 52 | + ifs.seek(8) # header |
| 53 | + return np.fromfile(ifs, dtype=np.uint8, count=n) \ |
| 54 | + .astype(np.uint32) |
| 55 | + |
| 56 | +def main(): |
| 57 | + # Loads data |
| 58 | + train_inputs = load_images("data/train-images-idx3-ubyte", NUM_TRAIN_SAMPLES) |
| 59 | + train_labels = load_labels("data/train-labels-idx1-ubyte", NUM_TRAIN_SAMPLES) |
| 60 | + test_inputs = load_images("data/t10k-images-idx3-ubyte", NUM_TEST_SAMPLES) |
| 61 | + test_labels = load_labels("data/t10k-labels-idx1-ubyte", NUM_TEST_SAMPLES) |
| 62 | + |
| 63 | + dev = D.CUDA(0); |
| 64 | + Device.set_default(dev) |
| 65 | + g = Graph() |
| 66 | + Graph.set_default(g) |
| 67 | + |
| 68 | + # Parameters of CNNs |
| 69 | + # Shape: {kernel_height, kernel_width, in_channels, out_channels} |
| 70 | + pw_cnn1 = Parameter( |
| 71 | + Shape([KERNEL_SIZE1, KERNEL_SIZE1, 1, NUM_CHANNELS1]), |
| 72 | + I.XavierUniformConv2D()) |
| 73 | + pw_cnn2 = Parameter( |
| 74 | + Shape([KERNEL_SIZE2, KERNEL_SIZE2, NUM_CHANNELS1, NUM_CHANNELS2]), |
| 75 | + I.XavierUniformConv2D()) |
| 76 | + |
| 77 | + # Parameters of FC layers |
| 78 | + pw_fc1 = Parameter(Shape([NUM_HIDDEN_UNITS, NUM_INPUT_UNITS]), I.XavierUniform()) |
| 79 | + pw_fc2 = Parameter(Shape([NUM_OUTPUT_UNITS, NUM_HIDDEN_UNITS]), I.XavierUniform()) |
| 80 | + pb_fc1 = Parameter(Shape([NUM_HIDDEN_UNITS]), I.Constant(0)) |
| 81 | + pb_fc2 = Parameter(Shape([NUM_OUTPUT_UNITS]), I.Constant(0)) |
| 82 | + |
| 83 | + # Optimizer |
| 84 | + optimizer = O.SGD(.1) |
| 85 | + optimizer.add(pw_cnn1, pw_cnn2, pw_fc1, pw_fc2, pb_fc1, pb_fc2) |
| 86 | + |
| 87 | + # Helper lambda to construct the predictor network. |
| 88 | + def make_graph(inputs, train): |
| 89 | + # Input and parameters. |
| 90 | + #x = F.input(Shape([IMAGE_HEIGHT, IMAGE_WIDTH], BATCH_SIZE), inputs) |
| 91 | + x = F.input(inputs) |
| 92 | + w_cnn1 = F.parameter(pw_cnn1) |
| 93 | + w_cnn2 = F.parameter(pw_cnn2) |
| 94 | + w_fc1 = F.parameter(pw_fc1) |
| 95 | + w_fc2 = F.parameter(pw_fc2) |
| 96 | + b_fc1 = F.parameter(pb_fc1) |
| 97 | + b_fc2 = F.parameter(pb_fc2) |
| 98 | + # CNNs |
| 99 | + h_cnn1 = F.relu(F.conv2d(x, w_cnn1, PADDING1, PADDING1, 1, 1, 1, 1)) |
| 100 | + h_pool1 = F.max_pool2d(h_cnn1, 2, 2, 0, 0, 2, 2) |
| 101 | + h_cnn2 = F.relu(F.conv2d(h_pool1, w_cnn2, PADDING2, PADDING2, 1, 1, 1, 1)) |
| 102 | + h_pool2 = F.max_pool2d(h_cnn2, 2, 2, 0, 0, 2, 2) |
| 103 | + # FC layers |
| 104 | + x_fc = F.dropout(F.flatten(h_pool2), .5, train) |
| 105 | + h_fc = F.dropout( |
| 106 | + F.relu(F.matmul(w_fc1, x_fc) + b_fc1), .5, train) |
| 107 | + return F.matmul(w_fc2, h_fc) + b_fc2 |
| 108 | + |
| 109 | + # Batch randomizer |
| 110 | + ids = list(range(NUM_TRAIN_SAMPLES)) |
| 111 | + |
| 112 | + for epoch in range(MAX_EPOCH): |
| 113 | + # Shuffles sample IDs. |
| 114 | + random.shuffle(ids) |
| 115 | + |
| 116 | + # Training loop |
| 117 | + for batch in range(NUM_TRAIN_BATCHES): |
| 118 | + print("\rTraining... %d / %d" % (batch + 1, NUM_TRAIN_BATCHES), end="") |
| 119 | + # Makes a minibatch for training. |
| 120 | + inputs = [train_inputs[ids[batch * BATCH_SIZE + i]] for i in range(BATCH_SIZE)] |
| 121 | + labels = [train_labels[ids[batch * BATCH_SIZE + i]] for i in range(BATCH_SIZE)] |
| 122 | + |
| 123 | + # Constructs the graph. |
| 124 | + g.clear(); |
| 125 | + y = make_graph(inputs, True); |
| 126 | + loss = F.softmax_cross_entropy(y, labels, 0) |
| 127 | + avg_loss = F.batch.mean(loss) |
| 128 | + |
| 129 | + # Dump computation graph at the first time. |
| 130 | + # if epoch == 0 and batch == 0: |
| 131 | + # print(g.dump("dot")) |
| 132 | + |
| 133 | + # Implicit forward, backward, and updates parameters. |
| 134 | + optimizer.reset_gradients() |
| 135 | + avg_loss.backward() |
| 136 | + optimizer.update() |
| 137 | + |
| 138 | + print() |
| 139 | + |
| 140 | + match = 0 |
| 141 | + |
| 142 | + # Test loop |
| 143 | + for batch in range(NUM_TEST_BATCHES): |
| 144 | + print("\rTesting... %d / %d" % (batch + 1, NUM_TEST_BATCHES), end="") |
| 145 | + # Makes a test minibatch. |
| 146 | + inputs = [test_inputs[batch * BATCH_SIZE + i] for i in range(BATCH_SIZE)] |
| 147 | + |
| 148 | + # Constructs the graph. |
| 149 | + g.clear() |
| 150 | + y = make_graph(inputs, False) |
| 151 | + |
| 152 | + # Gets outputs, argmax, and compares them with the label. |
| 153 | + y_val = y.to_list() |
| 154 | + for i in range(BATCH_SIZE): |
| 155 | + maxval = -1e10 |
| 156 | + argmax = -1 |
| 157 | + for j in range(NUM_OUTPUT_UNITS): |
| 158 | + v = y_val[j + i * NUM_OUTPUT_UNITS] |
| 159 | + if v > maxval: |
| 160 | + maxval = v |
| 161 | + argmax = j |
| 162 | + |
| 163 | + if argmax == test_labels[i + batch * BATCH_SIZE]: |
| 164 | + match += 1 |
| 165 | + |
| 166 | + accuracy = 100.0 * match / NUM_TEST_SAMPLES; |
| 167 | + print("epoch %d: accuracy: %.2f%%" % (epoch, accuracy)) |
| 168 | + |
| 169 | + return 0 |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == "__main__": |
| 173 | + main() |
0 commit comments