Skip to content

Commit 257d86f

Browse files
committed
Added:MNIST script
1 parent 2271a3e commit 257d86f

4 files changed

Lines changed: 107 additions & 16 deletions

File tree

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,9 @@
11

22
{
33

4-
"comment" : "layers :: RBM layer configuration (No: of Nodes)",
54
"hidden_layers": [1000, 1000, 1000],
6-
7-
"comment" : "activation :: sigmoid or tanh",
85
"activation" : "sigmoid",
9-
10-
"comment" : "pretrained_layers:number of layers to be pre-trained",
116
"pretrained_layers" : 3,
12-
13-
"comment" : "first_layer_type::type for the first layer; either 'bb' (Bernoulli-Bernoulli) or 'gb' (Gaussian-Bernoulli)",
147
"first_layer_type" : "gb",
15-
16-
"comment" : "random_seed::",
178
"random_seed" : 89677
189
}

sample_config/MNIST/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
The MNIST database of handwritten digits, available from [this page](http://yann.lecun.com/exdb/mnist/), has a training set of 60,000 examples, and a test set of 10,000 examples.
2+
3+
* Download [MNIST dataset](http://yann.lecun.com/exdb/mnist/).
4+
* Convert to NP Format using [mnist.py](mnist.py).
5+
* Try with any given recipe (or your own/modified).
6+
Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
{
22

3-
"comment" : "hidden_layers :: RBM layer configuration (No: of Nodes)",
43
"hidden_layers": [1000, 1000, 1000],
5-
6-
"comment":"",
74
"corruption_levels": [0.1, 0.2, 0.3],
8-
9-
"comment" : "activation :: sigmoid or tanh",
105
"activation" : "tanh",
11-
12-
"comment" : "random_seed::",
136
"random_seed" : 89677
147
}

sample_config/MNIST/mnist.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from struct import unpack
2+
import gzip,os,json,numpy
3+
from numpy import zeros, uint8
4+
from pylab import imshow, show, cm
5+
import cPickle as pickle
6+
7+
def get_labeled_data(imagefile, labelfile):
8+
"""
9+
Read input-vector (image) and target class (label, 0-9) and return
10+
it as list of tuples.
11+
"""
12+
13+
# Open the images with gzip in read binary mode
14+
images = gzip.open(imagefile, 'rb')
15+
labels = gzip.open(labelfile, 'rb')
16+
17+
# Read the binary data
18+
19+
# We have to get big endian unsigned int. So we need '>I'
20+
21+
# Get metadata for images
22+
images.read(4) # skip the magic_number
23+
number_of_images = images.read(4)
24+
number_of_images = unpack('>I', number_of_images)[0]
25+
rows = images.read(4)
26+
rows = unpack('>I', rows)[0]
27+
cols = images.read(4)
28+
cols = unpack('>I', cols)[0]
29+
30+
# Get metadata for labels
31+
labels.read(4) # skip the magic_number
32+
N = labels.read(4)
33+
N = unpack('>I', N)[0]
34+
35+
if number_of_images != N:
36+
raise Exception('The number of labels did not match '
37+
'the number of images.')
38+
39+
# Get the data
40+
x = zeros((N, rows*cols), dtype=float) # Initialize numpy array
41+
y = zeros((N, 1), dtype=uint8) # Initialize numpy array
42+
43+
tempx=zeros((rows,cols), dtype=float)
44+
for i in range(N):
45+
print 'Extracting ... {0}%\r'.format((i*100/N)),
46+
for row in range(rows):
47+
for col in range(cols):
48+
tmp_pixel = images.read(1) # Just a single byte
49+
tmp_pixel = unpack('>B', tmp_pixel)[0]
50+
tempx[row][col] = (float(tmp_pixel) / 255)
51+
x[i] = tempx.flatten();
52+
tmp_label = labels.read(1)
53+
y[i] = unpack('>B', tmp_label)[0]
54+
55+
header = {}
56+
header['featdim'] = rows*cols;
57+
header['input_shape'] = [rows,cols,1]
58+
59+
return x,y,header
60+
61+
62+
def saveData(name,x,y,header):
63+
64+
filehandle = open(name,'ab');
65+
filehandle.write(json.dumps(header)+'\n')
66+
dt={'names': ['d','l'],'formats': [('>f2',header['featdim']),'>i2']}
67+
data = numpy.zeros(1,dtype= numpy.dtype(dt))
68+
69+
for vector,label in zip(x,y):
70+
data['d']=vector; data['l']=label;
71+
data.tofile(filehandle);
72+
73+
filehandle.flush();
74+
filehandle.close();
75+
76+
if __name__ == '__main__':
77+
print("Get testset")
78+
(x,y,h)=get_labeled_data('t10k-images-idx3-ubyte.gz',
79+
't10k-labels-idx1-ubyte.gz')
80+
print("Got %i testing datasets." % len(x))
81+
saveData('test.dat',x,y,h);
82+
83+
print("Get trainingset")
84+
(x,y,h)=get_labeled_data('train-images-idx3-ubyte.gz',
85+
'train-labels-idx1-ubyte.gz')
86+
print("Got %i training datasets." % len(x))
87+
seed=9090;
88+
numpy.random.seed(seed)
89+
numpy.random.shuffle(x)
90+
numpy.random.seed(seed)
91+
numpy.random.shuffle(y)
92+
93+
N = len(x)
94+
xtrain = x[:int(N*0.75)]
95+
ytrain = y[:int(N*0.75)]
96+
xval = x[int(N*0.75)+1:]
97+
yval = y[int(N*0.75)+1:]
98+
99+
saveData('train.dat',xtrain,ytrain,h);
100+
saveData('val.dat',xval,yval,h);
101+

0 commit comments

Comments
 (0)