a simple demonstration of the EasyConvNet code
This code demonstrates the EasyConvNet package on the MNIST data set
Contents
data preparation
we need 4 files
trainImages = 'data/train.images.bin';
70 | trainLabels = 'data/train.labels.bin';
71 | testImages = 'data/test.images.bin';
72 | testLabels = 'data/test.labels.bin';
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
Decide if you want GPU or CPU implementation
atGPU = false;
84 |
Define the architecture of a network (the LeNet architecture)
the architecture is a cell array of structs. Each layer is one struct. The first layer must be of 'type' input and the last layer must be of type 'loss' Each layer has fields 'inInd' and 'outInd'. These are indices to where the input of the layer comes from and where the output of the layer goes to.
lenet = { ...
85 | struct('type','input','inInd',0,'outInd',1,'blobSize',[28 28 1 100],'fName',trainImages,'scale',1/256,'dataType','uint8'), ...
86 | struct('type','input','inInd',0,'outInd',2,'blobSize',[10 100],'fName',trainLabels,'scale',1,'dataType','ubit1'), ...
87 | struct('type','conv','inInd',1,'outInd',3,'kernelsize',5,'stride',1,'nOutChannels',20,'bias_filler',0),...
88 | struct('type','maxpool','inInd',3,'outInd',4,'kernelsize',2,'stride',2), ...
89 | struct('type','conv','inInd',4,'outInd',5,'kernelsize',5,'stride',1,'nOutChannels',50,'bias_filler',0),...
90 | struct('type','maxpool','inInd',5,'outInd',6,'kernelsize',2,'stride',2), ...
91 | struct('type','flatten','inInd',6,'outInd',6), ...
92 | struct('type','affine','inInd',6,'outInd',7,'nOutChannels',500,'bias_filler',0), ...
93 | struct('type','relu','inInd',7,'outInd',7), ...
94 | struct('type','affine','inInd',7,'outInd',8,'nOutChannels',10,'bias_filler',0), ...
95 | struct('type','loss','inInd',[8 2],'outInd',10,'lossType','MCLogLoss') };
96 |
initialize a network class
cnn = ConvNet(lenet,atGPU,'Orthogonal');
97 |
For debugging purposes, show some images
x = cnn.net{1}.data.get(1); y = cnn.net{2}.data.get(1); [~,bla] = max(y);
98 | figure; for i=1:25, subplot(5,5,i); imagesc(squeeze(x(:,:,:,i))'); colormap gray; axis off; title(sprintf('%d',bla(i)-1)); end
99 |
Train using SGD with Nesterov's momentum
mandatory fields
T = 1000; mu = single(0.9); printIter = 100; lam = single(0.0005);
100 | learning_rate = @(t)(0.01*(1+0.0001*t)^(-0.75));
101 |
102 | param.snapshotFile = '/tmp/snapshot'; param.printIter = 100; param.printDecay = 0.9;
103 |
104 | cnn.Nesterov(T,learning_rate,mu,lam,param);
105 |
106 | figure; plot(cnn.AllLoss'); title('Nesterov'); legend('Train loss','Train 0-1 error');
107 |
Iter: 100: 0.265366 0.075039
108 | Iter: 200: 0.171004 0.047117
109 | Iter: 300: 0.119551 0.037094
110 | Iter: 400: 0.091825 0.024946
111 | Iter: 500: 0.080097 0.024165
112 | Iter: 600: 0.067927 0.019668
113 | Iter: 700: 0.077423 0.025204
114 | Iter: 800: 0.063537 0.019559
115 | Iter: 900: 0.065999 0.020782
116 | Iter: 1000: 0.045626 0.012953
117 |
calculate test error
testlenet = lenet;
118 | testlenet{1}.fName = testImages;
119 | testlenet{2}.fName = testLabels;
120 | testNet = ConvNet(testlenet,atGPU);
121 | testNet.setTheta(cnn.theta);
122 | testNet.calcLossAndErr();
123 | fprintf(1,'Test loss= %f, Test accuracy = %f\n',testNet.Loss(1),1-testNet.Loss(2));
124 |
Test loss= 0.054117, Test accuracy = 0.982300
125 |
--------------------------------------------------------------------------------
/html/demoMNIST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaisha/EasyConvNet/ec15fa5c0d5cf99d3c9441585435d977f06cb174/html/demoMNIST.png
--------------------------------------------------------------------------------
/html/demoMNIST_01.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaisha/EasyConvNet/ec15fa5c0d5cf99d3c9441585435d977f06cb174/html/demoMNIST_01.png
--------------------------------------------------------------------------------
/html/demoMNIST_02.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/shaisha/EasyConvNet/ec15fa5c0d5cf99d3c9441585435d977f06cb174/html/demoMNIST_02.png
--------------------------------------------------------------------------------
/lossClass.m:
--------------------------------------------------------------------------------
1 | classdef lossClass < handle
2 |
3 | properties (SetAccess = private)
4 | type;
5 | end
6 |
7 |
8 | methods
9 | function this = lossClass(lossType)
10 |
11 | switch lossType
12 | case 'MCLogLoss'
13 | this.type = 1;
14 | case 'SquaredLoss'
15 | this.type = 2;
16 | case 'BinLogLoss'
17 | this.type = 3;
18 | otherwise
19 | assert(false,'Unknown loss type')
20 | end
21 |
22 | end
23 |
24 | function loss=LossAndErr(this,input)
25 | switch this.type
26 | case 1 % multiclass logistic loss
27 | [k,m] = size(input{1});
28 | [loss,ind]=max(input{1});
29 | pred = input{1}-repmat(loss,k,1);
30 | [~,y]=max(input{2});
31 | err = sum(ind~=y)/m;
32 |
33 | valY = sum(pred.*input{2});
34 | loss = pred - repmat(valY,k,1);
35 | loss = sum(log(sum(exp(loss))))/m;
36 | loss = [loss;err];
37 | case 2 % Squared Loss
38 | loss = 0.5*mean(sum((input{1}-input{2}).^2));
39 |
40 | case 3 % binary log loss
41 | loss = -input{1}.*input{2};
42 | err = mean(mean(loss>=0));
43 | loss(loss>0) = loss(loss>0) + log(1+exp(-loss(loss>0)));
44 | loss(loss<=0) = log(1+exp(loss(loss<=0)));
45 | loss = [mean(mean(loss)) ; err];
46 | end
47 | end
48 |
49 | function delta=Grad(this,input)
50 | switch this.type
51 |
52 | case 1 % multiclass logistic loss
53 | bla = input{1}-repmat(max(input{1}),size(input{1},1),1);
54 | bla = exp(bla);
55 | bla=bla./repmat(sum(bla),size(bla,1),1);
56 | delta = (bla - input{2})/size(bla,2);
57 |
58 | case 2 % SquaredLoss
59 | delta = (input{1}-input{2})/size(input{2},2);
60 |
61 | case 3 % binary log loss
62 | delta = -input{2}./(1+exp(input{1}.*input{2}))/prod(size(input{2}));
63 |
64 | end
65 |
66 | end
67 |
68 |
69 | end
70 | end
71 |
--------------------------------------------------------------------------------
/visualizeNetwork.m:
--------------------------------------------------------------------------------
1 | function visualizeNetwork(net,fName)
2 | % visualize network using latex tikz
3 | % Input:
4 | % net definition
5 | % fName - name of a latex output file
6 | % The function creates the file fName, in which there's a latex code that
7 | % generates a visualization of the network net
8 |
9 | lenO = 0;
10 | for i=1:length(net)
11 | lenO = max(lenO,max(net{i}.outInd));
12 | end
13 |
14 | paperheight = lenO*3;
15 |
16 | fid = fopen(fName,'wt');
17 |
18 | prelatex = {...
19 | '\documentclass[8pt]{article}' , ...
20 | sprintf('\\usepackage[paperwidth=6in, paperheight=%dcm]{geometry}',paperheight) , ...
21 | '\usepackage{tikz}' , ...
22 | '\usetikzlibrary{positioning}' , ...
23 | '\begin{document}' ,...
24 | ' ' , ...
25 | '\begin{tikzpicture}' , ...
26 | ' [nodestyle/.style={rectangle,draw=blue!50,fill=blue!20,thick,' ,...
27 | ' inner sep=2pt,minimum width=1cm},' , ...
28 | ' ostyle/.style={rectangle,draw=black!50,fill=black!20,thick,' ,...
29 | ' inner sep=2pt,minimum width=1cm}]' };
30 |
31 | fprintf(fid,'%s\n',prelatex{:});
32 |
33 |
34 | fprintf(fid,'\\node[ostyle] (O1) at (0,0) {1};\n');
35 | for i=2:lenO,
36 | fprintf(fid,'\\node[ostyle] (O%d) [above=of O%d] {%d};\n',i,i-1,i);
37 | end
38 | fprintf(fid,'\\node[nodestyle] (L1) at (10,0) {%s}',net{1}.type);
39 | for j=1:length(net{1}.outInd)
40 | fprintf(fid,'\n edge[->,very thick,blue] (O%d)',net{1}.outInd(j));
41 | end
42 | fprintf(fid,';\n');
43 | for i=2:length(net),
44 | fprintf(fid,'\\node[nodestyle] (L%d) [above=of L%d] {%s}',i,i-1,net{i}.type);
45 | if ~strcmp(net{i}.type,'input')
46 | for j=1:length(net{i}.inInd)
47 | fprintf(fid,'\n edge[<-,very thick,red] (O%d)',net{i}.inInd(j));
48 | end
49 | end
50 | for j=1:length(net{i}.outInd)
51 | fprintf(fid,'\n edge[->,very thick,blue] (O%d)',net{i}.outInd(j));
52 | end
53 | fprintf(fid,';\n');
54 | end
55 |
56 |
57 | % closure
58 | fprintf(fid,'\\end{tikzpicture}\n\\end{document}\n');
59 |
60 |
61 | fclose(fid);
62 |
63 | I = find(fName == '/',1,'last'); coreName = fName(1:end-4); if ~isempty(I), coreName=coreName((I+1):end); end
64 | fprintf(1,'Done. Now run:\n\t unix(''/usr/texbin/pdflatex -output-directory=/tmp %s''); unix(''open /tmp/%s.pdf'');\n',fName,coreName);
65 |
66 | end
--------------------------------------------------------------------------------