Skip to content

Commit b97e7f7

Browse files
author
John Canny
committed
update network scripts
1 parent 7fc7f26 commit b97e7f7

4 files changed

Lines changed: 138 additions & 15 deletions

File tree

scripts/networks/getImageNet.ssc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
2+
val tt = "train";
3+
4+
val imagenetroot = "/data/ImageNet/2012resized/"+tt+"/";
5+
val dataroot = "../../data/ImageNet/";
6+
val savefname = tt+"/part%04d.fmat.lz4";
7+
val labelfname = tt+"/label%04d.imat.lz4";
8+
val namesfname = tt+"/names%04d.csmat.txt";
9+
val loadtable = loadCSMat(dataroot+tt+".txt");
10+
11+
val bsize = 1024;
12+
13+
val nimgs = loadtable.nrows;
14+
15+
val fnames = loadtable(?,0);
16+
val alllabels = loadtable(?,1).toIMat;
17+
18+
val perm = randperm(nimgs);
19+
val mat = zeros(4 \ 256 \ 256 \ bsize);
20+
val labels = izeros(1, bsize);
21+
val names = CSMat(bsize,1);
22+
var i = 0;
23+
var jin = 0;
24+
while (jin < nimgs) {
25+
val todo = math.min(bsize, nimgs - jin);
26+
var j = 0;
27+
while (j < todo && jin < nimgs) {
28+
val indx = perm(jin);
29+
try {
30+
val im = loadImage(imagenetroot+fnames(indx));
31+
mat(?,?,?,j) = im.toFMat;
32+
labels(0, j) = alllabels(indx);
33+
names(j) = fnames(indx);
34+
j += 1;
35+
} catch {
36+
case e:Exception => println("\nProblem reading %s, continuing" format fnames(indx));
37+
}
38+
jin += 1;
39+
}
40+
if (j == bsize) {
41+
saveFMat(dataroot+savefname format i, mat);
42+
saveIMat(dataroot+labelfname format i, labels);
43+
saveCSMat(dataroot+namesfname format i, names);
44+
} else {
45+
saveFMat(dataroot+savefname format i, mat.colslice(0,j));
46+
saveIMat(dataroot+labelfname format i, labels.colslice(0,j));
47+
saveCSMat(dataroot+namesfname format i, names(0->j,0));
48+
}
49+
i += 1;
50+
print(".");
51+
}
52+
println("");
53+
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
:silent
2+
3+
val tt = "train";
4+
5+
val dataroot = "../../data/ImageNet/";
6+
val labelfname = dataroot+tt+"/label%04d.imat.lz4";
7+
val labelsout = dataroot+tt+"/labels%04d.fmat.lz4";
8+
9+
val bsize = 1024;
10+
val nparts = 1252;
11+
12+
print("\nComputing one-hot labels");
13+
val omat = zeros(1000,bsize);
14+
val coln = irow(0->bsize) *@ 1000;
15+
for (i <- 0 until nparts) {
16+
val mat = loadIMat(labelfname format i);
17+
omat.clear;
18+
val inds = mat + coln(0,0->mat.ncols);
19+
omat(inds) = 1f;
20+
if (mat.ncols == bsize) {
21+
saveFMat(labelsout format i, omat);
22+
} else {
23+
saveFMat(labelsout format i, omat.colslice(0,mat.ncols));
24+
}
25+
print(".");
26+
}
27+
println("");
28+
29+
:silent
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
:silent
2+
3+
val tt = "train";
4+
5+
val dataroot = "../../data/ImageNet/";
6+
val datafname = dataroot+tt+"/part%04d.fmat.lz4";
7+
val labelfname = dataroot+tt+"/label%04d.imat.lz4";
8+
val namesfname = dataroot+tt+"/names%04d.csmat.txt";
9+
10+
val bsize = 1024;
11+
val nparts = 1252;
12+
13+
val triminds = irow(16 -> 240);
14+
val trimcolors = irow(0,1,2);
15+
16+
var nimgs = 0L;
17+
val msum = dzeros(3\224\224\1);
18+
19+
print("\nComputing mean");
20+
val times = zeros(1,4)
21+
for (i <- 0 until nparts) {
22+
tic;
23+
val mat = loadFMat(datafname format i);
24+
val t1 = toc;
25+
val trim = mat(trimcolors, triminds, triminds, ?);
26+
val t2 = toc;
27+
val tmpsum = trim.sum(irow(3));
28+
val t3 = toc;
29+
msum ~ msum + DMat(tmpsum);
30+
val t4 = toc;
31+
times ~ times + row(t1,t2-t1,t3-t2,t4-t3);
32+
nimgs = nimgs + trim.ncols;
33+
print(".");
34+
}
35+
println("");
36+
37+
msum ~ msum / nimgs.toDouble;
38+
val means = FMat(msum);
39+
saveFMat(dataroot+tt+"/means.fmat.lz4", means);
40+
41+
:silent

scripts/networks/testConv.ssc

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,48 +20,48 @@ opts.lookahead = 0;
2020
import BIDMach.networks.layers.Node._;
2121

2222
val in = input;
23-
val crop1 = crop(in);
23+
val crop1 = crop(in)(sizes=irow(3,224,224,0),offsets=irow(0,-1,-1,-1));
2424
val means1 = constant(means);
2525

2626
val diff2 = crop1 - means1;
27-
val fmt2 = format(diff2);
27+
val fmt2 = format(diff2)();
2828

2929
val conv3 = conv(fmt2)(w=7,h=7,nch=64,stride=4,pad=3);
30-
val bns3 = batchNormScale(conv3);
30+
val bns3 = batchNormScale(conv3)();
3131
val relu3 = relu(bns3);
3232

3333
val conv4 = conv(relu3)(w=5,h=5,nch=256,stride=2,pad=2);
34-
val bns4 = batchNormScale(conv4);
34+
val bns4 = batchNormScale(conv4)();
3535
val relu4 = relu(bns4);
3636

3737
val conv5 = conv(relu4)(w=3,h=3,nch=256,stride=1,pad=1);
38-
val bns5 = batchNormScale(conv5);
38+
val bns5 = batchNormScale(conv5)();
3939
val relu5 = relu(bns5);
4040

4141
val conv6 = conv(relu5)(w=5,h=5,nch=1024,stride=2,pad=2);
42-
val bns6 = batchNormScale(conv6);
42+
val bns6 = batchNormScale(conv6)();
4343
val relu6 = relu(bns6);
4444

4545
val conv7 = conv(relu6)(w=3,h=3,nch=1024,stride=1,pad=1);
46-
val bns7 = batchNormScale(conv7);
46+
val bns7 = batchNormScale(conv7)();
4747
val relu7 = relu(bns7);
4848

4949
val conv8 = conv(relu7)(w=3,h=3,nch=1024,stride=1,pad=1);
50-
val bns8 = batchNormScale(conv8);
50+
val bns8 = batchNormScale(conv8)();
5151
val relu8 = relu(bns8);
5252

53-
val fc9 = lin(relu8)(outdim=4096);
54-
val relu9 = relu(fc8);
53+
val fc9 = linear(relu8)(outdim=4096);
54+
val relu9 = relu(fc9);
5555

56-
val fc10 = lin(relu9)(outdim=4096);
56+
val fc10 = linear(relu9)(outdim=4096);
5757
val relu10 = relu(fc10);
5858

59-
val fc11 = lin(relu10)(outdim=1000);
59+
val fc11 = linear(relu10)(outdim=1000);
6060
val out = glm(fc11)(opts.links);
6161

62-
val nodes = in \ diff2 \ conv3 \ conv4 \ conv5 \ conv6 \ conv7 \ conv8 \ fc9 \ fc10 \ fc11 on
63-
crop1 \ fmt2 \ bns3 \ bns4 \ bns5 \ bns6 \ bns7 \ bns8 \ relu9 \ relu10 \ out on
64-
means1 \ null \ relu3 \ relu4 \ relu5 \ relu6 \ relu7 \ relu8 \ null \ null \ null;
62+
val nodes = (in \ diff2 \ conv3 \ conv4 \ conv5 \ conv6 \ conv7 \ conv8 \ fc9 \ fc10 \ fc11 on
63+
crop1 \ fmt2 \ bns3 \ bns4 \ bns5 \ bns6 \ bns7 \ bns8 \ relu9 \ relu10 \ out on
64+
means1 \ null \ relu3 \ relu4 \ relu5 \ relu6 \ relu7 \ relu8 \ null \ null \ null );
6565

6666

6767
opts.nodemat = nodes;

0 commit comments

Comments
 (0)