├── .gitignore ├── LICENSE ├── README.md ├── code ├── ConfusionMatrix.lua ├── MyConfusionMatrix.lua ├── ProFi.lua ├── conv-functions.lua ├── dataset-from-tensor.lua ├── dataset-mnist.lua ├── rbm-grads.lua ├── rbm-helpers.lua ├── rbm-regularization.lua ├── rbm-util.lua ├── rbm-visualisation.lua └── rbm.lua ├── examples ├── examples.txt ├── rbm_tests.lua ├── runrbm.lua ├── stackrbms.lua ├── test_image.lua ├── testconv.lua └── testconv3d.lua └── run_docker.sh /.gitignore: -------------------------------------------------------------------------------- 1 | sigp-data/ 2 | 20news-data/ 3 | examples/runRBMmy.lua 4 | code/dataset-sigp.lua 5 | 6 | code/dataset-sigp.lua 7 | 8 | # Ignore the saved RBM files by default. 9 | *.asc 10 | 11 | # Ignore the mnist-th7 dataset. 12 | examples/mnist-th7/ 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2014, Søren Kaae Sønderby 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | * Neither the name of the {organization} nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | RBM Toolbox for Torch 2 | =============== 3 | 4 | RBM toolbox is a Torch7 toolbox for online training of RBM's. A MATLAB version exists at 5 | [LINK](https://github.com/skaae/rbm_toolbox). 6 | 7 | The following is supported: 8 | * Support for training RBM's with class labels including: 9 | * Generative training objective [2,7] 10 | * Discriminative training objective [2,7] 11 | * Hybrid training objective [2,7] 12 | * Semi-supervised learning [2,7] untested 13 | * CD - k (contrastive divergence k) [5] 14 | * PCD (persistent contrastive divergence) [6] 15 | * RBM Classification support [2,7] 16 | * Regularization: L1, L2, sparsity, early-stopping, dropout [1], momentum [3] 17 | 18 | # Installation 19 | 20 | 1. Install torch7: Follow [these](https://github.com/torch/torch7/wiki/Cheatsheet#installing-and-running-torch) instructions 21 | 2. download this repository: `git clone https://github.com/skaae/rbm_toolbox_lua.git` 22 | 4. To run the examples install wget with homebrew 23 | 3. Run example rbms with examples/runrbm.lua 24 | 25 | # Examples 26 | Run from /example folder 27 | 28 | 29 | 1) th runrbm.lua -eta 0.05 -alpha 0 -nhidden 500 -folder test_discriminative 30 | 31 | 2) th runrbm.lua -eta 0.05 -alpha 0 -nhidden 500 -folder test_discriminative_dropout -dropout 0.5 32 | 33 | 3) th runrb,.lua -eta 0.05 -alpha 1 -nhidden 500 -folder test_generative_pcd -traintype PCD 34 | 4) th runrbm.lua -eta 0.05 -alpha 0.01 -nhidden 1500 -folder test_hybrid 35 | 36 | 5) th runrbm.lua -eta 0.05 -alpha 0.01 -nhidden 1500 -folder test_hybrid_dropout -dropout 0.5 37 | 38 | 6) th runrbm.lua -eta 0.05 -alpha 0.01 -nhidden 3000 -folder test_hybrid_sparsity -sparsity 0.0001 39 | 40 | 7) th runrbm.lua -eta 0.05 -alpha 0.01 -nhidden 3000 -folder test_hybrid_sparsity_dropout -sparsity 0.0001 -dropout 0.5 41 | 42 | 8) th runrbm.lua -eta 0.05 -alpha 1 -nhidden 1000 -folder test_generative 43 | 44 | 9) th runrbm.lua -eta 0.05 -alpha 1 -nhidden 2000 -folder test_generative -dropout -0.5 45 | 46 | # Using your own data 47 | You can create our own datasets with the functions in 48 | code/dataset-from-tensor.lua 49 | 50 | ```LUA 51 | codeFolder = '../code/' 52 | require('torch') 53 | require(codeFolder..'rbm') 54 | require(codeFolder..'dataset-from-tensor') 55 | require 'paths' 56 | geometry = {1,100} -- dimensions of your training data 57 | nclasses = 3 58 | nSamples = 5 59 | trainTensor = torch.rand(nSamples,geometry[1],geometry[2]) 60 | trainLabels = torch.Tensor({1,2,3,1,2}) 61 | classes = {'ClassA','ClassB','ClassC'} 62 | trainData = datatensor.createDataset(trainTensor, 63 | oneOfK(nclasses,trainLabels), 64 | classes, 65 | geometry) 66 | print(trainData:next()) 67 | print(trainData[2]) 68 | print(trainData:classnames()) 69 | 70 | ``` 71 | # TODO 72 | 73 | 1. DO DROPOUT DISCRIMINATIVE WITH SPARSITY? 74 | 2. Use momentum to smooth gradients? + Decrease learning rate 75 | 3. Generative training example + samples drawn from model 76 | 4. Hybrid training exampe 77 | 5. Semisupervised example 78 | 6. Implement stacking of RBM's 79 | 80 | # References 81 | 82 | [1] Srivastava Nitish, G. Hinton, A. Krizhevsky, I. Sutskever, and R. R. Salakhutdinov, “Dropout: A Simple Way to Prevent Neural Networks from Overfitting,” J. Mach. Learn. Res., vol. 5(Jun), no. 2, p. 1929−1958, 2014. 83 | [2] H. Larochelle and Y. Bengio, “Classification using discriminative restricted Boltzmann machines,” in Proceedings of the 25th international conference on Machine learning. ACM,, 2008. 84 | [3] G. Hinton, “A practical guide to training restricted Boltzmann machines,” Momentum, vol. 9, no. 1, p. 926, 2010. 85 | [4] G. Hinton, N. Srivastava, A. Krizhevsky, I. Sutskever, and R. R. Salakhutdinov, “Improving neural networks by preventing co-adaptation of feature detectors,” arXiv Prepr., vol. 1207.0580, no. Hinton, Geoffrey E., et al. "Improving neural networks by preventing co-adaptation of feature detectors." arXiv preprint arXiv:1207.0580 (2012)., Jul. 2012. 86 | [5] G. Hinton, “Training products of experts by minimizing contrastive divergence,” Neural Comput., vol. 14, no. 8, pp. 1771–1800, 2002. 87 | [6] T. Tieleman, “Training restricted Boltzmann machines using approximations to the likelihood gradient,” in Proceedings of the 25th international conference on Machine learning. ACM, 2008. 88 | [7] H. Larochelle, M. Mandel, R. Pascanu, and Y. Bengio, “Learning algorithms for the classification restricted boltzmann machine,” J. Mach. Learn. Res., vol. 13, no. 1, pp. 643–669, 2012. 89 | [8] R. Salakhutdinov and I. Murray, “On the quantitative analysis of deep belief networks,” in Proceedings of the 25th international conference on Machine learning. ACM,, 2008. 90 | [9] Y. Tang and I. Sutskever, “Data normalization in the learning of restricted Boltzmann machines,” Dep. Comput. Sci. Toronto Univ., vol. UTML-TR-11, 2011. 91 | [10] L. Wan, M. Zeiler, S. Zhang, Y. Le Cun, and R. Fergus, “Regularization of Neural Networks using DropConnect,” in Proceedings of The 30th International Conference on Machine Learning, 2013, pp. 1058–1066. 92 | 93 | Copyright (c) 2014, Søren Kaae Sønderby (skaaesonderby@gmail.com) All rights reserved. 94 | -------------------------------------------------------------------------------- /code/ConfusionMatrix.lua: -------------------------------------------------------------------------------- 1 | --[[ A Confusion Matrix class 2 | 3 | Example: 4 | 5 | conf = optim.ConfusionMatrix( {'cat','dog','person'} ) -- new matrix 6 | conf:zero() -- reset matrix 7 | for i = 1,N do 8 | conf:add( neuralnet:forward(sample), label ) -- accumulate errors 9 | end 10 | print(conf) -- print matrix 11 | image.display(conf:render()) -- render matrix 12 | ]] 13 | --local ConfusionMatrix = torch.class('optim.MyConfusionMatrix') 14 | 15 | local ConfusionMatrix = torch.class('ConfusionMatrix') 16 | 17 | function ConfusionMatrix:__init(nclasses, classes) 18 | if type(nclasses) == 'table' then 19 | classes = nclasses 20 | nclasses = #classes 21 | end 22 | self.mat = torch.FloatTensor(nclasses,nclasses):zero() 23 | self.valids = torch.FloatTensor(nclasses):zero() 24 | self.unionvalids = torch.FloatTensor(nclasses):zero() 25 | self.nclasses = nclasses 26 | self.totalValid = 0 27 | self.averageValid = 0 28 | self.classes = classes or {} 29 | end 30 | 31 | function ConfusionMatrix:add(prediction, target) 32 | if type(prediction) == 'number' then 33 | -- comparing numbers 34 | self.mat[target][prediction] = self.mat[target][prediction] + 1 35 | elseif type(target) == 'number' then 36 | -- prediction is a vector, then target assumed to be an index 37 | self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses) 38 | self.prediction_1d:copy(prediction) 39 | local _,prediction = self.prediction_1d:max(1) 40 | self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1 41 | else 42 | -- both prediction and target are vectors 43 | self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses) 44 | self.prediction_1d:copy(prediction) 45 | self.target_1d = self.target_1d or torch.FloatTensor(self.nclasses) 46 | self.target_1d:copy(target) 47 | local _,prediction = self.prediction_1d:max(1) 48 | local _,target = self.target_1d:max(1) 49 | self.mat[target[1]][prediction[1]] = self.mat[target[1]][prediction[1]] + 1 50 | end 51 | end 52 | 53 | function ConfusionMatrix:batchAdd(predictions, targets) 54 | local preds, targs, __ 55 | if predictions:dim() == 1 then 56 | -- predictions is a vector of classes 57 | preds = predictions 58 | elseif predictions:dim() == 2 then 59 | -- prediction is a matrix of class likelihoods 60 | if predictions:size(2) == 1 then 61 | -- or prediction just needs flattening 62 | preds = predictions:select(2,1) 63 | else 64 | __,preds = predictions:max(2) 65 | preds:resize(preds:size(1)) 66 | end 67 | else 68 | error("predictions has invalid number of dimensions") 69 | end 70 | 71 | if targets:dim() == 1 then 72 | -- targets is a vector of classes 73 | targs = targets 74 | elseif targets:dim() == 2 then 75 | -- targets is a matrix of one-hot rows 76 | if targets:size(2) == 1 then 77 | -- or targets just needs flattening 78 | targs = targets:select(2,1) 79 | else 80 | __,targs = targets:max(2) 81 | targs:resize(targs:size(1)) 82 | end 83 | else 84 | error("targets has invalid number of dimensions") 85 | end 86 | --loop over each pair of indices 87 | for i = 1,preds:size(1) do 88 | self.mat[targs[i]][preds[i]] = self.mat[targs[i]][preds[i]] + 1 89 | end 90 | end 91 | 92 | function ConfusionMatrix:zero() 93 | self.mat:zero() 94 | self.valids:zero() 95 | self.unionvalids:zero() 96 | self.totalValid = 0 97 | self.averageValid = 0 98 | end 99 | 100 | local function isNaN(number) 101 | return number ~= number 102 | end 103 | 104 | local function remNaN(x,self) 105 | for i = 1, self.nclasses do 106 | if isNaN(x[{1,i}]) then 107 | x[{1,i}] = 0 108 | end 109 | end 110 | return x 111 | end 112 | 113 | 114 | local function getErrors(self) 115 | local tp, fn, fp, tn 116 | tp = torch.diag(self.mat):resize(1,self.nclasses ) 117 | fn = (torch.sum(self.mat,2)-torch.diag(self.mat)):t() 118 | fp = torch.sum(self.mat,1)-torch.diag(self.mat) 119 | tn = torch.Tensor(1,self.nclasses):fill(torch.sum(self.mat)):typeAs(tp) - tp - fn - fp 120 | 121 | return tp, tn, fp, fn 122 | end 123 | 124 | 125 | function ConfusionMatrix:getConfusion() 126 | return getErrors(self) 127 | end 128 | 129 | function ConfusionMatrix:matthewsCorrelation() 130 | local mcc,numerator, denominator 131 | tp, tn, fp, fn = getErrors(self) 132 | numerator = torch.cmul(tp,tn) - torch.cmul(fp,fn) 133 | denominator = torch.sqrt((tp+fp):cmul(tp+fn):cmul(tn+fp):cmul(tn+fn)) 134 | mcc = torch.cdiv(numerator,denominator) 135 | mcc = remNaN(mcc,self) 136 | return mcc 137 | end 138 | 139 | function ConfusionMatrix:sensitivity() 140 | tp, tn, fp, fn = getErrors(self) 141 | res = torch.cdiv(tp, tp + fn ) 142 | res = remNaN(res,self) 143 | return res -- TP / (TP + FN) 144 | end 145 | 146 | function ConfusionMatrix:specificity() 147 | tp, tn, fp, fn = getErrors(self) 148 | res = torch.cdiv(tn, tn + fp) -- TN / (TN + FP) 149 | res = remNaN(res,self) 150 | return res -- TP / (TP + FN) 151 | 152 | end 153 | 154 | function ConfusionMatrix:positivePredictiveValue() 155 | tp, tn, fp, fn = getErrors(self) 156 | res = torch.cdiv(tp, tp + fp ) -- TP / (TP + FP) 157 | res = remNaN(res,self) 158 | return res -- TP / (TP + FN) 159 | end 160 | 161 | function ConfusionMatrix:negativePredictiveValue() 162 | tp, tn, fp, fn = getErrors(self) 163 | res = torch.cdiv(tn, tn + fn ) -- TN / (TN + FN) 164 | res = remNaN(res,self) 165 | return res -- TP / (TP + FN) 166 | end 167 | 168 | function ConfusionMatrix:falsePositiveRate() 169 | tp, tn, fp, fn = getErrors(self) 170 | res = torch.cdiv(fp, fp + tn) -- FP / (FP + TN) 171 | res = remNaN(res,self) 172 | return res -- TP / (TP + FN) 173 | end 174 | 175 | function ConfusionMatrix:falseDiscoveryRate() 176 | tp, tn, fp, fn = getErrors(self) 177 | res = torch.cdiv(fp, tp + fp) -- FP / (TP + FP) 178 | res = remNaN(res,self) 179 | return res -- TP / (TP + FN) 180 | end 181 | 182 | function ConfusionMatrix:classAccuracy() 183 | tp, tn, fp, fn = getErrors(self) 184 | res = torch.cdiv(tp + tn, tp + tn + fp + fn) -- (TP + FN) / (TN + TP + FN + FP) 185 | res = remNaN(res,self) 186 | return res -- TP / (TP + FN) 187 | end 188 | 189 | function ConfusionMatrix:F1() 190 | tp, tn, fp, fn = getErrors(self) 191 | res = torch.cdiv(tp * 2, tp * 2 + fp + fn) -- (2*TP)/(TP*2+fp+fn) 192 | res = remNaN(res,self) 193 | return res -- TP / (TP + FN) 194 | end 195 | 196 | function ConfusionMatrix:updateValids() 197 | local total = 0 198 | for t = 1,self.nclasses do 199 | self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum() 200 | self.unionvalids[t] = self.mat[t][t] / (self.mat:select(1,t):sum()+self.mat:select(2,t):sum()-self.mat[t][t]) 201 | total = total + self.mat[t][t] 202 | end 203 | self.totalValid = total / self.mat:sum() 204 | self.averageValid = 0 205 | self.averageUnionValid = 0 206 | local nvalids = 0 207 | local nunionvalids = 0 208 | for t = 1,self.nclasses do 209 | if not isNaN(self.valids[t]) then 210 | self.averageValid = self.averageValid + self.valids[t] 211 | nvalids = nvalids + 1 212 | end 213 | if not isNaN(self.valids[t]) and not isNaN(self.unionvalids[t]) then 214 | self.averageUnionValid = self.averageUnionValid + self.unionvalids[t] 215 | nunionvalids = nunionvalids + 1 216 | end 217 | end 218 | self.averageValid = self.averageValid / nvalids 219 | self.averageUnionValid = self.averageUnionValid / nunionvalids 220 | end 221 | 222 | function ConfusionMatrix:__tostring__() 223 | self:updateValids() 224 | local str = {'ConfusionMatrix:\n'} 225 | local nclasses = self.nclasses 226 | table.insert(str, '[') 227 | for t = 1,nclasses do 228 | local pclass = self.valids[t] * 100 229 | pclass = string.format('%2.3f', pclass) 230 | if t == 1 then 231 | table.insert(str, '[') 232 | else 233 | table.insert(str, ' [') 234 | end 235 | for p = 1,nclasses do 236 | table.insert(str, string.format('%8d', self.mat[t][p])) 237 | end 238 | if self.classes and self.classes[1] then 239 | if t == nclasses then 240 | table.insert(str, ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n') 241 | else 242 | table.insert(str, '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n') 243 | end 244 | else 245 | if t == nclasses then 246 | table.insert(str, ']] ' .. pclass .. '% \n') 247 | else 248 | table.insert(str, '] ' .. pclass .. '% \n') 249 | end 250 | end 251 | end 252 | table.insert(str, ' + average row correct: ' .. (self.averageValid*100) .. '% \n') 253 | table.insert(str, ' + average rowUcol correct (VOC measure): ' .. (self.averageUnionValid*100) .. '% \n') 254 | table.insert(str, ' + global correct: ' .. (self.totalValid*100) .. '%') 255 | return table.concat(str) 256 | end 257 | 258 | function ConfusionMatrix:render(sortmode, display, block, legendwidth) 259 | -- args 260 | local confusion = self.mat 261 | local classes = self.classes 262 | local sortmode = sortmode or 'score' -- 'score' or 'occurrence' 263 | local block = block or 25 264 | local legendwidth = legendwidth or 200 265 | local display = display or false 266 | 267 | -- legends 268 | local legend = { 269 | ['score'] = 'Confusion matrix [sorted by scores, global accuracy = %0.3f%%, per-class accuracy = %0.3f%%]', 270 | ['occurrence'] = 'Confusiong matrix [sorted by occurences, accuracy = %0.3f%%, per-class accuracy = %0.3f%%]' 271 | } 272 | 273 | -- parse matrix / normalize / count scores 274 | local diag = torch.FloatTensor(#classes) 275 | local freqs = torch.FloatTensor(#classes) 276 | local unconf = confusion 277 | local confusion = confusion:clone() 278 | local corrects = 0 279 | local total = 0 280 | for target = 1,#classes do 281 | freqs[target] = confusion[target]:sum() 282 | corrects = corrects + confusion[target][target] 283 | total = total + freqs[target] 284 | confusion[target]:div( math.max(confusion[target]:sum(),1) ) 285 | diag[target] = confusion[target][target] 286 | end 287 | 288 | -- accuracies 289 | local accuracy = corrects / total * 100 290 | local perclass = 0 291 | local total = 0 292 | for target = 1,#classes do 293 | if confusion[target]:sum() > 0 then 294 | perclass = perclass + diag[target] 295 | total = total + 1 296 | end 297 | end 298 | perclass = perclass / total * 100 299 | freqs:div(unconf:sum()) 300 | 301 | -- sort matrix 302 | if sortmode == 'score' then 303 | _,order = torch.sort(diag,1,true) 304 | elseif sortmode == 'occurrence' then 305 | _,order = torch.sort(freqs,1,true) 306 | else 307 | error('sort mode must be one of: score | occurrence') 308 | end 309 | 310 | -- render matrix 311 | local render = torch.zeros(#classes*block, #classes*block) 312 | for target = 1,#classes do 313 | for prediction = 1,#classes do 314 | render[{ { (target-1)*block+1,target*block }, { (prediction-1)*block+1,prediction*block } }] = confusion[order[target]][order[prediction]] 315 | end 316 | end 317 | 318 | -- add grid 319 | for target = 1,#classes do 320 | render[{ {target*block},{} }] = 0.1 321 | render[{ {},{target*block} }] = 0.1 322 | end 323 | 324 | -- create rendering 325 | require 'image' 326 | require 'qtwidget' 327 | require 'qttorch' 328 | local win1 = qtwidget.newimage( (#render)[2]+legendwidth, (#render)[1] ) 329 | image.display{image=render, win=win1} 330 | 331 | -- add legend 332 | for i in ipairs(classes) do 333 | -- background cell 334 | win1:setcolor{r=0,g=0,b=0} 335 | win1:rectangle((#render)[2],(i-1)*block,legendwidth,block) 336 | win1:fill() 337 | 338 | -- % 339 | win1:setfont(qt.QFont{serif=false, size=fontsize}) 340 | local gscale = freqs[order[i]]/freqs:max()*0.9+0.1 --3/4 341 | win1:setcolor{r=gscale*0.5+0.2,g=gscale*0.5+0.2,b=gscale*0.8+0.2} 342 | win1:moveto((#render)[2]+10,i*block-block/3) 343 | win1:show(string.format('[%2.2f%% labels]',math.floor(freqs[order[i]]*10000+0.5)/100)) 344 | 345 | -- legend 346 | win1:setfont(qt.QFont{serif=false, size=fontsize}) 347 | local gscale = diag[order[i]]*0.8+0.2 348 | win1:setcolor{r=gscale,g=gscale,b=gscale} 349 | win1:moveto(120+(#render)[2]+10,i*block-block/3) 350 | win1:show(classes[order[i]]) 351 | 352 | for j in ipairs(classes) do 353 | -- scores 354 | local score = confusion[order[j]][order[i]] 355 | local gscale = (1-score)*(score*0.8+0.2) 356 | win1:setcolor{r=gscale,g=gscale,b=gscale} 357 | win1:moveto((i-1)*block+block/5,(j-1)*block+block*2/3) 358 | win1:show(string.format('%02.0f',math.floor(score*100+0.5))) 359 | end 360 | end 361 | 362 | -- generate tensor 363 | local t = win1:image():toTensor() 364 | 365 | -- display 366 | if display then 367 | image.display{image=t, legend=string.format(legend[sortmode],accuracy,perclass)} 368 | end 369 | 370 | -- return rendering 371 | return t 372 | end 373 | -------------------------------------------------------------------------------- /code/MyConfusionMatrix.lua: -------------------------------------------------------------------------------- 1 | --[[ A Confusion Matrix class 2 | 3 | Example: 4 | 5 | conf = optim.ConfusionMatrix( {'cat','dog','person'} ) -- new matrix 6 | conf:zero() -- reset matrix 7 | for i = 1,N do 8 | conf:add( neuralnet:forward(sample), label ) -- accumulate errors 9 | end 10 | print(conf) -- print matrix 11 | image.display(conf:render()) -- render matrix 12 | ]] 13 | --local ConfusionMatrix = torch.class('optim.MyConfusionMatrix') 14 | 15 | local ConfusionMatrix = torch.class('ConfusionMatrix') 16 | 17 | function ConfusionMatrix:__init(nclasses, classes) 18 | if type(nclasses) == 'table' then 19 | classes = nclasses 20 | nclasses = #classes 21 | end 22 | self.mat = torch.FloatTensor(nclasses,nclasses):zero() 23 | self.valids = torch.FloatTensor(nclasses):zero() 24 | self.unionvalids = torch.FloatTensor(nclasses):zero() 25 | self.nclasses = nclasses 26 | self.totalValid = 0 27 | self.averageValid = 0 28 | self.classes = classes or {} 29 | end 30 | 31 | function ConfusionMatrix:add(prediction, target) 32 | if type(prediction) == 'number' then 33 | -- comparing numbers 34 | self.mat[target][prediction] = self.mat[target][prediction] + 1 35 | elseif type(target) == 'number' then 36 | -- prediction is a vector, then target assumed to be an index 37 | self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses) 38 | self.prediction_1d:copy(prediction) 39 | local _,prediction = self.prediction_1d:max(1) 40 | self.mat[target][prediction[1]] = self.mat[target][prediction[1]] + 1 41 | else 42 | -- both prediction and target are vectors 43 | self.prediction_1d = self.prediction_1d or torch.FloatTensor(self.nclasses) 44 | self.prediction_1d:copy(prediction) 45 | self.target_1d = self.target_1d or torch.FloatTensor(self.nclasses) 46 | self.target_1d:copy(target) 47 | local _,prediction = self.prediction_1d:max(1) 48 | local _,target = self.target_1d:max(1) 49 | self.mat[target[1]][prediction[1]] = self.mat[target[1]][prediction[1]] + 1 50 | end 51 | end 52 | 53 | function ConfusionMatrix:batchAdd(predictions, targets) 54 | local preds, targs, __ 55 | if predictions:dim() == 1 then 56 | -- predictions is a vector of classes 57 | preds = predictions 58 | elseif predictions:dim() == 2 then 59 | -- prediction is a matrix of class likelihoods 60 | if predictions:size(2) == 1 then 61 | -- or prediction just needs flattening 62 | preds = predictions:select(2,1) 63 | else 64 | __,preds = predictions:max(2) 65 | preds:resize(preds:size(1)) 66 | end 67 | else 68 | error("predictions has invalid number of dimensions") 69 | end 70 | 71 | if targets:dim() == 1 then 72 | -- targets is a vector of classes 73 | targs = targets 74 | elseif targets:dim() == 2 then 75 | -- targets is a matrix of one-hot rows 76 | if targets:size(2) == 1 then 77 | -- or targets just needs flattening 78 | targs = targets:select(2,1) 79 | else 80 | __,targs = targets:max(2) 81 | targs:resize(targs:size(1)) 82 | end 83 | else 84 | error("targets has invalid number of dimensions") 85 | end 86 | --loop over each pair of indices 87 | for i = 1,preds:size(1) do 88 | self.mat[targs[i]][preds[i]] = self.mat[targs[i]][preds[i]] + 1 89 | end 90 | end 91 | 92 | function ConfusionMatrix:zero() 93 | self.mat:zero() 94 | self.valids:zero() 95 | self.unionvalids:zero() 96 | self.totalValid = 0 97 | self.averageValid = 0 98 | end 99 | 100 | local function isNaN(number) 101 | return number ~= number 102 | end 103 | 104 | local function remNaN(x,self) 105 | for i = 1, self.nclasses do 106 | if isNaN(x[{1,i}]) then 107 | x[{1,i}] = 0 108 | end 109 | end 110 | return x 111 | end 112 | 113 | 114 | local function getErrors(self) 115 | local tp, fn, fp, tn 116 | tp = torch.diag(self.mat):resize(1,self.nclasses ) 117 | fn = (torch.sum(self.mat,2)-torch.diag(self.mat)):t() 118 | fp = torch.sum(self.mat,1)-torch.diag(self.mat) 119 | tn = torch.Tensor(1,self.nclasses):fill(torch.sum(self.mat)):typeAs(tp) - tp - fn - fp 120 | 121 | return tp, tn, fp, fn 122 | end 123 | 124 | 125 | function ConfusionMatrix:getConfusion() 126 | return getErrors(self) 127 | end 128 | 129 | function ConfusionMatrix:printscore(type,mytitle) 130 | local title,score,class_app,class,val 131 | 132 | if type == "mcc" then 133 | score = self:matthewsCorrelation() 134 | title = "Matthew Correlation" 135 | -- elseif type == "accuracy" then 136 | -- local score = self:accuracy() 137 | -- local title = "Matthew Correlation" 138 | elseif type == 'acc' then 139 | score = self:classAccuracy() 140 | title = "Class accuracies" 141 | else 142 | print("print funct not implemented") 143 | error() 144 | end 145 | 146 | if mytitle then 147 | title = mytitle..": "..title 148 | end 149 | 150 | 151 | local ln = "|" 152 | local ls = "|" 153 | for i = 1,self.nclasses do 154 | 155 | val = string.format("%.4f", score[{1,i}]) 156 | class = self.classes[i] 157 | class_app = math.max(1,4-math.floor(#class / 2)) 158 | class = string.rep(" ",class_app)..class..string.rep(" ",class_app+1-#class%2) 159 | 160 | ln = ln..class.."|" 161 | ls =ls.." "..val 162 | if (#ls+1) < #ln then 163 | ls = ls .. string.rep(" ",#ln-#ls-1) 164 | end 165 | ls = ls .."|" 166 | 167 | end 168 | local line = string.rep("-",#ln) 169 | ln = ln.."\n"..line.."\n"..ls 170 | print(line) 171 | print(string.rep(" ",math.min(0,math.floor(#ls/2)-math.floor(#title/2) ))..title) 172 | print(line) 173 | print(ln) 174 | print(line) 175 | end 176 | 177 | function ConfusionMatrix:accuracy() 178 | -- parse matrix / normalize / count scores 179 | local diag = torch.FloatTensor(self.nclasses) 180 | local freqs = torch.FloatTensor(self.nclasses) 181 | --local unconf = confusion 182 | local confusion = self.mat:clone() 183 | local corrects = 0 184 | local total = 0 185 | for target = 1,self.nclasses do 186 | freqs[target] = confusion[target]:sum() 187 | corrects = corrects + confusion[target][target] 188 | total = total + freqs[target] 189 | confusion[target]:div( math.max(confusion[target]:sum(),1) ) 190 | diag[target] = confusion[target][target] 191 | end 192 | 193 | -- accuracies 194 | local accuracy = corrects / total 195 | return accuracy 196 | end 197 | 198 | function ConfusionMatrix:matthewsCorrelation() 199 | local mcc,numerator, denominator 200 | tp, tn, fp, fn = getErrors(self) 201 | numerator = torch.cmul(tp,tn) - torch.cmul(fp,fn) 202 | denominator = torch.sqrt((tp+fp):cmul(tp+fn):cmul(tn+fp):cmul(tn+fn)) 203 | mcc = torch.cdiv(numerator,denominator) 204 | mcc = remNaN(mcc,self) 205 | return mcc 206 | end 207 | 208 | function ConfusionMatrix:sensitivity() 209 | tp, tn, fp, fn = getErrors(self) 210 | res = torch.cdiv(tp, tp + fn ) 211 | res = remNaN(res,self) 212 | return res -- TP / (TP + FN) 213 | end 214 | 215 | function ConfusionMatrix:specificity() 216 | tp, tn, fp, fn = getErrors(self) 217 | res = torch.cdiv(tn, tn + fp) -- TN / (TN + FP) 218 | res = remNaN(res,self) 219 | return res -- TP / (TP + FN) 220 | 221 | end 222 | 223 | function ConfusionMatrix:positivePredictiveValue() 224 | tp, tn, fp, fn = getErrors(self) 225 | res = torch.cdiv(tp, tp + fp ) -- TP / (TP + FP) 226 | res = remNaN(res,self) 227 | return res -- TP / (TP + FN) 228 | end 229 | 230 | function ConfusionMatrix:negativePredictiveValue() 231 | tp, tn, fp, fn = getErrors(self) 232 | res = torch.cdiv(tn, tn + fn ) -- TN / (TN + FN) 233 | res = remNaN(res,self) 234 | return res -- TP / (TP + FN) 235 | end 236 | 237 | function ConfusionMatrix:falsePositiveRate() 238 | tp, tn, fp, fn = getErrors(self) 239 | res = torch.cdiv(fp, fp + tn) -- FP / (FP + TN) 240 | res = remNaN(res,self) 241 | return res -- TP / (TP + FN) 242 | end 243 | 244 | function ConfusionMatrix:falseDiscoveryRate() 245 | tp, tn, fp, fn = getErrors(self) 246 | res = torch.cdiv(fp, tp + fp) -- FP / (TP + FP) 247 | res = remNaN(res,self) 248 | return res -- TP / (TP + FN) 249 | end 250 | 251 | function ConfusionMatrix:classAccuracy() 252 | tp, tn, fp, fn = getErrors(self) 253 | res = torch.cdiv(tp + tn, tp + tn + fp + fn) -- (TP + FN) / (TN + TP + FN + FP) 254 | res = remNaN(res,self) 255 | return res -- TP / (TP + FN) 256 | end 257 | 258 | function ConfusionMatrix:F1() 259 | tp, tn, fp, fn = getErrors(self) 260 | res = torch.cdiv(tp * 2, tp * 2 + fp + fn) -- (2*TP)/(TP*2+fp+fn) 261 | res = remNaN(res,self) 262 | return res -- TP / (TP + FN) 263 | end 264 | 265 | function ConfusionMatrix:updateValids() 266 | local total = 0 267 | for t = 1,self.nclasses do 268 | self.valids[t] = self.mat[t][t] / self.mat:select(1,t):sum() 269 | self.unionvalids[t] = self.mat[t][t] / (self.mat:select(1,t):sum()+self.mat:select(2,t):sum()-self.mat[t][t]) 270 | total = total + self.mat[t][t] 271 | end 272 | self.totalValid = total / self.mat:sum() 273 | self.averageValid = 0 274 | self.averageUnionValid = 0 275 | local nvalids = 0 276 | local nunionvalids = 0 277 | for t = 1,self.nclasses do 278 | if not isNaN(self.valids[t]) then 279 | self.averageValid = self.averageValid + self.valids[t] 280 | nvalids = nvalids + 1 281 | end 282 | if not isNaN(self.valids[t]) and not isNaN(self.unionvalids[t]) then 283 | self.averageUnionValid = self.averageUnionValid + self.unionvalids[t] 284 | nunionvalids = nunionvalids + 1 285 | end 286 | end 287 | self.averageValid = self.averageValid / nvalids 288 | self.averageUnionValid = self.averageUnionValid / nunionvalids 289 | end 290 | 291 | function ConfusionMatrix:__tostring__() 292 | self:updateValids() 293 | local str = {'ConfusionMatrix:\n'} 294 | local nclasses = self.nclasses 295 | table.insert(str, '[') 296 | for t = 1,nclasses do 297 | local pclass = self.valids[t] * 100 298 | pclass = string.format('%2.3f', pclass) 299 | if t == 1 then 300 | table.insert(str, '[') 301 | else 302 | table.insert(str, ' [') 303 | end 304 | for p = 1,nclasses do 305 | table.insert(str, string.format('%8d', self.mat[t][p])) 306 | end 307 | if self.classes and self.classes[1] then 308 | if t == nclasses then 309 | table.insert(str, ']] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n') 310 | else 311 | table.insert(str, '] ' .. pclass .. '% \t[class: ' .. (self.classes[t] or '') .. ']\n') 312 | end 313 | else 314 | if t == nclasses then 315 | table.insert(str, ']] ' .. pclass .. '% \n') 316 | else 317 | table.insert(str, '] ' .. pclass .. '% \n') 318 | end 319 | end 320 | end 321 | table.insert(str, ' + average row correct: ' .. (self.averageValid*100) .. '% \n') 322 | table.insert(str, ' + average rowUcol correct (VOC measure): ' .. (self.averageUnionValid*100) .. '% \n') 323 | table.insert(str, ' + global correct: ' .. (self.totalValid*100) .. '%') 324 | return table.concat(str) 325 | end 326 | 327 | function ConfusionMatrix:render(sortmode, display, block, legendwidth) 328 | -- args 329 | local confusion = self.mat 330 | local classes = self.classes 331 | local sortmode = sortmode or 'score' -- 'score' or 'occurrence' 332 | local block = block or 25 333 | local legendwidth = legendwidth or 200 334 | local display = display or false 335 | 336 | -- legends 337 | local legend = { 338 | ['score'] = 'Confusion matrix [sorted by scores, global accuracy = %0.3f%%, per-class accuracy = %0.3f%%]', 339 | ['occurrence'] = 'Confusiong matrix [sorted by occurences, accuracy = %0.3f%%, per-class accuracy = %0.3f%%]' 340 | } 341 | 342 | -- parse matrix / normalize / count scores 343 | local diag = torch.FloatTensor(#classes) 344 | local freqs = torch.FloatTensor(#classes) 345 | local unconf = confusion 346 | local confusion = confusion:clone() 347 | local corrects = 0 348 | local total = 0 349 | for target = 1,#classes do 350 | freqs[target] = confusion[target]:sum() 351 | corrects = corrects + confusion[target][target] 352 | total = total + freqs[target] 353 | confusion[target]:div( math.max(confusion[target]:sum(),1) ) 354 | diag[target] = confusion[target][target] 355 | end 356 | 357 | -- accuracies 358 | local accuracy = corrects / total * 100 359 | local perclass = 0 360 | local total = 0 361 | for target = 1,#classes do 362 | if confusion[target]:sum() > 0 then 363 | perclass = perclass + diag[target] 364 | total = total + 1 365 | end 366 | end 367 | perclass = perclass / total * 100 368 | freqs:div(unconf:sum()) 369 | 370 | -- sort matrix 371 | if sortmode == 'score' then 372 | _,order = torch.sort(diag,1,true) 373 | elseif sortmode == 'occurrence' then 374 | _,order = torch.sort(freqs,1,true) 375 | else 376 | error('sort mode must be one of: score | occurrence') 377 | end 378 | 379 | -- render matrix 380 | local render = torch.zeros(#classes*block, #classes*block) 381 | for target = 1,#classes do 382 | for prediction = 1,#classes do 383 | render[{ { (target-1)*block+1,target*block }, { (prediction-1)*block+1,prediction*block } }] = confusion[order[target]][order[prediction]] 384 | end 385 | end 386 | 387 | -- add grid 388 | for target = 1,#classes do 389 | render[{ {target*block},{} }] = 0.1 390 | render[{ {},{target*block} }] = 0.1 391 | end 392 | 393 | -- create rendering 394 | require 'image' 395 | require 'qtwidget' 396 | require 'qttorch' 397 | local win1 = qtwidget.newimage( (#render)[2]+legendwidth, (#render)[1] ) 398 | image.display{image=render, win=win1} 399 | 400 | -- add legend 401 | for i in ipairs(classes) do 402 | -- background cell 403 | win1:setcolor{r=0,g=0,b=0} 404 | win1:rectangle((#render)[2],(i-1)*block,legendwidth,block) 405 | win1:fill() 406 | 407 | -- % 408 | win1:setfont(qt.QFont{serif=false, size=fontsize}) 409 | local gscale = freqs[order[i]]/freqs:max()*0.9+0.1 --3/4 410 | win1:setcolor{r=gscale*0.5+0.2,g=gscale*0.5+0.2,b=gscale*0.8+0.2} 411 | win1:moveto((#render)[2]+10,i*block-block/3) 412 | win1:show(string.format('[%2.2f%% labels]',math.floor(freqs[order[i]]*10000+0.5)/100)) 413 | 414 | -- legend 415 | win1:setfont(qt.QFont{serif=false, size=fontsize}) 416 | local gscale = diag[order[i]]*0.8+0.2 417 | win1:setcolor{r=gscale,g=gscale,b=gscale} 418 | win1:moveto(120+(#render)[2]+10,i*block-block/3) 419 | win1:show(classes[order[i]]) 420 | 421 | for j in ipairs(classes) do 422 | -- scores 423 | local score = confusion[order[j]][order[i]] 424 | local gscale = (1-score)*(score*0.8+0.2) 425 | win1:setcolor{r=gscale,g=gscale,b=gscale} 426 | win1:moveto((i-1)*block+block/5,(j-1)*block+block*2/3) 427 | win1:show(string.format('%02.0f',math.floor(score*100+0.5))) 428 | end 429 | end 430 | 431 | -- generate tensor 432 | local t = win1:image():toTensor() 433 | 434 | -- display 435 | if display then 436 | image.display{image=t, legend=string.format(legend[sortmode],accuracy,perclass)} 437 | end 438 | 439 | -- return rendering 440 | return t 441 | end 442 | -------------------------------------------------------------------------------- /code/ProFi.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | ProFi v1.3, by Luke Perkin 2012. MIT Licence http://www.opensource.org/licenses/mit-license.php. 3 | 4 | Example: 5 | ProFi = require 'ProFi' 6 | ProFi:start() 7 | some_function() 8 | another_function() 9 | coroutine.resume( some_coroutine ) 10 | ProFi:stop() 11 | ProFi:writeReport( 'MyProfilingReport.txt' ) 12 | 13 | API: 14 | *Arguments are specified as: type/name/default. 15 | ProFi:start( string/once/nil ) 16 | ProFi:stop() 17 | ProFi:checkMemory( number/interval/0, string/note/'' ) 18 | ProFi:writeReport( string/filename/'ProFi.txt' ) 19 | ProFi:reset() 20 | ProFi:setHookCount( number/hookCount/0 ) 21 | ProFi:setGetTimeMethod( function/getTimeMethod/os.clock ) 22 | ProFi:setInspect( string/methodName, number/levels/1 ) 23 | ]] 24 | 25 | ----------------------- 26 | -- Locals: 27 | ----------------------- 28 | 29 | local ProFi = {} 30 | local onDebugHook, sortByDurationDesc, sortByCallCount, getTime 31 | local DEFAULT_DEBUG_HOOK_COUNT = 0 32 | local FORMAT_HEADER_LINE = "| %-50s: %-40s: %-20s: %-12s: %-12s: %-12s|\n" 33 | local FORMAT_OUTPUT_LINE = "| %s: %-12s: %-12s: %-12s|\n" 34 | local FORMAT_INSPECTION_LINE = "> %s: %-12s\n" 35 | local FORMAT_TOTALTIME_LINE = "| TOTAL TIME = %f\n" 36 | local FORMAT_MEMORY_LINE = "| %-20s: %-16s: %-16s| %s\n" 37 | local FORMAT_HIGH_MEMORY_LINE = "H %-20s: %-16s: %-16sH %s\n" 38 | local FORMAT_LOW_MEMORY_LINE = "L %-20s: %-16s: %-16sL %s\n" 39 | local FORMAT_TITLE = "%-50.50s: %-40.40s: %-20s" 40 | local FORMAT_LINENUM = "%4i" 41 | local FORMAT_TIME = "%04.3f" 42 | local FORMAT_RELATIVE = "%03.2f%%" 43 | local FORMAT_COUNT = "%7i" 44 | local FORMAT_KBYTES = "%7i Kbytes" 45 | local FORMAT_MBYTES = "%7.1f Mbytes" 46 | local FORMAT_MEMORY_HEADER1 = "\n=== HIGH & LOW MEMORY USAGE ===============================\n" 47 | local FORMAT_MEMORY_HEADER2 = "=== MEMORY USAGE ==========================================\n" 48 | local FORMAT_BANNER = [[ 49 | ############################################################################################################### 50 | ##### ProFi, a lua profiler. This profile was generated on: %s 51 | ##### ProFi is created by Luke Perkin 2012 under the MIT Licence, www.locofilm.co.uk 52 | ##### Version 1.3. Get the most recent version at this gist: https://gist.github.com/2838755 53 | ############################################################################################################### 54 | 55 | ]] 56 | 57 | ----------------------- 58 | -- Public Methods: 59 | ----------------------- 60 | 61 | --[[ 62 | Starts profiling any method that is called between this and ProFi:stop(). 63 | Pass the parameter 'once' to so that this methodis only run once. 64 | Example: 65 | ProFi:start( 'once' ) 66 | ]] 67 | function ProFi:start( param ) 68 | if param == 'once' then 69 | if self:shouldReturn() then 70 | return 71 | else 72 | self.should_run_once = true 73 | end 74 | end 75 | self.has_started = true 76 | self.has_finished = false 77 | self:resetReports( self.reports ) 78 | self:startHooks() 79 | self.startTime = getTime() 80 | end 81 | 82 | --[[ 83 | Stops profiling. 84 | ]] 85 | function ProFi:stop() 86 | if self:shouldReturn() then 87 | return 88 | end 89 | self.stopTime = getTime() 90 | self:stopHooks() 91 | self.has_finished = true 92 | end 93 | 94 | function ProFi:checkMemory( interval, note ) 95 | local time = getTime() 96 | local interval = interval or 0 97 | if self.lastCheckMemoryTime and time < self.lastCheckMemoryTime + interval then 98 | return 99 | end 100 | self.lastCheckMemoryTime = time 101 | local memoryReport = { 102 | ['time'] = time; 103 | ['memory'] = collectgarbage('count'); 104 | ['note'] = note or ''; 105 | } 106 | table.insert( self.memoryReports, memoryReport ) 107 | self:setHighestMemoryReport( memoryReport ) 108 | self:setLowestMemoryReport( memoryReport ) 109 | end 110 | 111 | --[[ 112 | Writes the profile report to a file. 113 | Param: [filename:string:optional] defaults to 'ProFi.txt' if not specified. 114 | ]] 115 | function ProFi:writeReport( filename ) 116 | if #self.reports > 0 or #self.memoryReports > 0 then 117 | filename = filename or 'ProFi.txt' 118 | self:sortReportsWithSortMethod( self.reports, self.sortMethod ) 119 | self:writeReportsToFilename( filename ) 120 | print( string.format("[ProFi]\t Report written to %s", filename) ) 121 | end 122 | end 123 | 124 | --[[ 125 | Resets any profile information stored. 126 | ]] 127 | function ProFi:reset() 128 | self.reports = {} 129 | self.reportsByTitle = {} 130 | self.memoryReports = {} 131 | self.highestMemoryReport = nil 132 | self.lowestMemoryReport = nil 133 | self.has_started = false 134 | self.has_finished = false 135 | self.should_run_once = false 136 | self.lastCheckMemoryTime = nil 137 | self.hookCount = self.hookCount or DEFAULT_DEBUG_HOOK_COUNT 138 | self.sortMethod = self.sortMethod or sortByDurationDesc 139 | self.inspect = nil 140 | end 141 | 142 | --[[ 143 | Set how often a hook is called. 144 | See http://pgl.yoyo.org/luai/i/debug.sethook for information. 145 | Param: [hookCount:number] if 0 ProFi counts every time a function is called. 146 | if 2 ProFi counts every other 2 function calls. 147 | ]] 148 | function ProFi:setHookCount( hookCount ) 149 | self.hookCount = hookCount 150 | end 151 | 152 | --[[ 153 | Set how the report is sorted when written to file. 154 | Param: [sortType:string] either 'duration' or 'count'. 155 | 'duration' sorts by the time a method took to run. 156 | 'count' sorts by the number of times a method was called. 157 | ]] 158 | function ProFi:setSortMethod( sortType ) 159 | if sortType == 'duration' then 160 | self.sortMethod = sortByDurationDesc 161 | elseif sortType == 'count' then 162 | self.sortMethod = sortByCallCount 163 | end 164 | end 165 | 166 | --[[ 167 | By default the getTime method is os.clock (CPU time), 168 | If you wish to use other time methods pass it to this function. 169 | Param: [getTimeMethod:function] 170 | ]] 171 | function ProFi:setGetTimeMethod( getTimeMethod ) 172 | getTime = getTimeMethod 173 | end 174 | 175 | --[[ 176 | Allows you to inspect a specific method. 177 | Will write to the report a list of methods that 178 | call this method you're inspecting, you can optionally 179 | provide a levels parameter to traceback a number of levels. 180 | Params: [methodName:string] the name of the method you wish to inspect. 181 | [levels:number:optional] the amount of levels you wish to traceback, defaults to 1. 182 | ]] 183 | function ProFi:setInspect( methodName, levels ) 184 | if self.inspect then 185 | self.inspect.methodName = methodName 186 | self.inspect.levels = levels or 1 187 | else 188 | self.inspect = { 189 | ['methodName'] = methodName; 190 | ['levels'] = levels or 1; 191 | } 192 | end 193 | end 194 | 195 | ----------------------- 196 | -- Implementations methods: 197 | ----------------------- 198 | 199 | function ProFi:shouldReturn( ) 200 | return self.should_run_once and self.has_finished 201 | end 202 | 203 | function ProFi:getFuncReport( funcInfo ) 204 | local title = self:getTitleFromFuncInfo( funcInfo ) 205 | local funcReport = self.reportsByTitle[ title ] 206 | if not funcReport then 207 | funcReport = self:createFuncReport( funcInfo ) 208 | self.reportsByTitle[ title ] = funcReport 209 | table.insert( self.reports, funcReport ) 210 | end 211 | return funcReport 212 | end 213 | 214 | function ProFi:getTitleFromFuncInfo( funcInfo ) 215 | local name = funcInfo.name or 'anonymous' 216 | local source = funcInfo.short_src or 'C_FUNC' 217 | local linedefined = funcInfo.linedefined or 0 218 | linedefined = string.format( FORMAT_LINENUM, linedefined ) 219 | return string.format(FORMAT_TITLE, source, name, linedefined) 220 | end 221 | 222 | function ProFi:createFuncReport( funcInfo ) 223 | local name = funcInfo.name or 'anonymous' 224 | local source = funcInfo.source or 'C Func' 225 | local linedefined = funcInfo.linedefined or 0 226 | local funcReport = { 227 | ['title'] = self:getTitleFromFuncInfo( funcInfo ); 228 | ['count'] = 0; 229 | ['timer'] = 0; 230 | } 231 | return funcReport 232 | end 233 | 234 | function ProFi:startHooks() 235 | debug.sethook( onDebugHook, 'cr', self.hookCount ) 236 | end 237 | 238 | function ProFi:stopHooks() 239 | debug.sethook() 240 | end 241 | 242 | function ProFi:sortReportsWithSortMethod( reports, sortMethod ) 243 | if reports then 244 | table.sort( reports, sortMethod ) 245 | end 246 | end 247 | 248 | function ProFi:writeReportsToFilename( filename ) 249 | local file, err = io.open( filename, 'w' ) 250 | assert( file, err ) 251 | self:writeBannerToFile( file ) 252 | if #self.reports > 0 then 253 | self:writeProfilingReportsToFile( self.reports, file ) 254 | end 255 | if #self.memoryReports > 0 then 256 | self:writeMemoryReportsToFile( self.memoryReports, file ) 257 | end 258 | file:close() 259 | end 260 | 261 | function ProFi:writeProfilingReportsToFile( reports, file ) 262 | local totalTime = self.stopTime - self.startTime 263 | local totalTimeOutput = string.format(FORMAT_TOTALTIME_LINE, totalTime) 264 | file:write( totalTimeOutput ) 265 | local header = string.format( FORMAT_HEADER_LINE, "FILE", "FUNCTION", "LINE", "TIME", "RELATIVE", "CALLED" ) 266 | file:write( header ) 267 | for i, funcReport in ipairs( reports ) do 268 | local timer = string.format(FORMAT_TIME, funcReport.timer) 269 | local count = string.format(FORMAT_COUNT, funcReport.count) 270 | local relTime = string.format(FORMAT_RELATIVE, (funcReport.timer / totalTime) * 100 ) 271 | local outputLine = string.format(FORMAT_OUTPUT_LINE, funcReport.title, timer, relTime, count ) 272 | file:write( outputLine ) 273 | if funcReport.inspections then 274 | self:writeInpsectionsToFile( funcReport.inspections, file ) 275 | end 276 | end 277 | end 278 | 279 | function ProFi:writeMemoryReportsToFile( reports, file ) 280 | file:write( FORMAT_MEMORY_HEADER1 ) 281 | self:writeHighestMemoryReportToFile( file ) 282 | self:writeLowestMemoryReportToFile( file ) 283 | file:write( FORMAT_MEMORY_HEADER2 ) 284 | for i, memoryReport in ipairs( reports ) do 285 | local outputLine = self:formatMemoryReportWithFormatter( memoryReport, FORMAT_MEMORY_LINE ) 286 | file:write( outputLine ) 287 | end 288 | end 289 | 290 | function ProFi:writeHighestMemoryReportToFile( file ) 291 | local memoryReport = self.highestMemoryReport 292 | local outputLine = self:formatMemoryReportWithFormatter( memoryReport, FORMAT_HIGH_MEMORY_LINE ) 293 | file:write( outputLine ) 294 | end 295 | 296 | function ProFi:writeLowestMemoryReportToFile( file ) 297 | local memoryReport = self.lowestMemoryReport 298 | local outputLine = self:formatMemoryReportWithFormatter( memoryReport, FORMAT_LOW_MEMORY_LINE ) 299 | file:write( outputLine ) 300 | end 301 | 302 | function ProFi:formatMemoryReportWithFormatter( memoryReport, formatter ) 303 | local time = string.format(FORMAT_TIME, memoryReport.time) 304 | local kbytes = string.format(FORMAT_KBYTES, memoryReport.memory) 305 | local mbytes = string.format(FORMAT_MBYTES, memoryReport.memory/1024) 306 | local outputLine = string.format(formatter, time, kbytes, mbytes, memoryReport.note) 307 | return outputLine 308 | end 309 | 310 | function ProFi:writeBannerToFile( file ) 311 | local banner = string.format(FORMAT_BANNER, os.date()) 312 | file:write( banner ) 313 | end 314 | 315 | function ProFi:writeInpsectionsToFile( inspections, file ) 316 | local inspectionsList = self:sortInspectionsIntoList( inspections ) 317 | file:write('\n==^ INSPECT ^======================================================================================================== COUNT ===\n') 318 | for i, inspection in ipairs( inspectionsList ) do 319 | local line = string.format(FORMAT_LINENUM, inspection.line) 320 | local title = string.format(FORMAT_TITLE, inspection.source, inspection.name, line) 321 | local count = string.format(FORMAT_COUNT, inspection.count) 322 | local outputLine = string.format(FORMAT_INSPECTION_LINE, title, count ) 323 | file:write( outputLine ) 324 | end 325 | file:write('===============================================================================================================================\n\n') 326 | end 327 | 328 | function ProFi:sortInspectionsIntoList( inspections ) 329 | local inspectionsList = {} 330 | for k, inspection in pairs(inspections) do 331 | inspectionsList[#inspectionsList+1] = inspection 332 | end 333 | table.sort( inspectionsList, sortByCallCount ) 334 | return inspectionsList 335 | end 336 | 337 | function ProFi:resetReports( reports ) 338 | for i, report in ipairs( reports ) do 339 | report.timer = 0 340 | report.count = 0 341 | report.inspections = nil 342 | end 343 | end 344 | 345 | function ProFi:shouldInspect( funcInfo ) 346 | return self.inspect and self.inspect.methodName == funcInfo.name 347 | end 348 | 349 | function ProFi:getInspectionsFromReport( funcReport ) 350 | local inspections = funcReport.inspections 351 | if not inspections then 352 | inspections = {} 353 | funcReport.inspections = inspections 354 | end 355 | return inspections 356 | end 357 | 358 | function ProFi:getInspectionWithKeyFromInspections( key, inspections ) 359 | local inspection = inspections[key] 360 | if not inspection then 361 | inspection = { 362 | ['count'] = 0; 363 | } 364 | inspections[key] = inspection 365 | end 366 | return inspection 367 | end 368 | 369 | function ProFi:doInspection( inspect, funcReport ) 370 | local inspections = self:getInspectionsFromReport( funcReport ) 371 | local levels = 5 + inspect.levels 372 | local currentLevel = 5 373 | while currentLevel < levels do 374 | local funcInfo = debug.getinfo( currentLevel, 'nS' ) 375 | if funcInfo then 376 | local source = funcInfo.short_src or '[C]' 377 | local name = funcInfo.name or 'anonymous' 378 | local line = funcInfo.linedefined 379 | local key = source..name..line 380 | local inspection = self:getInspectionWithKeyFromInspections( key, inspections ) 381 | inspection.source = source 382 | inspection.name = name 383 | inspection.line = line 384 | inspection.count = inspection.count + 1 385 | currentLevel = currentLevel + 1 386 | else 387 | break 388 | end 389 | end 390 | end 391 | 392 | function ProFi:onFunctionCall( funcInfo ) 393 | local funcReport = ProFi:getFuncReport( funcInfo ) 394 | funcReport.callTime = getTime() 395 | funcReport.count = funcReport.count + 1 396 | if self:shouldInspect( funcInfo ) then 397 | self:doInspection( self.inspect, funcReport ) 398 | end 399 | end 400 | 401 | function ProFi:onFunctionReturn( funcInfo ) 402 | local funcReport = ProFi:getFuncReport( funcInfo ) 403 | if funcReport.callTime then 404 | funcReport.timer = funcReport.timer + (getTime() - funcReport.callTime) 405 | end 406 | end 407 | 408 | function ProFi:setHighestMemoryReport( memoryReport ) 409 | if not self.highestMemoryReport then 410 | self.highestMemoryReport = memoryReport 411 | else 412 | if memoryReport.memory > self.highestMemoryReport.memory then 413 | self.highestMemoryReport = memoryReport 414 | end 415 | end 416 | end 417 | 418 | function ProFi:setLowestMemoryReport( memoryReport ) 419 | if not self.lowestMemoryReport then 420 | self.lowestMemoryReport = memoryReport 421 | else 422 | if memoryReport.memory < self.lowestMemoryReport.memory then 423 | self.lowestMemoryReport = memoryReport 424 | end 425 | end 426 | end 427 | 428 | ----------------------- 429 | -- Local Functions: 430 | ----------------------- 431 | 432 | getTime = os.clock 433 | 434 | onDebugHook = function( hookType ) 435 | local funcInfo = debug.getinfo( 2, 'nS' ) 436 | if hookType == "call" then 437 | ProFi:onFunctionCall( funcInfo ) 438 | elseif hookType == "return" then 439 | ProFi:onFunctionReturn( funcInfo ) 440 | end 441 | end 442 | 443 | sortByDurationDesc = function( a, b ) 444 | return a.timer > b.timer 445 | end 446 | 447 | sortByCallCount = function( a, b ) 448 | return a.count > b.count 449 | end 450 | 451 | ----------------------- 452 | -- Return Module: 453 | ----------------------- 454 | 455 | ProFi:reset() 456 | return ProFi -------------------------------------------------------------------------------- /code/conv-functions.lua: -------------------------------------------------------------------------------- 1 | local conv = {} 2 | 3 | function conv.generativestatistics(rbm,x,y,tcwx) 4 | assert(isRowVec(x)) 5 | assert(x:size(2) == rbm.n_visible) 6 | 7 | if rbm.toprbm then 8 | assert(isRowVec(y)) 9 | assert(y:size(2) == rbm.n_classes) 10 | end 11 | -- tcwx will be nil 12 | local h0,h0_rnd,vkx,vky,vkx_rnd,vky_rnd,hk 13 | h0 = rbm.up(rbm,x,y,drop) 14 | 15 | if rbm.dropout > 0 then 16 | h0:cmul(rbm.dropout_mask) 17 | drop = 1 18 | end 19 | 20 | h0_rnd = rbm.hidsampler(h0,rbm.rand) 21 | 22 | if rbm.dropout > 0 then 23 | h0_rnd:cmul(rbm.dropout_mask) -- Apply dropout on p(h|v) 24 | end 25 | 26 | vkx = rbm.downx(rbm,h0_rnd) 27 | 28 | if rbm.toprbm then 29 | vky = rbm.downy(rbm,h0_rnd) 30 | vky_rnd = samplevec( vky, rbm.rand) 31 | else 32 | vky,vky_rnd = {},{} 33 | end 34 | vkx_rnd = rbm.visxsampler(vkx,rbm.rand) 35 | hk = rbm.up(rbm,vkx,vky_rnd,drop) -- Why not vkx_RND????? 36 | 37 | local stat = {} 38 | stat.h0 = h0 39 | stat.h0_rnd = h0_rnd 40 | stat.hk = hk 41 | stat.vkx = vkx 42 | stat.vkx_rnd = vkx_rnd 43 | 44 | if rbm.toprbm then 45 | stat.vky = vky 46 | stat.vky_rnd = vky_rnd 47 | end 48 | 49 | return stat 50 | end 51 | 52 | function conv.creategenerativegrads(sizes) 53 | local fw = sizes.filter_size * sizes.filter_size 54 | local fs = sizes.filter_size 55 | local is = sizes.input_size 56 | local nf = sizes.n_filters 57 | local ni = sizes.n_input 58 | local ns = sizes.n_visible 59 | local nc = sizes.n_classes 60 | 61 | -- Setup networks 62 | local nn_dW_pos, nn_dW_neg 63 | nn_dW_pos = nn.SpatialConvolutionMM(n_input, 1, sizes.hid_w, sizes.hid_h) 64 | nn_dW_neg = nn.SpatialConvolutionMM(n_input, 1, sizes.hid_w, sizes.hid_h) 65 | nn_dW_pos.bias = torch.zeros(nn_dW_pos.bias:size()) 66 | nn_dW_neg.bias = torch.zeros(nn_dW_neg.bias:size()) 67 | 68 | local debug = {} 69 | debug.nn_dW_pos = nn_dW_pos 70 | debug.nn_dW_neg = nn_dW_neg 71 | 72 | -- preallocate dW 73 | --local dW = torch.Tensor(sizes.n_filters,fs,fs) 74 | 75 | local conv_generative_grads = function(rbm,x,y,stat) 76 | local grads = {} 77 | assert(isRowVec(x) and x:size(2) == sizes.n_visible and x:isSameSizeAs(stat.vkx_rnd)) 78 | assert(isRowVec(stat.h0) and stat.h0:size(2) == sizes.n_hidden and stat.h0:isSameSizeAs(stat.hk)) 79 | 80 | if rbm.toprbm then 81 | assert(isRowVec(y) and y:size(2) == sizes.n_classes and y:isSameSizeAs(stat.vky_rnd)) 82 | grads.dd = torch.add(y, -stat.vky_rnd):t() 83 | grads.dU = torch.mm(stat.h0:t(),y):add(-torch.mm(stat.hk:t(),stat.vky_rnd)) 84 | 85 | assert(grads.dU:size(1) == sizes.n_hidden and grads.dU:size(2) == sizes.n_classes) 86 | assert(grads.dd:size(1) == sizes.n_classes and grads.dd:size(2) == 1) 87 | end 88 | 89 | -- shape h0,hk,x, and vkx in multidimensional shapes 90 | local h0_m, hk_m, x_m, vkx_m 91 | h0_m = stat.h0:view(nf, sizes.hid_h, sizes.hid_h) 92 | hk_m = stat.hk:view(nf, sizes.hid_h, sizes.hid_h) 93 | x_m = x:view(1,ni,is,is) 94 | vkx_m = stat.vkx:view(1, ni, is, is) 95 | 96 | 97 | grads.dc = ( h0_m:sum(3):sum(2) - hk_m:sum(3):sum(2) ):view(sizes.n_filters,1) 98 | grads.db = (x_m:sum(4):sum(3) - vkx_m:sum(4):sum(3) ):view(sizes.n_input,1) 99 | 100 | assert(ni == 1) -- Current implementation possibly breaks with more than one input channel, refactor nn_dW_*** weights 101 | local dw_pos,dw_neg,l 102 | grads.dW = torch.Tensor(sizes.n_filters,fs,fs) 103 | for l = 1,nf do 104 | nn_dW_pos.weight = h0_m[{l,{},{}}]:view(ni, sizes.hid_w*sizes.hid_h) 105 | nn_dW_neg.weight = hk_m[{l,{},{}}]:view(ni, sizes.hid_w*sizes.hid_h) 106 | 107 | dW_pos = nn_dW_pos:forward(x_m)--:clone() 108 | dW_neg = nn_dW_neg:forward(vkx_m)--:clone() 109 | grads.dW[{l,{},{}}] = dW_pos - dW_neg 110 | 111 | end 112 | 113 | -- normalize 114 | grads.db:mul(1/(is * is)) 115 | grads.dc:mul(1/(sizes.hid_w*sizes.hid_h)) 116 | grads.dW:mul(1/( (sizes.hid_h - 2 * fs + 2) * (sizes.hid_w - 2 * fs + 2) )) 117 | 118 | -- Reshape gradients 119 | grads.db = grads.db:view(-1,1) -- visbias 120 | grads.dc = grads.dc:view(-1,1) -- hidbias 121 | grads.dW = conv.toFlat(grads.dW) 122 | 123 | assert(isRowVec(grads.dW) and grads.dW:size(2) == sizes.n_W) 124 | assert(grads.db:size(1) == sizes.n_input and grads.db:size(2) == 1) 125 | assert(grads.dc:size(1) == sizes.n_filters and grads.db:size(2) == 1) 126 | 127 | 128 | return grads 129 | end 130 | 131 | return conv_generative_grads,debug 132 | 133 | end 134 | 135 | 136 | conv.calcconvsizes = function(filter_size,n_filters,n_classes,input_size,pool_size,train) 137 | -- calculates sizes of layers etc in a conv RBM 138 | 139 | -- NUMBER OF HIDDEN UNITS 140 | -- The number of hidden units is the number of filters times 141 | -- the size of the convolution after the rbmup. 142 | local sizes = {} 143 | sizes.input_size = input_size 144 | sizes.n_input = train.data:size(2) 145 | sizes.n_hidden = math.pow(input_size - filter_size +1,2)*n_filters 146 | sizes.n_visible = train.data:size(3)*sizes.n_input 147 | 148 | sizes.hid_w = input_size - filter_size + 1 149 | sizes.hid_h = input_size - filter_size + 1 150 | 151 | -- store W as a row vector 152 | sizes.n_W = sizes.n_input * n_filters * math.pow(filter_size,2) 153 | sizes.n_U = sizes.n_hidden * n_classes -- for reference 154 | 155 | -- b,c and d are column vectors 156 | sizes.n_b = sizes.n_input -- bias of visible layer 157 | sizes.n_c = n_filters -- bias of hidden layer 158 | sizes.n_d = n_classes 159 | sizes.pad = filter_size -1 -- Zero padding size 160 | sizes.filter_size = filter_size 161 | sizes.n_filters = n_filters 162 | sizes.n_classes = n_classes 163 | sizes.pool_size = pool_size 164 | 165 | -- hidden filter 166 | --sizes.h_filter_w = sizes.hid_h - filter_size+1 - filter_size + 1 167 | --sizes.h_filter_h = sizes.hid_w - filter_size+1 - filter_size + 1 168 | print(sizes) 169 | assert(sizes.hid_w % pool_size == 0) -- otherwise maxpooling fails 170 | return sizes 171 | end 172 | 173 | conv.setupsettings = function(opts,sizes) 174 | -- takes an opts struct and initialize W,U,b,c and d in correct dimensions 175 | -- W is drawn from N(0,10^-6) 176 | -- U is drawn from uniform distribution see rbm-util / initcrbm 177 | -- b,c and d are initialized at zero 178 | -- 179 | -- INPUT 180 | -- sizes : output table from conv.calcConvSizes 181 | -- 182 | -- RETURNS 183 | -- empty - modofies the supplied opts struct 184 | opts.W = initcrbm(1,sizes.n_W,'gauss',-2) 185 | opts.U = initcrbm(sizes.n_hidden,sizes.n_classes,'crbm') 186 | opts.b = torch.zeros(sizes.n_input,1) 187 | opts.c = torch.zeros(sizes.n_filters,1) 188 | opts.d = torch.zeros(sizes.n_classes,1) 189 | opts.n_hidden = sizes.n_hidden 190 | opts.n_visible = sizes.n_visible 191 | opts.n_classes = sizes.n_classes 192 | opts.precalctcwx = 0 -- dont precalc when we use convNETS 193 | 194 | end 195 | 196 | function conv.setupfunctions(rbm,sizes,vistype,usemaxpool) 197 | if vistype == nil then 198 | vistype = 'binary' 199 | end 200 | 201 | 202 | 203 | local conv_rbmup,conv_rbmdownx,conv_rbmdownxgauss,conv_pygivenx,conv_pygivenxdropout,debugupdown = conv.createupdownpygivenx(rbm,sizes,usemaxpool) 204 | local generativegrads,debuggen = conv.creategenerativegrads(sizes) 205 | 206 | rbm.up = conv_rbmup 207 | rbm.downy = rbmdowny 208 | rbm.generativegrads = generativegrads 209 | rbm.generativestatistics = conv.generativestatistics 210 | rbm.pygivenx = conv_pygivenx 211 | rbm.pygivenxdropout = conv_pygivenxdropout 212 | 213 | if vistype == 'binary' then 214 | rbm.downx = conv_rbmdownx 215 | rbm.visxsampler = bernoullisampler -- bernoulli sampler 216 | elseif vistype == 'gauss' then 217 | rbm.downx = conv_rbmdownxgauss 218 | rbm.visxsampler = gausssampler 219 | else 220 | assert('false') 221 | end 222 | 223 | 224 | -- debug is reference to modelup and modeldownx 225 | local debug = {} 226 | debug.modelup = debugupdown.modelup 227 | debug.modeldownx = debugupdown.modeldownx 228 | debug.nn_dW_pos = debuggen.nn_dW_pos 229 | debug.nn_dW_neg = debuggen.nn_dW_neg 230 | return debug 231 | 232 | end 233 | 234 | 235 | function conv.maxPool(x,pool_size) 236 | -- maxpool over several filters 237 | -- INPUTS 238 | -- x : should be a 3d matrix with dimensions 239 | -- [n_filters x filter_size x filter_size] 240 | -- pool_size : size of maxpool 241 | -- 242 | -- RETURN 243 | -- New matrix of x:size() with pooled result + sum of each maxpool 244 | -- 245 | -- SEE [1] H. Lee, R. Grosse, R. Ranganath, and A. Ng, “Convolutional deep 246 | -- belief networks for scalable unsupervised learning of hierarchical 247 | -- representations,” … Mach. Learn., 2009. 248 | local function maxpool(x) 249 | --Calculate exp(x) / [sum(exp(x)) +1] in numerically stable way 250 | local m = torch.max(x) 251 | local exp_x = torch.exp(x - m) 252 | -- normalizer = sum(exp(x)) + 1 in scaled domain 253 | local normalizer = torch.exp(-m) + exp_x:sum() 254 | exp_x:cdiv( torch.Tensor(exp_x:nElement()):fill(normalizer) ) 255 | 256 | 257 | return exp_x 258 | end 259 | 260 | local function maxpoollayer(h,h_pool_res,p_pool_res,pool_size) 261 | -- Performs probabilistic maxpooling. 262 | -- For each block of pool_size x pool_size calculate 263 | -- exp(h_i) / [sum_i(exp(h_i)) + 1] 264 | -- h should be a 2d matrix 265 | --print(hf:size()) 266 | local height = h:size(1) 267 | local width = h:size(2) 268 | --poshidprobs = torch.Tensor(height,width):typeAs(hf) 269 | -- notation h_(i,j) 270 | 271 | local i_pool,j_pool = 0,0 272 | local h_maxpool, p_maxpool 273 | --print("maxpoollayer: ", h_pool_res:size(),p_pool_res:size()) 274 | for i_start = 1,height,pool_size do 275 | j_pool = 0 276 | 277 | i_pool = i_pool + 1 278 | i_end = i_start+pool_size -1 279 | for j_start = 1,width,pool_size do -- columns 280 | j_end = j_start+pool_size -1 281 | j_pool = j_pool + 1 282 | 283 | h_maxpool = maxpool(h[{{i_start,i_end},{j_start,j_end}}]) 284 | p_maxpool = h_maxpool:sum() 285 | 286 | h_pool_res[{{i_start,i_end},{j_start,j_end}}] = h_maxpool 287 | p_pool_res[{i_pool,j_pool}] = p_maxpool 288 | end 289 | end 290 | end 291 | 292 | local n_filter = x:size(1) 293 | local h_filter = x:size(2) 294 | local w_filter = x:size(3) 295 | h_maxpooled = torch.Tensor(n_filter,h_filter,w_filter):typeAs(x) 296 | p_maxpooled = torch.Tensor(n_filters, 297 | h_filter / pool_size, 298 | w_filter / pool_size):typeAs(x) 299 | 300 | 301 | for i = 1, x:size(1) do 302 | maxpoollayer( x[{i,{},{}}], 303 | h_maxpooled[{ i,{},{} }], 304 | p_maxpooled[{ i,{},{} }], 305 | pool_size ) 306 | end 307 | return h_maxpooled,p_maxpooled 308 | end 309 | 310 | conv.flatToDownW = function(W,dest) 311 | -- Convert Weights from RBMUP NN to RBMDOWN format 312 | -- Converts between spatialConvolutionMM format and spatialConvolution 313 | -- format. Furthermore each filter is INVERTED 314 | -- INPUTS 315 | -- W : [1xn] row vector of weights 316 | -- dest : mem reference where result is stored. The size of dest 317 | -- is [n_input x n_filters x filter_size x filter_size] 318 | -- 319 | -- RETURN 320 | -- empty, result is stored in dest 321 | -- 322 | -- The dimensions of the returned 323 | -- to the format used by spatialConvolution 324 | -- INVERTS weights in kernels 325 | function invertWeights(x) 326 | -- invert a mxn matrix 327 | xc = torch.Tensor(x:size()) 328 | xc = xc:view(-1) -- view as vector 329 | idx = xc:nElement() 330 | for i = 1,x:size(1) do --rows 331 | for j = 1,x:size(2) do -- columns 332 | xc[idx] = x[{i,j}] 333 | idx = idx-1 334 | end 335 | end 336 | xc = xc:viewAs(x) 337 | return xc 338 | end 339 | 340 | 341 | -- Change the weiw of the flat mtrix to correct format 342 | -- Wf is [n_input X n_filters X filter_size x filter_size] 343 | local fs = dest:size(4) --filter_size 344 | local ni = dest:size(1) --n_input 345 | local nf = dest:size(2) -- n_filters 346 | 347 | 348 | local Wf = conv.flatTo4D(W,dest) 349 | for input_dim = 1,ni do 350 | for filter_num = 1,nf do 351 | dest[{input_dim,filter_num,{},{}}] = invertWeights(Wf[{input_dim,filter_num,{},{}}]) 352 | end 353 | end 354 | end 355 | 356 | conv.flatTo4D = function(x,sizeas) 357 | -- Convert x to sizeas. Used to convert rowvector x to spatialConv weights 358 | -- 359 | -- INPUTS 360 | -- x : [1xn] row vector of weights 361 | -- sizeas : [n_input x n_filters x filter_size x filter_size] matrix 362 | -- 363 | -- RETURNS 364 | -- X in same view as sizeas. No mem copy 365 | local fs = sizeas:size(4) --filter_size 366 | local ni = sizeas:size(1) 367 | local s = torch.LongStorage({fs*fs,fs*fs*ni,fs,1}) 368 | local sz= sizeas:size() 369 | return torch.Tensor():set(x:storage(), 1, sz,s) 370 | end 371 | 372 | conv.toFlat = function(x) 373 | -- Convert x to row vector (1xn) 374 | local sz = torch.LongStorage({1,x:nElement()}) 375 | local s = torch.LongStorage({x:nElement(),1}) 376 | return torch.Tensor():set(x:storage(), 1, sz,s) 377 | end 378 | 379 | conv.flatTo2D = function(x,sizeas) 380 | -- Convert x to sizeas. Used to convert rowvector x to spatialConvMM weights 381 | -- 382 | -- INPUTS 383 | -- x : [1xn] row vector of weights 384 | -- sizeas : [n_filters x (n_input*filter_size*filter_size)] matrix 385 | -- 386 | -- RETURNS 387 | -- X in same view as sizeas. No mem copy 388 | local sz = sizeas:size() 389 | local s = sizeas:stride() 390 | return torch.Tensor():set(x:storage(), 1, sz,s) 391 | end 392 | 393 | conv.flatToUpW = function(W,dest) 394 | -- Converts to format used by spatialConvolutionMM 395 | -- 396 | -- INPUTS 397 | -- W : row vector of weights 398 | -- dest : mem reference to weights. Should be a 399 | -- [n_filters x (n_input*filter_size*filter_size)] matrix 400 | -- 401 | -- Returns 402 | -- empty, results are stored in dest 403 | -- 404 | W2d = conv.flatTo2D(W,dest) 405 | dest:set(W2d:storage(),1,W2d:size(),W2d:stride()) 406 | 407 | end 408 | 409 | 410 | conv.createupdownpygivenx = function(rbm,sizes,usemaxpool) 411 | local debug = {} 412 | local pad = filter_size - 1 413 | local modelup = nn.Sequential() 414 | modelup:add(nn.Reshape(sizes.n_input,sizes.input_size,sizes.input_size)) 415 | modelup:add(nn.SpatialConvolutionMM(sizes.n_input,sizes.n_filters, 416 | sizes.filter_size,sizes.filter_size)) 417 | --modelup:add(nn.Sigmoid()) 418 | 419 | if usemaxpool == nil then 420 | usemaxpool = true 421 | end 422 | 423 | 424 | local modeldownx = nn.Sequential() 425 | modeldownx:add(nn.Reshape(sizes.n_filters,sizes.hid_h,sizes.hid_w)) 426 | modeldownx:add(nn.SpatialZeroPadding(sizes.pad, sizes.pad, sizes.pad, sizes.pad)) -- pad (filterwidth -1) 427 | modeldownx:add(nn.SpatialConvolution(sizes.n_filters,sizes.n_input, 428 | sizes.filter_size,sizes.filter_size)) 429 | --modeldownx:add(nn.Sigmoid()) 430 | 431 | debug.modelup = modelup 432 | debug.modeldownx = modeldownx 433 | 434 | -- SET TESTING WEIGHTS OF UP MODEL 435 | conv.flatToUpW(rbm.W,modelup.modules[2].weight) 436 | modelup.modules[2].bias = rbm.c:view(-1)-- torch.zeros(n_filters) 437 | 438 | -- Test that the underlying storages are equal 439 | assert(rbm.W:storage() == modelup.modules[2].weight:storage()) 440 | assert(rbm.c:storage() == modelup.modules[2].bias:storage()) 441 | 442 | -- -- SET TESTING WEIGHTS OF DOWNX MODEL 443 | conv.flatToDownW(rbm.W,modeldownx.modules[3].weight) 444 | modeldownx.modules[3].bias = rbm.b:view(-1)--torch.zeros(n_input) 445 | 446 | assert(rbm.b:storage() == modeldownx.modules[3].bias:storage()) 447 | -- The memory between weighs in modeldownx and rbm.W are not shared 448 | -- because of weight inversion 449 | --assert(rbm.W:storage() == modeldownx.modules[3].weight:storage()) 450 | 451 | local pygivenx_conv = function(rbm,x,tcwx_pre_calc) 452 | assert(isRowVec(x)) 453 | assert(x:size(1)*x:size(2) == rbm.n_visible) 454 | --print("pygivenx_conv:I could implement tcwx reuse with a bit of work...") 455 | local F,pyx, mask_expanded 456 | 457 | -- check shared memory and sizes 458 | assert(rbm.W:storage() == modelup.modules[2].weight:storage()) 459 | assert(rbm.c:storage() == modelup.modules[2].bias:storage()) 460 | 461 | tcwx = tcwx_pre_calc or modelup:forward(x) 462 | tcwx = tcwx:view(1,-1) -- to flat representation 463 | 464 | 465 | F = torch.add( rbm.U, torch.mm(tcwx:t(), rbm.one_by_classes) ) 466 | pyx = softplus(F):sum(1) -- p(y|x) logprob 467 | pyx:add(-torch.max(pyx)) -- subtract max for numerical stability 468 | pyx:exp() -- convert to real domain 469 | pyx:mul( ( 1/pyx:sum() )) -- normalize probabilities 470 | 471 | assert(pyx:size(1) == 1 and pyx:size(2) == rbm.n_classes) 472 | return pyx,F 473 | end 474 | 475 | local pygivenxdropout_conv = function(rbm,x,tcwx_pre_calc) 476 | assert(isRowVec(x)) 477 | assert(x:size(1)*x:size(2) == rbm.n_visible) 478 | 479 | --print("pygivenx_conv:I could implement tcwx reuse with a bit of work...") 480 | local tcwx,F,F_softplus,pyx, mask_expanded 481 | assert(tcwx == nil) 482 | 483 | mask_expanded = torch.mm(rbm.dropout_mask:t(), rbm.one_by_classes) 484 | 485 | tcwx = tcwx_pre_calc or modelup:forward(x) 486 | tcwx = tcwx:view(1,-1) -- to flat representation 487 | 488 | F = torch.add( rbm.U, torch.mm(tcwx:t(), rbm.one_by_classes) ) 489 | F:cmul(mask_expanded) -- Apply dropout mask 490 | 491 | F_softplus = softplus(F) 492 | F_softplus:cmul(mask_expanded) -- Apply dropout mask 493 | 494 | pyx = F_softplus:sum(1) -- p(y|x) logprob 495 | pyx:add(-torch.max(pyx)) -- subtract max for numerical stability 496 | pyx:exp() -- convert to real domain 497 | pyx:mul( ( 1/pyx:sum() )) -- normalize probabilities 498 | 499 | assert(pyx:size(1) == 1 and pyx:size(2) == rbm.n_classes) 500 | return pyx,F,mask_expanded 501 | end 502 | 503 | 504 | print("Maxpool in RBMUP: ", usemaxpool) 505 | local rbmup_conv = function(rbm,x,y,drop) 506 | assert(isRowVec(x)) 507 | assert(x:size(1)*x:size(2) == rbm.n_visible) 508 | 509 | local hid_act 510 | 511 | --- calculate contribution from features X 512 | 513 | -- Makes sure that the storages are the same 514 | -- I.e no need to update weights. 515 | assert(rbm.W:storage() == modelup.modules[2].weight:storage()) 516 | assert(rbm.c:storage() == modelup.modules[2].bias:storage()) 517 | hid_act = modelup:forward(x):clone() 518 | -- The output from modelup is [1 x n_filters X hidh X hidw] 519 | -- remove first dimension 520 | 521 | 522 | 523 | -- Max pool calculates hidden act using eq on top p. 4 Lee 2009 524 | if usemaxpool then 525 | hid_act = torch.view(hid_act, hid_act:size(2),hid_act:size(3),hid_act:size(4)) 526 | hid_act,p_act = conv.maxPool(hid_act,sizes.pool_size) 527 | rbm.act_up = p_act:view(1,-1) 528 | else 529 | rbm.act_up = hid_act:view(1,-1) 530 | end 531 | hid_act = hid_act:view(1,-1) --flat view 532 | 533 | -- Calculate contribution from labels Y 534 | 535 | if rbm.toprbm then 536 | assert(isRowVec(y) and y:size(2) == rbm.n_classes) 537 | hid_act:add( torch.mm(y,rbm.U:t()) ) 538 | end 539 | 540 | if drop == 1 then 541 | hid_act:cmul(rbm.dropout_mask) 542 | end 543 | 544 | return hid_act 545 | end 546 | 547 | local rbmdownx_conv = function(rbm,hid_act) 548 | 549 | -- I Need to call flatToDownW because we use inverted weights 550 | assert(rbm.b:storage() == modeldownx.modules[3].bias:storage()) -- biases are shared 551 | conv.flatToDownW(rbm.W,modeldownx.modules[3].weight) 552 | local vis_act = modeldownx:forward(hid_act):clone() 553 | vis_act = vis_act:view(1,-1) -- to flat view 554 | return sigm(vis_act) 555 | end 556 | 557 | local rbmdownxgauss_conv = function(rbm,hid_act) 558 | 559 | -- I Need to call flatToDownW because we use inverted weights 560 | assert(rbm.b:storage() == modeldownx.modules[3].bias:storage()) -- biases are shared 561 | conv.flatToDownW(rbm.W,modeldownx.modules[3].weight) 562 | local vis_act = modeldownx:forward(hid_act):clone() 563 | vis_act = vis_act:view(1,-1) -- to flat view 564 | return vis_act 565 | end 566 | return rbmup_conv,rbmdownx_conv,rbmdownxgauss_conv,pygivenx_conv,pygivenxdropout_conv,debug 567 | end 568 | 569 | 570 | return conv -------------------------------------------------------------------------------- /code/dataset-from-tensor.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'paths' 3 | 4 | datatensor = {} 5 | function datatensor.createDataset(tensor,labels,classes,geometry) 6 | local cdataset = {} 7 | 8 | if tensor:dim() ~= 3 then 9 | print("Tensor dim must be batches X datadim1 X datadim2") 10 | error() 11 | end 12 | 13 | 14 | local tensor = tensor:clone() 15 | local labels = labels:clone() 16 | local dim = tensor:size(2) 17 | --cdataset.classes = classes 18 | 19 | function cdataset:normalize(mean_, std_) 20 | print('Not implemented - see dataset-mnist.lua') 21 | error() 22 | end 23 | function cdataset:normalizeGlobal(mean_, std_) 24 | local std = std_ or tensor:std() 25 | local mean = mean_ or tensor:mean() 26 | tensor:add(-mean) 27 | tensor:mul(1/std) 28 | return mean, std 29 | end 30 | function cdataset:size() 31 | return tensor:size(1) 32 | end 33 | function cdataset:resize(nsamples,start) 34 | start = start or 1 35 | tensor = tensor[{{start,nsamples},{},{} }] 36 | labels = labels[{{start,nsamples},{}}] 37 | end 38 | 39 | function cdataset:toProbability() 40 | tensor:add( tensor:min() ) -- minimum to 0 41 | tensor:mul(1/tensor:max()) -- maximimum to 1 42 | end 43 | 44 | function cdataset:classnames() 45 | return classes 46 | end 47 | 48 | function cdataset:getTensor() 49 | return tensor 50 | end 51 | 52 | function cdataset:getLabels() 53 | return labels 54 | end 55 | 56 | 57 | local currentSample = 1 58 | local nSamples = tensor:size(1) 59 | local currentPerm = torch.randperm(nSamples) 60 | 61 | --iterator over dataset 62 | function cdataset:next() 63 | if currentSample > nSamples then 64 | currentSample = 1 65 | currentPerm = torch.randperm(nSamples) 66 | end 67 | 68 | currentSample = currentSample + 1 69 | return self[ currentPerm[currentSample-1] ] 70 | end 71 | 72 | function cdataset:getcurrentsample() 73 | return currentSample 74 | end 75 | 76 | function cdataset:geometry() 77 | return geometry 78 | end 79 | 80 | setmetatable(cdataset, {__index = function(self,index) 81 | --print(dataset) 82 | local x = tensor[{ index,{},{} }] 83 | local x = x:view(1,geometry[1],geometry[2]) 84 | local y = labels[{index,{}}]:view(-1) 85 | local example = {x, y} 86 | return example 87 | end}) 88 | return cdataset 89 | end 90 | -------------------------------------------------------------------------------- /code/dataset-mnist.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'paths' 3 | 4 | mnist = {} 5 | 6 | mnist.path_remote = 'https://www.dropbox.com/s/zdq98audyn8j845/mnist_lua.tar.gz?dl=0' 7 | mnist.path_dataset = 'mnist-th7' 8 | mnist.path_trainset = paths.concat(mnist.path_dataset, 'train.th7') 9 | mnist.path_testset = paths.concat(mnist.path_dataset, 'test.th7') 10 | 11 | function mnist.download() 12 | if not paths.filep(mnist.path_trainset) or not paths.filep(mnist.path_testset) then 13 | local remote = mnist.path_remote 14 | local tar = paths.basename(remote) 15 | os.execute('wget ' .. remote .. '; ' .. 'tar xvf ' .. tar .. '; rm ' .. tar) 16 | end 17 | end 18 | 19 | function mnist.loadTrainSet(maxLoad, geometry,boost) 20 | boost = boost or 'none' 21 | return mnist.loadConvDataset(mnist.path_trainset, maxLoad, geometry,nil,boost,'train') 22 | end 23 | 24 | function mnist.loadTrainAndValSet(geometry,boost) 25 | boost = boost or 'none' 26 | local train = mnist.loadConvDataset(mnist.path_trainset, 60000, geometry,'train',boost) 27 | local val = mnist.loadConvDataset(mnist.path_trainset, 60000, geometry,'val') 28 | return train,val 29 | end 30 | 31 | function mnist.loadTestSet(maxLoad, geometry) 32 | return mnist.loadConvDataset(mnist.path_testset, maxLoad, geometry,nil,nil,'test') 33 | end 34 | 35 | function mnist.loadFlatDataset(fileName, maxLoad,trainOrVal) 36 | mnist.download() 37 | 38 | local f = torch.DiskFile(fileName, 'r') 39 | f:binary() 40 | 41 | local nExample = f:readInt() 42 | local dim = f:readInt() 43 | if maxLoad and maxLoad > 0 and maxLoad < nExample then 44 | nExample = maxLoad 45 | print(' loading only ' .. nExample .. ' examples') 46 | end 47 | print(' reading ' .. nExample .. ' examples with ' .. dim-1 .. '+1 dimensions...') 48 | local tensor = torch.Tensor(nExample, dim) 49 | tensor:storage():copy(f:readFloat(nExample*dim)) 50 | print(' done') 51 | 52 | if trainOrVal then 53 | if trainOrVal == 'train' then 54 | tensor = tensor[{{1,50000},{} }] 55 | elseif trainOrVal == 'val' then 56 | tensor = tensor[{{50001,60000},{} }] 57 | else 58 | print('trainOrVal must be train|val') 59 | error() 60 | end 61 | end 62 | local dataset = {} 63 | dataset.tensor = tensor 64 | 65 | 66 | 67 | function dataset:normalize(mean_, std_) 68 | local data = tensor:narrow(2, 1, dim-1) 69 | local std = std_ or torch.std(data, 1, true) 70 | local mean = mean_ or torch.mean(data, 1) 71 | for i=1,dim-1 do 72 | tensor:select(2, i):add(-mean[1][i]) 73 | if std[1][i] > 0 then 74 | tensor:select(2, i):mul(1/std[1][i]) 75 | end 76 | end 77 | return mean, std 78 | end 79 | 80 | function dataset:toProbability() 81 | local data = tensor:narrow(2, 1, dim-1) 82 | data:mul(1/255) 83 | end 84 | 85 | function dataset:resize(nsamples,start) 86 | start = start or 1 87 | tensor = tensor[{{start,nsamples},{} }] 88 | end 89 | 90 | function dataset:normalizeGlobal(mean_, std_) 91 | local data = tensor:narrow(2, 1, dim-1) 92 | local std = std_ or data:std() 93 | local mean = mean_ or data:mean() 94 | data:add(-mean) 95 | data:mul(1/std) 96 | return mean, std 97 | end 98 | 99 | dataset.dim = dim-1 100 | 101 | function dataset:size() 102 | return tensor:size(1) 103 | end 104 | 105 | local labelvector = torch.zeros(10) 106 | 107 | setmetatable(dataset, {__index = function(self, index) 108 | local input = tensor[index]:narrow(1, 1, dim-1) 109 | local class = tensor[index][dim]+1 110 | local label = labelvector:zero() 111 | label[class] = 1 112 | local example = {input, label} 113 | return example 114 | end}) 115 | 116 | return dataset 117 | end 118 | 119 | function mnist.loadConvDataset(fileName, maxLoad, geometry,trainOrVal,boost,name) 120 | local dataset = mnist.loadFlatDataset(fileName, maxLoad,trainOrVal) 121 | local cdataset = {} 122 | 123 | function cdataset:normalize(m,s) 124 | return dataset:normalize(m,s) 125 | end 126 | function cdataset:normalizeGlobal(m,s) 127 | return dataset:normalizeGlobal(m,s) 128 | end 129 | function cdataset:size() 130 | return dataset:size() 131 | end 132 | function cdataset:resize(nsamples,start) 133 | dataset:resize(nsamples,start) 134 | end 135 | 136 | function cdataset:toProbability() 137 | dataset:toProbability() 138 | end 139 | 140 | function cdataset:classnames() 141 | return {'1','2','3','4','5','6','7','8','9','10'} 142 | end 143 | 144 | 145 | local currentSample = 1 146 | local nSamples = dataset:size() 147 | local currentPerm = torch.randperm(nSamples) 148 | local skipped = 0 149 | 150 | --iterator over dataset 151 | function cdataset:nextboost() 152 | if boost == 'none' then 153 | print('boost is not enabled for this dataset') 154 | error() 155 | end 156 | --if boost == 'diff' then 157 | -- print('add some function that samples based on yprop') 158 | -- print('+add function to get number of indeces check') 159 | -- print('+in sigp set up some heuristic as in schmidhuber - here use som general 1-error prob or whatever') 160 | local sample 161 | if myprob then-- check if yprobs is set, which happens after the first epoch 162 | local choosen = false 163 | while not choosen do 164 | sample = self:next() 165 | if self:getcurrentsample() > self:size() then 166 | print(name,'NUMBER of skipped this pass: ',skipped) 167 | skipped = 0 168 | end 169 | --print(ex[2]) 170 | local _,idx = torch.max(sample[2],1) 171 | 172 | local pred_prob = myprob[{currentSample-1,idx[1]}] 173 | --print(currentSample,correct_prob) 174 | local sampling_prob = 1-pred_prob 175 | if torch.rand(1)[1] < sampling_prob then 176 | choosen = true 177 | else 178 | skipped = skipped +1 179 | end 180 | end 181 | else 182 | 183 | -- myprop or boost is not defined 184 | sample = self:next() 185 | end 186 | 187 | return sample,skipped 188 | 189 | end 190 | 191 | 192 | function cdataset:next() 193 | if currentSample > nSamples then 194 | currentSample = 1 195 | currentPerm = torch.randperm(nSamples) 196 | 197 | end 198 | 199 | currentSample = currentSample + 1 200 | return self[ currentPerm[currentSample-1] ] 201 | end 202 | 203 | function cdataset:getcurrentsample() 204 | return currentSample 205 | end 206 | 207 | function cdataset:geometry() 208 | return geometry 209 | end 210 | 211 | function cdataset:setyprobs(yprops) 212 | myprob = yprops 213 | --print(myprob) 214 | end 215 | 216 | 217 | local iheight = geometry[2] 218 | local iwidth = geometry[1] 219 | local inputpatch = torch.zeros(1, iheight, iwidth) 220 | 221 | setmetatable(cdataset, {__index = function(self,index) 222 | local ex = dataset[index] 223 | local input = ex[1] 224 | local label = ex[2] 225 | local w = math.sqrt(input:nElement()) 226 | local uinput = input:unfold(1,input:nElement(),input:nElement()) 227 | local cinput = uinput:unfold(2,w,w) 228 | local h = cinput:size(2) 229 | local w = cinput:size(3) 230 | local x = math.floor((iwidth-w)/2)+1 231 | local y = math.floor((iheight-h)/2)+1 232 | inputpatch:narrow(3,x,w):narrow(2,y,h):copy(cinput) 233 | local example = {inputpatch, label} 234 | return example 235 | end}) 236 | return cdataset 237 | end 238 | -------------------------------------------------------------------------------- /code/rbm-grads.lua: -------------------------------------------------------------------------------- 1 | grads = {} 2 | -- Calculate generative weights 3 | -- tcwx is tcwx = torch.mm( x,rbm.W:t() ):add( rbm.c:t() ) 4 | function grads.generativestatistics(rbm,x,y,tcwx) 5 | local visx, visy, h0,ch_idx,drop, vkx, vkx_rnd, vky_rnd,hk,vky 6 | local stat = {} 7 | if rbm.toprbm then 8 | h0 = sigm( torch.add(tcwx, torch.mm(y,rbm.U:t() ) ) ) -- UP 9 | else 10 | h0 = sigm( tcwx ) 11 | end 12 | 13 | if rbm.dropout > 0 then 14 | h0:cmul(rbm.dropout_mask) 15 | drop = 1 16 | end 17 | 18 | -- Switch between CD and PCD 19 | if rbm.traintype == 'CD' then -- CD 20 | -- Use training data as start for negative statistics 21 | hid = rbm.hidsampler(h0,rbm.rand) -- sample the hidden derived from training state 22 | elseif rbm.traintype == 'PCD' then 23 | -- use pcd chains as start for negative statistics 24 | ch_idx = math.floor( (torch.rand(1) * rbm.npcdchains)[1]) +1 25 | 26 | local chx = rbm.chx[ch_idx]:resize(1,x:size(2)) 27 | local chy 28 | if rbm.toprbm then 29 | chy = rbm.chy[ch_idx]:resize(1,y:size(2)) 30 | else 31 | chy = {} 32 | end 33 | hid = rbm.hidsampler( rbm.up(rbm, chx, chy, drop), rbm.rand) 34 | elseif rbm.traintype == 'meanfield' then 35 | hid = rbm.hidsampler(h0,rbm.rand) 36 | end 37 | 38 | 39 | 40 | -- If CDn > 1 update chians n-1 times 41 | for i = 1, (rbm.cdn - 1) do 42 | visx = rbm.downx( rbm, hid ) 43 | if rbm.traintype ~= 'meanfield' then 44 | visx = rbm.visxsampler( visx, rbm.rand) 45 | end 46 | 47 | 48 | if rbm.toprbm then 49 | visy = rbm.downy( rbm, hid) 50 | if rbm.traintype ~= 'meanfield' then 51 | visy = samplevec( visy, rbm.rand) 52 | end 53 | else 54 | visy = {} 55 | end 56 | 57 | hid = rbm.up(rbm,visx, visy, drop) 58 | if rbm.traintype ~= 'meanfield' then 59 | hid = rbm.hidsampler( hid, rbm.rand) 60 | end 61 | 62 | 63 | end 64 | 65 | 66 | -- Down-Up dont sample last hiddens, because it introduces noise 67 | -- for meanfield we do not sample 68 | vkx = rbm.downx(rbm,hid) 69 | stat.vkx_unsampled = vkx 70 | if rbm.traintype ~= 'meanfield' then 71 | vkx = rbm.visxsampler(vkx,rbm.rand) 72 | end 73 | if rbm.toprbm then 74 | vky = rbm.downy(rbm,hid) 75 | if rbm.traintype ~= 'meanfield' then 76 | vky = samplevec( vky, rbm.rand) 77 | end 78 | else 79 | vky = {} 80 | end 81 | hk = rbm.up(rbm,vkx,vky,drop) 82 | 83 | -- If PCD: Update status of selected PCD chains 84 | if rbm.traintype == 'PCD' then 85 | rbm.chx[{ ch_idx,{} }] = vkx 86 | 87 | if rbm.toprbm then 88 | rbm.chy[{ ch_idx,{} }] = vky 89 | end 90 | end 91 | 92 | 93 | stat.h0 = h0 94 | --stat.h0_rnd = h0_rnd 95 | stat.hk = hk 96 | stat.vkx = vkx 97 | --stat.vkx_rnd = vkx_rnd 98 | 99 | if rbm.toprbm then 100 | stat.vky = vky 101 | --stat.vky_rnd = vky_rnd 102 | end 103 | return stat 104 | end 105 | 106 | function grads.generativegrads(rbm,x,y,stat) 107 | assert(isRowVec(x)) 108 | local grads = {} 109 | -- Calculate generative gradients 110 | grads.dW = torch.mm(stat.h0:t(),x) :add(-torch.mm(stat.hk:t(),stat.vkx)) 111 | grads.db = torch.add(x, -stat.vkx):t() 112 | grads.dc = torch.add(stat.h0, -stat.hk ):t() 113 | 114 | if rbm.toprbm then 115 | assert(isRowVec(y)) 116 | grads.dU = torch.mm(stat.h0:t(),y):add(-torch.mm(stat.hk:t(),stat.vky)) 117 | grads.dd = torch.add(y, -stat.vky):t() 118 | end 119 | return grads 120 | 121 | end 122 | 123 | -- Calculate discriminative weights 124 | -- tcwx is tcwx = torch.mm( x,rbm.W:t() ):add( rbm.c:t() ) 125 | function grads.discriminativegrads(rbm,x,y,tcwx) 126 | assert(isRowVec(x)) 127 | assert(isRowVec(y)) 128 | --print("kakkakak") 129 | local p_y_given_x, F, mask_expanded,F_sigm, F_sigm_prob,F_sigm_prob_sum,F_sigm_dy 130 | local dW,dU,dc,dd 131 | 132 | -- Switch between dropout version and non dropout version of pygivenx 133 | if rbm.dropout > 0 then 134 | p_y_given_x, F,mask_expanded = rbm.pygivenxdropout(rbm,x,tcwx) 135 | else 136 | p_y_given_x, F = rbm.pygivenx(rbm,x,tcwx) 137 | end 138 | 139 | F_sigm = sigm(F) 140 | 141 | -- Apply dropout mask 142 | if rbm.dropout > 0 then 143 | F_sigm:cmul(mask_expanded) 144 | end 145 | 146 | F_sigm_prob = torch.cmul( F_sigm, torch.mm( rbm.hidden_by_one,p_y_given_x ) ) 147 | F_sigm_prob_sum = F_sigm_prob:sum(2) 148 | F_sigm_dy = torch.mm(F_sigm, y:t()) 149 | 150 | 151 | dW = torch.add( torch.mm(F_sigm_dy, x), -torch.mm(F_sigm_prob_sum,x) ) 152 | dU = torch.add( -F_sigm_prob, torch.cmul(F_sigm, torch.mm( torch.ones(F_sigm_prob:size(1),1),y ) ) ) 153 | dc = torch.add(-F_sigm_prob_sum, F_sigm_dy) 154 | dd = torch.add(y, -p_y_given_x):t() 155 | 156 | local grads = {} 157 | grads.dW = dW 158 | grads.dU = dU 159 | grads.dc = dc 160 | grads.dd = dd 161 | return grads,p_y_given_x 162 | 163 | end 164 | 165 | 166 | function grads.pygivenx(rbm,x,tcwx_pre_calc) 167 | assert(isRowVec(x)) 168 | 169 | local tcwx,F,pyx 170 | tcwx_pre_calc = tcwx_pre_calc or torch.mm( x,rbm.W:t() ):add( rbm.c:t() ) 171 | assert(isRowVec(tcwx_pre_calc)) -- 1xn_hidden 172 | 173 | F = torch.add( rbm.U, torch.mm(tcwx_pre_calc:t(), rbm.one_by_classes) ) 174 | pyx = softplus(F):sum(1) -- p(y|x) logprob 175 | pyx:add(-torch.max(pyx)) -- subtract max for numerical stability 176 | pyx:exp() -- convert to real domain 177 | pyx:mul( ( 1/pyx:sum() )) -- normalize probabilities 178 | 179 | assert(pyx:size(1) == 1 and pyx:size(2) == rbm.n_classes) 180 | return pyx,F 181 | end 182 | 183 | function grads.pygivenxdropout(rbm,x,tcwx_pre_calc) 184 | -- Dropout version of pygivenx 185 | assert(isRowVec(x)) 186 | local tcwx,F,F_softplus,pyx, mask_expanded 187 | mask_expanded = torch.mm(rbm.dropout_mask:t(), rbm.one_by_classes) 188 | tcwx_pre_calc = tcwx_pre_calc or torch.mm( x,rbm.W:t() ):add( rbm.c:t() ) 189 | assert(isRowVec(tcwx_pre_calc)) -- 1xn_hidden 190 | 191 | F = torch.add( rbm.U, torch.mm(tcwx_pre_calc:t(), rbm.one_by_classes) ) 192 | F:cmul(mask_expanded) -- Apply dropout mask 193 | 194 | F_softplus = softplus(F) 195 | F_softplus:cmul(mask_expanded) -- Apply dropout mask 196 | 197 | pyx = F_softplus:sum(1) -- p(y|x) logprob 198 | pyx:add(-torch.max(pyx)) -- subtract max for numerical stability 199 | pyx:exp() -- convert to real domain 200 | pyx:mul( ( 1/pyx:sum() )) -- normalize probabilities 201 | 202 | assert(pyx:size(1) == 1 and pyx:size(2) == rbm.n_classes) 203 | return pyx,F,mask_expanded 204 | end -------------------------------------------------------------------------------- /code/rbm-helpers.lua: -------------------------------------------------------------------------------- 1 | -- small math functions and rbmup/down etc 2 | function sigm(x) 3 | local o = torch.exp(-x):add(1):pow(-1) 4 | return(o) 5 | end 6 | 7 | function normalizeexprowvec(x) 8 | -- Calculate exp(x) / sum(exp(x)) in numerically stable way 9 | -- x is a row vector 10 | exp_x = torch.exp(x - torch.max(x)) 11 | normalizer = torch.mm(exp_x:sum(2), torch.ones(1,x:size(2))) 12 | return exp_x:cdiv( normalizer ) 13 | end 14 | 15 | function softplus(x) 16 | local o = torch.exp(x):add(1):log() 17 | --local o = nn.SoftPlus():forward(x) 18 | return(o) 19 | end 20 | 21 | 22 | function rbmup(rbm,x,y,drop) 23 | -- drop == 1 applies dropout to p(h|v) 24 | assert(isRowVec(x)) 25 | assert(x:size(1)*x:size(2) == rbm.n_visible) 26 | 27 | local act_hid 28 | act_hid = torch.mm(x,rbm.W:t()):add(rbm.c:t()) -- x * rbm.W' + rbm.c' 29 | 30 | if rbm.toprbm then 31 | assert(isRowVec(y) and y:size(2) == rbm.n_classes) 32 | act_hid:add( torch.mm(y,rbm.U:t()) ) 33 | end 34 | act_hid = sigm(act_hid) 35 | 36 | if drop == 1 then 37 | act_hid:cmul(rbm.dropout_mask) 38 | end 39 | 40 | rbm.act_up = act_hid 41 | return act_hid 42 | 43 | end 44 | 45 | 46 | 47 | function rbmdownx(rbm,act_hid) 48 | -- bernoulli units 49 | assert(isRowVec(act_hid)) 50 | 51 | local act_vis_x 52 | --act_vis_x = -- hid_act * rbm.W + rbm.b' 53 | act_vis_x = sigm(torch.mm(act_hid,rbm.W):add(rbm.b:t()) ); 54 | return act_vis_x 55 | end 56 | 57 | 58 | function rbmdowny(rbm,act_hid) 59 | local act_vis_y 60 | act_vis_y = torch.mm( act_hid,rbm.U ):add( rbm.d:t() ) 61 | 62 | act_vis_y = normalizeexprowvec(act_vis_y) 63 | return act_vis_y 64 | end 65 | 66 | -- ##########PRETRAIN FUNCTIONS ################# 67 | -- ## modified for pretraining see Deep boltzmann machines salakhutdinov 2009 sec 3.1 68 | -- ## basically doubles the input 69 | function rbmuppretrain(rbm,x,y,drop) 70 | -- drop == 1 applies dropout to p(h|v) 71 | -- in the code provided they do not double the biases 72 | assert(isRowVec(x)) 73 | assert(x:size(1)*x:size(2) == rbm.n_visible) 74 | 75 | local act_hid 76 | act_hid = torch.mm(x,rbm.W:t()):mul(2) --MODIFIED 77 | act_hid:add(rbm.c:t()) -- x * rbm.W' + rbm.c' 78 | 79 | if rbm.toprbm then 80 | --assert(isRowVec(y) and y:size(2) == rbm.n_classes) 81 | act_hid:add( torch.mm(y,rbm.U:t()):mul(2) ) --MODIFIED 82 | end 83 | 84 | act_hid = sigm(act_hid) 85 | 86 | if drop == 1 then 87 | act_hid:cmul(rbm.dropout_mask) 88 | end 89 | 90 | rbm.act_up = act_hid 91 | return act_hid 92 | 93 | end 94 | 95 | function rbmdownxpretrain(rbm,act_hid) 96 | -- bernoulli units 97 | assert(isRowVec(act_hid)) 98 | 99 | local act_vis_x 100 | --act_vis_x = -- hid_act * rbm.W + rbm.b' 101 | act_vis_x = torch.mm(act_hid,rbm.W):mul(2) 102 | act_vis_x:add(rbm.b:t()) 103 | --act_vis_x:mul(2) -- MODIFICATION 104 | act_vis_x = sigm(act_vis_x ); 105 | return act_vis_x 106 | end 107 | 108 | 109 | function rbmdownypretrain(rbm,act_hid) 110 | local act_vis_y 111 | act_vis_y = torch.mm( act_hid,rbm.U ):mul(2) 112 | act_vis_y:add( rbm.d:t() ) 113 | 114 | act_vis_y = normalizeexprowvec(act_vis_y) 115 | return act_vis_y 116 | end 117 | --###### END PRETRAIN FUNCTIONS 118 | 119 | 120 | 121 | function samplevec(x,ran) 122 | assert(isRowVec(x)) 123 | local r,x_c,larger,sample 124 | r = ran(1,1):expand(x:size()) 125 | x_c = torch.cumsum(x,2) 126 | larger = torch.ge(x_c,r) 127 | sample = torch.eq(torch.cumsum(larger,2),1):typeAs(x) 128 | return sample 129 | end 130 | 131 | function bernoullisampler(dat,ran) 132 | local ret = torch.gt(dat, ran(1,dat:size(2))):typeAs(dat) 133 | return(ret) 134 | end 135 | 136 | function gausssampler(dat,ran) 137 | -- returns ~N(dat,1) 138 | return torch.randn(dat:size()):add(dat) 139 | end 140 | 141 | function classprobs(rbm,x) 142 | 143 | local probs,x_i,p_i 144 | 145 | -- Iter over examples and calculate the class probs 146 | probs = torch.Tensor(x:size(1),rbm.n_classes) 147 | for i = 1, x:size(1) do 148 | x_i =x[i]:resize(1,rbm.n_visible) 149 | 150 | p_i = rbm.pygivenx(rbm,x_i) 151 | probs[{i,{}}] = p_i 152 | end 153 | return(probs) 154 | end 155 | 156 | function predict(rbm,x) 157 | --print(x) 158 | --assert(x:dim() == 3) 159 | --assert(x:size(2)*x:size(3) == rbm.n_visible) 160 | 161 | local probs,_,pred 162 | probs = classprobs(rbm,x) 163 | --print(probs) 164 | -- probs is cases X n_classes 165 | assert(probs:size(1) == x:size(1) and probs:size(2) == rbm.n_classes) 166 | 167 | 168 | vec,pred=torch.max(probs,2) 169 | local n_samples = x:size(1) 170 | local labels_vec = torch.zeros(1,rbm.n_classes):float() 171 | for i =1,n_samples do 172 | pred_idx = pred[{i,1}] 173 | labels_vec[{1, pred_idx }] = labels_vec[{1, pred_idx }] + 1 174 | end 175 | 176 | pred = pred:view(-1) 177 | return pred:typeAs(x),probs 178 | end 179 | 180 | function geterror(rbm,data,errorfunction) 181 | rbm.conf:zero() 182 | local probs = torch.Tensor(data:size(),rbm.n_classes) 183 | for i = 1,data:size() do 184 | local sample = data[i] 185 | local x = sample[1]:view(1,-1) 186 | local _,y_index=torch.max(sample[2],1) 187 | 188 | local x_pred,x_probs 189 | x_pred,x_probs = predict(rbm,x) 190 | probs[{ i,{} }] = x_probs 191 | rbm.conf:add(x_pred[1], y_index[1]) 192 | end 193 | 194 | local err 195 | if errorfunction then 196 | err = errorfunction(rbm.conf) 197 | else 198 | err = rbm.errorfunction(rbm.conf) 199 | end 200 | return err,probs 201 | end 202 | 203 | function rbmuppass(rbm,data,returnLabels) 204 | -- takes a rbm and calculates the activation of hidden units for all samples 205 | -- if returnLabels is non nil the function also returns a array of the labels for the 206 | -- correpsonding class. This functionality can be used to construct datasets 207 | -- not for conv rbms? 208 | 209 | local sample,x,y,hid,up_size,n_samples,labels 210 | n_samples = data:size() 211 | up_size = rbm.act_up:size(2) 212 | hid = torch.Tensor(n_samples,1,up_size) 213 | -- also return array of labels 214 | if returnLabels then 215 | local class_names = trainData:classnames() 216 | labels = torch.Tensor(n_samples,#class_names) 217 | end 218 | 219 | for i = 1,n_samples do 220 | sample = data:next() 221 | x = sample[1]:view(1,-1) 222 | if rbm.toprbm then 223 | y = sample[2]:view(1,-1) 224 | else 225 | y = {} 226 | end 227 | 228 | rbm.up(rbm,x,y,false) -- updates rbm.act_up in rbm false = no dropout 229 | act_hid = rbm.act_up:clone() 230 | hid[{ i,{},{} }] = act_hid:view(1,-1) 231 | 232 | if returnLabels then 233 | labels[{ i,{} }] = sample[2]:view(1,-1) 234 | end 235 | 236 | end 237 | return hid,labels 238 | end 239 | 240 | 241 | function oneOfK(nclasses,labels) 242 | -- If labels are numeric encodes the function returns the 243 | -- lables encoded as one-of-K 244 | local n_classes, n_samples, labels_vec,i 245 | n_samples = labels:size(1) 246 | labels_vec = torch.zeros(n_samples,nclasses) 247 | for i =1,n_samples do 248 | labels_vec[{i, labels[i] }] = 1 249 | end 250 | 251 | return labels_vec 252 | 253 | end -------------------------------------------------------------------------------- /code/rbm-regularization.lua: -------------------------------------------------------------------------------- 1 | regularization = {} 2 | 3 | function regularization.applyregularization(rbm) 4 | if rbm.sparsity > 0 then 5 | -- rbm.db:add(-rbm.sparsity) -- db is bias of visible layer 6 | rbm.dc:add(-rbm.sparsity) -- dc is bias of hidden layer 7 | -- rbm.dd:add(-rbm.sparsity) -- dd is bias of "label" layer 8 | end 9 | 10 | if rbm.L1 > 0 then 11 | rbm.dW:add( -torch.sign(rbm.dW):mul(rbm.L1) ) 12 | 13 | if rbm.toprbm then 14 | rbm.dU:add( -torch.sign(rbm.dU):mul(rbm.L1) ) 15 | end 16 | end 17 | 18 | if rbm.L2 > 0 then 19 | rbm.dW:add( -torch.mul(rbm.dW,rbm.L2 ) ) 20 | 21 | if rbm.toprbm then 22 | rbm.dU:add( -torch.mul(rbm.dU,rbm.L2 ) ) 23 | end 24 | end 25 | end 26 | 27 | 28 | function regularization.dropout(rbm) 29 | -- Create dropout mask for hidden units 30 | if rbm.dropout > 0 then 31 | rbm.dropout_mask = torch.lt( torch.rand(1,rbm.n_hidden),rbm.dropout ):typeAs(rbm.W) 32 | end 33 | end -------------------------------------------------------------------------------- /code/rbm-util.lua: -------------------------------------------------------------------------------- 1 | -- setup, printing functions, saving functions, 2 | 3 | 4 | function rbmconvsetup(settings,train,convopts) 5 | -- handles the ugly setup for a convRBM 6 | local sizes = conv.calcconvsizes(settings.filter_size,settings.n_filters, 7 | settings.n_classes, 8 | settings.input_size, 9 | settings.pool_size, 10 | train) 11 | local opts = {} 12 | 13 | conv.setupsettings(opts,sizes) 14 | 15 | 16 | 17 | 18 | 19 | opts.toprbm = settings.toprbm 20 | -- merge useropts and functions opts. convopts will be 21 | -- overwritten with content in convopts_functions 22 | if convopts ~= nil then 23 | for k,v in pairs(opts) do 24 | if convopts[k] ~= nil then 25 | print("Overwriting settings: ", k,convopts[k], "with", v) 26 | end 27 | convopts[k] = v 28 | end 29 | else 30 | convopts = opts 31 | end 32 | 33 | local rbm = rbmsetup(convopts,train) 34 | local debug = conv.setupfunctions(rbm,sizes,settings.vistype,settings.usemaxpool) 35 | 36 | return rbm,opts,debug 37 | 38 | end 39 | 40 | function trainstackconvtorbm(convsettings,convopts,toprbmopts,train,val) 41 | -- settings.usemaxpool = true 42 | -- settings.vistype = 'binary' 43 | -- settings.filter_size = filter_size 44 | -- settings.n_filters = n_filters 45 | -- settings.n_classes = n_classes 46 | -- settings.input_size = input_size 47 | -- settings.pool_size = pool_size 48 | -- convops any settings that apply to normal rbm. Settings not used in 49 | -- conv or with inferred value will be overwritten 50 | -- rbmopts settings for toprbm 51 | 52 | print("WRITE SOME test where i assert that the correct values are overwritten etc") 53 | print(" and check that the bottom convrbm is a generative model without labels") 54 | print("Skip the ugly convsettings/convopts and just let the user specify a opts settings with the nessesary values") 55 | print("Maybe write a print function for conv rbms") 56 | convsettings.toprbm = false 57 | local convrbm,convopts_functions,convdebug = setupconvrbm(settings,train,convopts) 58 | 59 | 60 | convrbm = rbmtrain(convrbm,train,val) 61 | 62 | local data2 = {} 63 | train2.labels = train.labels:clone() 64 | train2.data = rbmuppass(convrbm,train) 65 | 66 | local val2 = {} 67 | val2.labels = val.labels:clone() 68 | val2.data = rbmuppass(convrbm,val) 69 | 70 | toprbm = rbmsetup(toprbmopts,train2) 71 | toprbm = rbmtrain(train2,val2) 72 | 73 | return convrbm,toprbm 74 | end 75 | 76 | function trainstackrbmrbm() 77 | 78 | end 79 | 80 | function printrbm(rbm,xt,xv,xs) 81 | print("---------------------RBM------------------------------------") 82 | 83 | if xt then print(string.format("Number of trainig samples : %i",xt:size())) end 84 | if xv then print(string.format("Number of validation samples : %i",xv:size())) end 85 | if xs then print(string.format("Number of semi-sup samples : %i",xs:size())) end 86 | 87 | local ttype 88 | if rbm.alpha == 1 then ttype = "GENERATIVE" 89 | elseif rbm.alpha == 0 then ttype = "DISCRIMINATIVE" 90 | elseif rbm.alpha > 0 and rbm.alpha < 1 then ttype = "HYBRID" 91 | else assert(false, "alpha must be numeric between 0 and 1") end 92 | 93 | if rbm.beta > 0 then ttype = ttype .. " + SEMISUP" end 94 | 95 | print(string.format("Training type : %s",ttype)) 96 | print(string.format("Pretraining : %s",rbm.pretrain)) 97 | print(string.format("Top RBM : %s",tostring(rbm.toprbm))) 98 | print(string.format("Number of visible : %i",rbm.n_visible)) 99 | print(string.format("Number of hidden : %i",rbm.n_hidden)) 100 | if rbm.toprbm then print(string.format("Number of classes : %i",rbm.n_classes)) end 101 | print("") 102 | print(string.format("Number of epocs : %i",rbm.numepochs)) 103 | print(string.format("Current epoc : %i",rbm.currentepoch)) 104 | print(string.format("Learning rate : %f",rbm.learningrate)) 105 | print(string.format("Momentum : %f",rbm.momentum)) 106 | print(string.format("alpha : %f",rbm.alpha)) 107 | print(string.format("beta : %f",rbm.beta)) 108 | print(string.format("batchsize : %i",rbm.batchsize)) 109 | print(string.format("Temp file : %s",rbm.tempfile)) 110 | print("") 111 | 112 | 113 | print("TRAINING TYPE") 114 | print(string.format("Type : %s",rbm.traintype)) 115 | print(string.format("Gibbs steps : %i",rbm.cdn)) 116 | if rbm.traintype == 'PCD' then print(string.format("Number of PCD chains : %i",rbm.npcdchains)) end 117 | 118 | print("") 119 | print("REGULARIZATON") 120 | print(string.format("Patience : %i",rbm.patience)) 121 | print(string.format("Sparisty : %f",rbm.sparsity)) 122 | print(string.format("L1 : %f",rbm.L1)) 123 | print(string.format("L2 : %f",rbm.L2)) 124 | print(string.format("DropOut : %f",rbm.dropout)) 125 | print("------------------------------------------------------------") 126 | 127 | end 128 | 129 | function initcrbm(m,n,inittype,std) 130 | -- Creates initial weights. 131 | -- If inittype is 'crbm' then init weights from uniform distribution 132 | -- initilize weigts from uniform distribution. As described in 133 | -- Learning Algorithms for the Classification Restricted Boltzmann 134 | -- machine 135 | -- if inittype is 'gauss' init from N(0,std^2), std defualts to 10^-3 136 | -- If inittype is not specified use 'crbm' 137 | local weights 138 | if inittype == nil then 139 | inittype = 'crbm' 140 | end 141 | 142 | if std == nil then 143 | std = -2 144 | end 145 | 146 | if inittype == 'crbm' then 147 | local M,interval_max, interval_min 148 | M = math.max(m,n); 149 | interval_max = math.pow(M,-0.5); 150 | interval_min = -interval_max; 151 | weights = torch.rand(m,n):mul( interval_min + (interval_max-interval_min) ) 152 | elseif inittype == 'gauss' then 153 | weights = torch.randn(m,n) * math.pow(10,std) 154 | 155 | else 156 | assert(false) -- unknown inittype 157 | end 158 | return weights 159 | end 160 | 161 | 162 | function rbmsetup(opts,train,semisup) 163 | local rbm = {} 164 | 165 | rbm.progress = opts.progress or 1 166 | 167 | --assert(train.data:dim() == 3) 168 | -- if semisup then 169 | -- assert(semisup.data:dim() == 3) 170 | -- end 171 | 172 | -- or idiom does not work for booleans 173 | if opts.toprbm ~= nil then 174 | rbm.toprbm = opts.toprbm 175 | else 176 | rbm.toprbm = true 177 | end 178 | local n_visible,n_samples,n_classes,n_input,n_hidden 179 | n_samples = train:size() 180 | local geometry = train:geometry() 181 | n_visible = geometry[1]*geometry[2]-- channels * channelwidth 182 | n_hidden = opts.n_hidden or assert(false) 183 | 184 | 185 | 186 | 187 | rbm.boost = opts.boost or 'none' 188 | if opts.boost == 'diff' then 189 | rbm.yprobs = torch.Tensor(train:size(),#train:classnames()) 190 | end 191 | 192 | rbm.batchsize = opts.batchsize or 1 193 | 194 | rbm.W = opts.W or initcrbm(n_hidden,n_visible) 195 | rbm.b = opts.b or torch.zeros(n_visible,1) 196 | rbm.c = opts.c or torch.zeros(n_hidden,1) 197 | 198 | rbm.vW = torch.zeros(rbm.W:size()):zero() 199 | rbm.vb = torch.zeros(rbm.b:size()):zero() 200 | rbm.vc = torch.zeros(rbm.c:size()):zero() 201 | 202 | rbm.dW = torch.Tensor(rbm.W:size()):zero() 203 | rbm.db = torch.Tensor(rbm.b:size()):zero() 204 | rbm.dc = torch.Tensor(rbm.c:size()):zero() 205 | rbm.rand = function(m,n) return torch.rand(m,n) end 206 | rbm.n_visible = n_visible 207 | rbm.n_samples = n_samples 208 | rbm.n_hidden = n_hidden 209 | rbm.errorfunction = opts.errorfunction or function(conf) return 1-conf:accuracy() end 210 | 211 | rbm.hidden_by_one = torch.ones(rbm.W:size(1),1) 212 | 213 | rbm.numepochs = opts.numepochs or 1 214 | rbm.currentepoch = 1 215 | rbm.learningrate = opts.learningrate or 0.05 216 | rbm.momentum = opts.momentum or 0 217 | rbm.traintype = opts.traintype or 'CD' -- CD or PCD 218 | rbm.cdn = opts.cdn or 1 219 | rbm.npcdchains = opts.npcdchains or 1 220 | 221 | -- OBJECTIVE 222 | rbm.alpha = opts.alpha or 1 223 | rbm.beta = opts.beta or 0 224 | 225 | -- REGULARIZATION 226 | rbm.dropout = opts.dropout or 0 227 | rbm.L1 = opts.L1 or 0 228 | rbm.L2 = opts.L2 or 0 229 | rbm.sparsity = opts.sparsity or 0 230 | rbm.patience = opts.patience or 15 231 | 232 | 233 | rbm.pretrain = opts.pretrain or 'none' 234 | 235 | -- Set up and down functions + generative statistics functions 236 | -- see Deep boltzmann machines salakhutdinov 2009 sec 3.1 237 | -- pretraining also modifies downy 238 | if rbm.pretrain == 'none' then 239 | rbm.up = opts.up or rbmup 240 | rbm.downx = opts.downx or rbmdownx 241 | if rbm.toprbm then 242 | rbm.downy = opts.downy or rbmdowny 243 | end 244 | elseif rbm.pretrain == 'top' then 245 | -- we double the downweights 246 | rbm.up = rbmup 247 | rbm.downx = rbmdownxpretrain 248 | if rbm.toprbm then 249 | rbm.downy = rbmdownypretrain 250 | end 251 | elseif rbm.pretrain == 'bottom' then 252 | -- double up weights 253 | rbm.up = rbmuppretrain 254 | rbm.downx = rbmdownx 255 | if rbm.toprbm then 256 | rbm.downy = rbmdowny 257 | end 258 | else 259 | print('unknown pretrain options') 260 | error() 261 | end 262 | 263 | 264 | rbm.visxsampler = opts.visxsampler or bernoullisampler 265 | rbm.hidsampler = opts.hidsampler or bernoullisampler 266 | 267 | 268 | rbm.generativestatistics = opts.generativestatistics or grads.generativestatistics 269 | rbm.generativegrads = opts.generativegrads or grads.generativegrads 270 | 271 | -- - 272 | rbm.tempfile = opts.tempfile or "temp_rbm.asc" 273 | rbm.finalfile = opts.finalfile or "final_rbm.asc" 274 | rbm.isgpu = opts.isgpu or 0 275 | rbm.precalctcwx = opts.precalctcwx or 1 276 | 277 | rbm.err_recon_train = torch.Tensor(rbm.numepochs):fill(-1) 278 | rbm.err_train = torch.Tensor(rbm.numepochs):fill(-1) 279 | rbm.err_val = torch.Tensor(rbm.numepochs):fill(-1) 280 | rbm.cur_err = torch.zeros(1) 281 | 282 | if rbm.traintype == 'PCD' then -- init PCD chains 283 | rbm.chx = torch.Tensor(rbm.npcdchains,n_visible) 284 | if rbm.toprbm then 285 | rbm.chy = torch.Tensor(rbm.npcdchains,#train:classnames()) 286 | end 287 | 288 | for i = 1,rbm.npcdchains do 289 | local idx = math.floor(torch.uniform(1,train:size()+0.999999999)) 290 | rbm.chx[{ i,{} }] = trainData[idx][1]:clone():view(1,-1) 291 | if rbm.toprbm then 292 | rbm.chy[{ i,{} }] = trainData[idx][2]:clone():view(1,-1) 293 | end 294 | 295 | end 296 | 297 | 298 | if rbm.beta > 0 then 299 | local kk_semisup = torch.randperm(semisup.x:size(1)) 300 | kk_semisup = kk_semisup[{ {1, rbm.npcdchains} }] 301 | rbm.chx_semisup = x_semisup[{kk_semisup,{} }]:clone() 302 | if rbm.toprbm then 303 | rbm.chy_semisup = y_semisup[{kk_semisup,{} }]:clone() 304 | end 305 | end 306 | elseif "CD" then 307 | -- 308 | elseif "meanfield" then 309 | -- 310 | else 311 | print("unknown traintype") 312 | error() 313 | end 314 | 315 | 316 | if rbm.toprbm then 317 | n_classes = #train:classnames() 318 | rbm.n_classes = n_classes 319 | rbm.conf = ConfusionMatrix(train:classnames()) 320 | rbm.d = opts.d or torch.zeros(n_classes,1) 321 | rbm.U = opts.U or initcrbm(n_hidden,n_classes) 322 | rbm.vU = torch.zeros(rbm.U:size()):zero() 323 | rbm.vd = torch.zeros(rbm.d:size()):zero() 324 | rbm.dU = torch.Tensor(rbm.U:size()):zero() 325 | rbm.dd = torch.Tensor(rbm.d:size()):zero() 326 | rbm.one_by_classes = torch.ones(1,rbm.U:size(2)) 327 | rbm.discriminativegrads = opts.discriminativegrads or grads.discriminativegrads 328 | rbm.pygivenx = opts.pygivenx or grads.pygivenx 329 | rbm.pygivenxdropout = opts.pygivenxdropout or grads.pygivenxdropout 330 | end 331 | 332 | if rbm.toprbm == false then 333 | assert(rbm.alpha == 1) -- for non top rbms it does not make sense to discriminative training 334 | end 335 | 336 | return(rbm) 337 | end 338 | 339 | function checkequality(t1,t2,prec,pr) 340 | not_same_dim = not t1:isSameSizeAs(t2) 341 | 342 | if pr then 343 | print(t1) 344 | print(t2) 345 | end 346 | local prec = prec or -4 347 | 348 | local diff = t1 - t2 349 | err = diff:abs():max() 350 | numeric_err = (err > math.pow(10,prec) ) 351 | if numeric_err then 352 | print('ASSERT: Numeric Error') 353 | elseif not_same_dim then 354 | print('ASSERT: Dimension Error') 355 | else 356 | print('Assert: Passed') 357 | end 358 | return (not not_same_dim) and (not numeric_err) 359 | 360 | end 361 | 362 | -- Stupid function to save an RBM in CSV... 363 | -- Use loadrbm.m to load the RBM in matlab 364 | function writerbmtocsv(rbm,folder) 365 | folder = folder or '' 366 | require('csvigo') 367 | function createtable(weight) 368 | local weighttable = {} 369 | for i = 1,weight:size(1) do --rows 370 | local row = {} 371 | for j = 1,weight:size(2) do -- columns 372 | row[j] = weight[{i,j}] 373 | end 374 | weighttable[i] = row 375 | end 376 | 377 | return weighttable 378 | 379 | end 380 | 381 | function readerr(err) 382 | e = {} 383 | for i = 1, err:size(1) do 384 | if err[i] ~= -1 then 385 | e[i] = err[i] 386 | end 387 | end 388 | ret = {} 389 | ret[1] = e 390 | return(ret) 391 | end 392 | csvigo.save{data=createtable(rbm.stat_gen.hk), path=paths.concat(folder,'rbmhk.csv'),verbose = false} 393 | csvigo.save{data=createtable(rbm.stat_gen.vkx), path=paths.concat(folder,'rbmvkx.csv'),verbose = false} 394 | csvigo.save{data=createtable(rbm.stat_gen.vky), path=paths.concat(folder,'rbmvky.csv'),verbose = false} 395 | csvigo.save{data=createtable(rbm.stat_gen.h0), path=paths.concat(folder,'rbmh0.csv'),verbose = false} 396 | csvigo.save{data=createtable(rbm.W), path=paths.concat(folder,'rbmW.csv'),verbose = false} 397 | csvigo.save{data=createtable(rbm.dU), path=paths.concat(folder,'rbmdU.csv'),verbose = false} 398 | csvigo.save{data=createtable(rbm.dW), path=paths.concat(folder,'rbmdW.csv'),verbose = false} 399 | csvigo.save{data=createtable(rbm.U), path=paths.concat(folder,'rbmU.csv'),verbose = false} 400 | csvigo.save{data=createtable(rbm.b), path=paths.concat(folder,'rbmb.csv'),verbose = false} 401 | csvigo.save{data=createtable(rbm.c), path=paths.concat(folder,'rbmc.csv'),verbose = false} 402 | csvigo.save{data=createtable(rbm.d), path=paths.concat(folder,'rbmd.csv'),verbose = false} 403 | csvigo.save{data=readerr(rbm.err_val), path=paths.concat(folder,'rbmerr_val.csv'),verbose = false} 404 | csvigo.save{data=readerr(rbm.err_train), path=paths.concat(folder,'rbmerr_train.csv'),verbose = false} 405 | csvigo.save{data=readerr(rbm.err_recon_train), path=paths.concat(folder,'rbmerr_recon_train.csv'),verbose = false} 406 | end 407 | 408 | 409 | function writetensor(tensor,filename) 410 | -- writes tensor to csv file 411 | require('csvigo') 412 | function createtable(weight) 413 | local weighttable = {} 414 | for i = 1,weight:size(1) do --rows 415 | local row = {} 416 | for j = 1,weight:size(2) do -- columns 417 | row[j] = weight[{i,j}] 418 | end 419 | weighttable[i] = row 420 | end 421 | 422 | return weighttable 423 | 424 | end 425 | 426 | local tab = createtable(tensor) 427 | csvigo.save{data=tab, path=filename,verbose = false} 428 | end 429 | 430 | 431 | 432 | function isRowVec(x) 433 | -- checks if x is a vector is 1xn 434 | if x:dim() == 2 and x:size(1) == 1 then 435 | res = true 436 | else 437 | print ("isRowVector vec size: ",x:size() ) 438 | res = false 439 | end 440 | return res 441 | end 442 | 443 | function isVec(x) 444 | -- checks if x is a vector 445 | if x:dim() == 1 then 446 | res = true 447 | else 448 | print ("isVec size: ",x:size() ) 449 | res = false 450 | end 451 | return res 452 | end 453 | 454 | function isMatrix(x) 455 | -- checks if x is a vector is mxn 456 | if x:dim() == 2 then 457 | res = true 458 | else 459 | print ("IsMatrix vec size: ",x:size() ) 460 | res = falses 461 | end 462 | return res 463 | end 464 | 465 | -------------------------------------------------------------------------------- /code/rbm-visualisation.lua: -------------------------------------------------------------------------------- 1 | -- visualisation function for the rbm training process. 2 | 3 | require('image') 4 | 5 | function create_weight_image(rbm, image_dimensions, filename) 6 | -- Create an image from the weights of the current rbm. 7 | 8 | -- print(rbm) 9 | -- print(image_dimensions) 10 | w = image_dimensions[1] 11 | h = image_dimensions[2] 12 | assert(rbm.W:size(2) == w*h) 13 | 14 | -- TODO: This may be the same as rbm.n_hidden 15 | n_filters = rbm.W:size(1) 16 | n_channels = 1 17 | 18 | pad = 1 19 | nrows = math.ceil(math.sqrt(n_filters)) 20 | 21 | local weight = rbm.W:view(n_filters, n_channels, w, h) 22 | local filters = image.toDisplayTensor{input=weight, padding=pad, 23 | nrow=nrows, scaleeach=true, symmetric=false} 24 | 25 | image.save(filename, filters) 26 | 27 | end -------------------------------------------------------------------------------- /code/rbm.lua: -------------------------------------------------------------------------------- 1 | require 'paths' 2 | require('nn') 3 | require('pl') 4 | require('torch') 5 | require 'sys' 6 | require 'xlua' 7 | require(codeFolder.. 'rbm-util') 8 | require(codeFolder.. 'rbm-helpers') 9 | require(codeFolder.. 'rbm-regularization') 10 | require(codeFolder..'rbm-grads') 11 | require(codeFolder..'MyConfusionMatrix') 12 | 13 | function rbmtrain(rbm,train,val,semisup) 14 | local x_train,y_train,x_val,y_val,x_semisup 15 | if semisup then 16 | print("semisupervised not implemented") 17 | error() 18 | end 19 | 20 | -- train RBM 21 | local x_tr,y_tr,x_semi, total_time, epoch_time,acc_train, best_val_err,patience, best_rbm,best 22 | printrbm(rbm,train,val,semisup) 23 | 24 | patience = rbm.patience 25 | total_time = os.time() 26 | 27 | -- extend error tensors if resuming training 28 | if rbm.err_train:size(1) < rbm.numepochs then 29 | best_val_err = rbm.err_val[rbm.currentepoch] 30 | rbm.err_recon_train = extendTensor(rbm, rbm.err_recon_train,rbm.numepochs) 31 | rbm.err_train = extendTensor(rbm,rbm.err_train,rbm.numepochs) 32 | rbm.err_val = extendTensor(rbm,rbm.err_val,rbm.numepochs) 33 | best_rbm = cprbm(rbm) 34 | end 35 | 36 | best_val_err = best_val_err or 1/0 37 | print("Best Val err",best_val_err) 38 | --print(y_train) 39 | for epoch = rbm.currentepoch, rbm.numepochs do 40 | --print("epcoh",epoch) 41 | epoch_time = os.time() 42 | rbm.cur_err = torch.zeros(1) 43 | rbm.currentepoch = epoch 44 | 45 | for i = 1, train:size() do -- iter over samples 46 | --x_tr,y_tr,x_semi = getsamples(rbm,x_train,y_train,x_semisup,i) 47 | 48 | if rbm.boost == 'none' then 49 | train_sample = train:next() 50 | else 51 | train_sample,skipped =train:nextboost() 52 | end 53 | 54 | x_tr = train_sample[1]:view(1,-1) 55 | y_tr = train_sample[2]:view(1,-1) 56 | regularization.dropout(rbm) -- create dropout mask for hidden units 57 | calculategrads(rbm,x_tr,y_tr,x_semi,i) -- calculates dW, dU, db, dc and dd 58 | 59 | -- update vW, vU, vb, vc and vd, formulae: vx = vX*mom + dX 60 | updateweights(rbm,i) 61 | --print(">>>>updateweights rbm.db: ",rbm.db) -- updates W,U,b,c and d, formulae: X = X + vX 62 | 63 | if rbm.progress > 0 and (i % 100) == 0 then 64 | xlua.progress(i, train:size()) 65 | end 66 | 67 | -- Force garbagecollector to collect 68 | collectgarbage() 69 | 70 | if rbm.csv then 71 | if i % rbm.csv == 1 then 72 | sf = paths.concat('e'..epoch..'s'..i) 73 | os.execute('mkdir -p ' .. sf) 74 | 75 | writerbmtocsv(rbm,sf) 76 | print('Saving to '..sf) 77 | end 78 | end 79 | 80 | 81 | if rbm.boost ~= 'none' and skipped + i > train:size() then 82 | -- we have passed once through the dataset 83 | break 84 | end 85 | end -- end samples loop 86 | epoch_time = os.time() - epoch_time 87 | 88 | -- calc. train recon err and train pred error 89 | rbm.err_recon_train[epoch] = rbm.cur_err:div(rbm.n_samples) 90 | 91 | if rbm.toprbm then 92 | local err,probs 93 | 94 | err,probs = geterror(rbm,train) 95 | rbm.err_train[epoch] = err 96 | if rbm.boost ~= 'none' then 97 | train:setyprobs(probs) 98 | end 99 | end 100 | 101 | if val and rbm.toprbm then 102 | rbm.err_val[epoch] = geterror(rbm,val) 103 | if rbm.err_val[epoch] < best_val_err then 104 | best_val_err = rbm.err_val[epoch] 105 | patience = rbm.patience 106 | 107 | if rbm.tempfile then 108 | torch.save(rbm.tempfile,rbm) 109 | end 110 | best_rbm = cprbm(rbm) -- save best weights 111 | best = '***' 112 | else 113 | patience = patience - 1 114 | best = '' 115 | end 116 | end 117 | displayprogress(rbm,epoch,epoch_time,patience,best or '') 118 | 119 | 120 | 121 | if patience < 0 then -- Stop training 122 | -- Cp weights from best_rbm 123 | rbm.W = best_rbm.W:clone() 124 | rbm.U = best_rbm.U:clone() 125 | rbm.b = best_rbm.b:clone() 126 | rbm.c = best_rbm.c:clone() 127 | rbm.d = best_rbm.d:clone() 128 | print("BREAK") 129 | break 130 | end 131 | 132 | end -- end epoch loop 133 | total_time = os.time() - total_time 134 | 135 | if rbm.finalfile then 136 | torch.save(rbm.finalfile,rbm) 137 | end 138 | print("Mean epoch time:", total_time / rbm.numepochs) 139 | return(rbm) 140 | end 141 | 142 | 143 | function calculategrads(rbm,x_tr,y_tr,x_semi,samplenum) 144 | -- add the grads to dW 145 | local dW_gen, dU_gen, db_gen, dc_gen, dd_gen, vkx, tcwx 146 | local dW_dis, dU_dis, dc_dis, dd_dis, p_y_given_x 147 | local dW_semi, dU_semi,db_semi, dc_semi, dd_semi, y_semi 148 | local h0,h0_rnd, hk,vkx,vkx_rnd,vky_rnd 149 | 150 | -- reset accumulators 151 | -- Assert correct formats 152 | assert(isMatrix(x_tr)) 153 | 154 | if rbm.toprbm then 155 | assert(isRowVec(y_tr)) 156 | end 157 | 158 | if x_semi then 159 | assert(isMatrix(x_semi)) 160 | end 161 | 162 | if rbm.precalctcwx == 1 then 163 | tcwx = torch.mm( x_tr,rbm.W:t() ):add( rbm.c:t() ) -- precalc tcwx 164 | end 165 | -- GENERATIVE GRADS 166 | if rbm.alpha > 0 then 167 | stat_gen = rbm.generativestatistics(rbm,x_tr,y_tr,tcwx) 168 | 169 | --print(x_tr:type(),y_tr:type(),h0_gen:type(),hk_gen:type(),vkx_rnd_gen:type(),vky_rnd_gen:type()) 170 | grads_gen = rbm.generativegrads(rbm,x_tr,y_tr,stat_gen) 171 | rbm.dW:add( grads_gen.dW:mul( rbm.alpha*rbm.learningrate )) 172 | rbm.db:add( grads_gen.db:mul( rbm.alpha*rbm.learningrate )) 173 | rbm.dc:add( grads_gen.dc:mul( rbm.alpha*rbm.learningrate )) 174 | 175 | if rbm.toprbm then 176 | rbm.dU:add( grads_gen.dU:mul( rbm.alpha*rbm.learningrate )) 177 | rbm.dd:add( grads_gen.dd:mul( rbm.alpha*rbm.learningrate )) 178 | end 179 | rbm.cur_err:add( torch.sum(torch.add(x_tr,-stat_gen.vkx):pow(2)) ) 180 | 181 | 182 | rbm.stat_gen = stat_gen 183 | rbm.grads_gen = grads_gen 184 | end 185 | 186 | -- DISCRIMINATIVE GRADS 187 | if rbm.alpha < 1 then 188 | grads_dis, p_y_given_x = rbm.discriminativegrads(rbm,x_tr,y_tr,tcwx) 189 | rbm.dW:add( grads_dis.dW:mul( (1-rbm.alpha)*rbm.learningrate )) 190 | rbm.dU:add( grads_dis.dU:mul( (1-rbm.alpha)*rbm.learningrate )) 191 | rbm.dc:add( grads_dis.dc:mul( (1-rbm.alpha)*rbm.learningrate )) 192 | rbm.dd:add( grads_dis.dd:mul( (1-rbm.alpha)*rbm.learningrate )) 193 | end 194 | 195 | -- SEMISUPERVISED GRADS 196 | if rbm.beta > 0 then 197 | if rbm.precalctcwx == 1 then 198 | tcwx_semi = torch.mm( x_semi,rbm.W:t() ):add( rbm.c:t() ) -- precalc tcwx 199 | end 200 | 201 | if rbm.toprbm then 202 | p_y_given_x = p_y_given_x or rbm.pygivenx(rbm,x_tr,tcwx_semi) 203 | y_semi = samplevec(p_y_given_x,rbm.rand):resize(1,rbm.n_classes) 204 | else 205 | y_semi = {} 206 | end 207 | 208 | stat_semi = rbm.generativestatistics(rbm,x_semi,y_semi,tcwx_semi) 209 | grads_semi = rbm.generativegrads(x_semi,y_semi,stat_semi) 210 | print("FIX problem with PCD chains in semisupevised learning") 211 | 212 | rbm.dW:add( grads_semi.dW:mul( rbm.beta*rbm.learningrate )) 213 | rbm.db:add( grads_semi.db:mul( rbm.beta*rbm.learningrate )) 214 | rbm.dc:add( grads_semi.dc:mul( rbm.beta*rbm.learningrate )) 215 | 216 | if rbm.toprbm then 217 | rbm.dU:add( grads_semi.dU:mul( rbm.beta*rbm.learningrate )) 218 | rbm.dd:add( grads_semi.dd:mul( rbm.beta*rbm.learningrate )) 219 | end 220 | end 221 | end 222 | 223 | function displayprogress(rbm,epoch,epoch_time,patience,best) 224 | local strepoch, lrmom, err_recon, err_train, err_val, epoch_time_patience 225 | 226 | strepoch = string.format("%i/%i | ",epoch,rbm.numepochs) 227 | lrmom = string.format("LR: %f MOM %f | ",rbm.learningrate,rbm.momentum) 228 | err_recon = string.format("ERROR: Recon %4.1f ",rbm.err_recon_train[epoch]) 229 | err_train = string.format("TR ERR: %f ", rbm.err_train[epoch] ) 230 | err_val = string.format("VAL ERR: %f |", rbm.err_val[epoch] ) 231 | epoch_time_patience = string.format("time: %4.0f Patience %i",epoch_time,patience) 232 | 233 | outstr = strepoch .. lrmom .. err_recon .. err_train .. err_val .. epoch_time_patience 234 | .. best 235 | print(outstr) 236 | 237 | end 238 | 239 | function updateweights(rbm,currentsample) 240 | -- update gradients 241 | 242 | -- fore every minibatch update weights 243 | if (currentsample % rbm.batchsize) == 0 then 244 | regularization.applyregularization(rbm) -- APPLY REGULATIZATIO BEFORE WEIGHT UPDATE 245 | if rbm.momentum > 0 then 246 | rbm.vW:add( rbm.dW ):mul(rbm.momentum) 247 | rbm.vb:add( rbm.db ):mul(rbm.momentum) 248 | rbm.vc:mad( rbm.dc ):mul(rbm.momentum) 249 | 250 | -- add momentum to dW 251 | rbm.dW:add(rbm.vW) 252 | rbm.db:add(rbm.vb) 253 | rbm.dc:add(rbm.vc) 254 | 255 | -- update momentum variable 256 | 257 | if rbm.toprbm then 258 | rbm.vU:add( rbm.dU ):mul(rbm.momentum) 259 | rbm.vd:add( rbm.dd ):mul(rbm.momentum) 260 | rbm.dU:add(rbm.vU) 261 | rbm.dd:add(rbm.vd) 262 | 263 | end 264 | 265 | end 266 | 267 | -- normalize weight update 268 | if rbm.batchsize > 1 then 269 | rbm.dW:mul(1/rbm.batchsize) 270 | rbm.db:mul(1/rbm.batchsize) 271 | rbm.dc:mul(1/rbm.batchsize) 272 | 273 | if rbm.toprbm then 274 | rbm.dU:mul(1/rbm.batchsize) 275 | rbm.dd:mul(1/rbm.batchsize) 276 | end 277 | end 278 | 279 | -- update weights 280 | rbm.W:add(rbm.dW) 281 | rbm.b:add(rbm.db) 282 | rbm.c:add(rbm.dc) 283 | 284 | -- reset weights 285 | rbm.dW:fill(0) 286 | rbm.db:fill(0) 287 | rbm.dc:fill(0) 288 | 289 | if rbm.toprbm then 290 | rbm.d:add(rbm.dd) 291 | rbm.U:add(rbm.dU) 292 | 293 | rbm.dd:fill(0) 294 | rbm.dU:fill(0) 295 | end 296 | end 297 | end 298 | 299 | 300 | function cprbm(rbm) 301 | newrbm = {} 302 | newrbm.W = rbm.W:clone() 303 | newrbm.U = rbm.U:clone() 304 | newrbm.b = rbm.b:clone() 305 | newrbm.c = rbm.c:clone() 306 | newrbm.d = rbm.d:clone() 307 | return(newrbm) 308 | end 309 | 310 | -- extend old tensor to 311 | function extendTensor(rbm,oldtensor,newsize,fill) 312 | if fill then fill = fill else fill = -1 end 313 | local newtensor 314 | newtensor = torch.Tensor(newsize):fill(fill) 315 | newtensor[{{1,rbm.currentepoch}}] = oldtensor[{{1,rbm.currentepoch}}]:clone() 316 | return newtensor 317 | end -------------------------------------------------------------------------------- /examples/examples.txt: -------------------------------------------------------------------------------- 1 | # EXAMPLE RBMS 2 | 3 | # DISCRIMINATIVE RBM 4 | th runrbm.lua -eta 0.05 -alpha 0 -nhidden 500 -folder test_discriminative 5 | 6 | # DISCRIMINATIVE DROPOUT RBM 7 | th runrbm.lua -eta 0.05 -alpha 0 -nhidden 500 -folder test_discriminative_dropout -dropout 0.5 8 | 9 | # GENERATIVE PCD RBM 10 | th runrb,.lua -eta 0.05 -alpha 1 -nhidden 500 -folder test_generative_pcd -traintype PCD 11 | 12 | # HYBRID RBM 13 | th runrbm.lua -eta 0.05 -alpha 0.01 -nhidden 1500 -folder test_hybrid 14 | 15 | # HYBRID DROPOUT RBM 16 | th runrbm.lua -eta 0.05 -alpha 0.01 -nhidden 1500 -folder test_hybrid_dropout -dropout 0.5 17 | 18 | # HYBRID SPARSITY RBM 19 | th runrbm.lua -eta 0.05 -alpha 0.01 -nhidden 3000 -folder test_hybrid_sparsity -sparsity 0.0001 20 | 21 | # HYBRID SPARSITY DROPOUT RBM 22 | th runrbm.lua -eta 0.05 -alpha 0.01 -nhidden 3000 -folder test_hybrid_sparsity_dropout -sparsity 0.0001 -dropout 0.5 23 | 24 | # GENERATIVE RBM 25 | th runrbm.lua -eta 0.05 -alpha 1 -nhidden 1000 -folder test_generative 26 | 27 | # GENERATIVE DROPOUT RBM 28 | th runrbm.lua -eta 0.05 -alpha 1 -nhidden 2000 -folder test_generative -dropout -0.5 -------------------------------------------------------------------------------- /examples/rbm_tests.lua: -------------------------------------------------------------------------------- 1 | codeFolder = '../code/' 2 | 3 | require('torch') 4 | require(codeFolder..'rbm') 5 | require(codeFolder..'ProFi') 6 | require(codeFolder..'dataset-from-tensor.lua') 7 | require 'paths' 8 | torch.manualSeed(101) 9 | torch.setdefaulttensortype('torch.FloatTensor') 10 | torch.setnumthreads(2) 11 | 12 | --train_mnist,val_mnist,test_mnist = mnist.loadMnist(50) 13 | -- { 14 | -- 1 : FloatTensor - size: 1x32x32 15 | -- 2 : FloatTensor - size: 10 16 | -- } 17 | 18 | opts = {} 19 | opts.n_hidden = 500 20 | opts.numepochs = 500 21 | opts.learningrate = 0.05 22 | opts.alpha = 0 23 | opts.beta = 0 24 | opts.dropconnect = 0 25 | 26 | 27 | 28 | n_classes = 4; 29 | n_hidden = 7; 30 | n_visible = 3; 31 | 32 | -- create test data and test weights 33 | x = torch.Tensor({0.4170,0.7203, 0.0001}):resize(1,1,n_visible) 34 | x2d = x:view(1,3) 35 | y = torch.Tensor({0,0,1, 0}):resize(1,n_classes) 36 | 37 | U = torch.Tensor({{ 0.0538, -0.0113, -0.0688, -0.0074}, 38 | { 0.0564, 0.0654, 0.0357, 0.0584}, 39 | {-0.0593, 0.0047, 0.0698, -0.0295}, 40 | {-0.0658, 0.0274, 0.0355, -0.0303}, 41 | {-0.0472, -0.0264, -0.0314, -0.0529}, 42 | { 0.0540, 0.0266, 0.0413, -0.0687}, 43 | {-0.0574, 0.0478, -0.0567, 0.0255}}) 44 | W = torch.Tensor({{-0.0282 , -0.0115 , 0.0084}, 45 | {-0.0505 , 0.0265 , -0.0514}, 46 | {-0.0582 , -0.0422 , -0.0431}, 47 | {-0.0448 , 0.0540 , 0.0430}, 48 | {-0.0221 , -0.0675 , 0.0669}, 49 | {-0.0147 , 0.0244 , -0.0267}, 50 | {0.0055 , -0.0118 , 0.0275}}) 51 | b = torch.zeros(n_visible,1) 52 | c = torch.zeros(n_hidden,1) 53 | d = torch.zeros(n_classes,1) 54 | vW = torch.Tensor(W:size()):zero() 55 | vU = torch.Tensor(U:size()):zero() 56 | vb = torch.Tensor(b:size()):zero() 57 | vc = torch.Tensor(c:size()):zero() 58 | vd = torch.Tensor(d:size()):zero() 59 | 60 | 61 | 62 | rbm = {W = W:clone(), U = U:clone(), d = d:clone(), b = b:clone(), c = c:clone(), 63 | vW =vW,vU = vU,vb = vb, vc = vc, vd = vd,}; 64 | rbm.dW = torch.Tensor(rbm.W:size()):zero() 65 | rbm.dU = torch.Tensor(rbm.U:size()):zero() 66 | rbm.db = torch.Tensor(rbm.b:size()):zero() 67 | rbm.dc = torch.Tensor(rbm.c:size()):zero() 68 | rbm.dd = torch.Tensor(rbm.d:size()):zero() 69 | 70 | rbm.rand = function(m,n) return torch.Tensor(m,n):fill(1):mul(0.5)end -- for testing 71 | rbm.n_classes = n_classes 72 | rbm.n_visible = n_visible 73 | rbm.n_samples = 1 74 | rbm.n_input = 1 75 | rbm.numepochs = 1 76 | rbm.learningrate = 0.1 77 | rbm.alpha = 1 78 | rbm.beta = 0 79 | rbm.momentum = 0 80 | rbm.dropout = opts.dropout or 0 81 | rbm.dropouttype = "bernoulli" 82 | rbm.dropconnect = opts.dropconnect or 0 83 | rbm.L1 = opts.L1 or 0 84 | rbm.L2 = opts.L2 or 0 85 | rbm.sparsity = opts.sparsity or 0 86 | rbm.err_recon_train = torch.Tensor(1):fill(-1) 87 | rbm.err_train = torch.Tensor(1):fill(-1) 88 | rbm.err_val = torch.Tensor(1):fill(-1) 89 | rbm.temp_file = "blabla" 90 | rbm.patience = 15 91 | rbm.one_by_classes = torch.ones(1,rbm.U:size(2)) 92 | rbm.hidden_by_one = torch.ones(rbm.W:size(1),1) 93 | rbm.traintype = 'CD' -- CD 94 | rbm.npcdchains = 1 95 | rbm.cdn = 1 96 | rbm.n_hidden = n_hidden 97 | rbm.currentepoch = 1 98 | rbm.bottomrbm = 1 99 | rbm.toprbm = true 100 | rbm.samplex = false 101 | rbm.lrdecay = 0 -- no decay 102 | rbm.up = rbmup 103 | rbm.downx = rbmdownx 104 | rbm.downy = rbmdowny 105 | rbm.errorfunction = function(conf) return 1-conf:accuracy() end 106 | rbm.precalctcwx = 1 107 | rbm.generativestatistics = grads.generativestatistics 108 | rbm.generativegrads = grads.generativegrads 109 | rbm.discriminativegrads = grads.discriminativegrads 110 | rbm.pygivenx = grads.pygivenx 111 | rbm.pygivenxdropout = grads.pygivenxdropout 112 | rbm.batchsize = 1 113 | rbm.visxsampler = bernoullisampler 114 | rbm.hidsampler = bernoullisampler 115 | rbm.conf = ConfusionMatrix(4) 116 | 117 | --------------------------------------------------------- 118 | -- TRUE VALUES rbm-util 119 | --------------------------------------------------------- 120 | h0_true = torch.Tensor({ 0.4778,0.5084,0.5038,0.5139,0.4777,0.5132,0.4843}):resize(1,n_hidden) 121 | h0_rnd_true = torch.Tensor({0,1,1,1,0,1,0}):resize(1,n_hidden) 122 | vkx_true = torch.Tensor({ 0.4580,0.5156,0.4805}):resize(1,n_visible) 123 | vkx_rnd_true = torch.Tensor({0,1,0}):resize(1,n_visible) 124 | vky_true = torch.Tensor({ 0.2319,0.2664,0.2824,0.2194}):resize(1,n_classes) 125 | vky_rnd_true = torch.Tensor({ 0,0,1,0}):resize(1,n_classes) 126 | p_y_given_x_true = torch.Tensor({0.2423,0.2672,0.2532,0.2373}):resize(1,n_classes) 127 | 128 | 129 | --------------------------------------------------------- 130 | -- GENERATIVE TRUE WEIGHTS 131 | --------------------------------------------------------- 132 | dw_gen_true = torch.Tensor({ 133 | {0.1992, -0.1358, 0.0001}, 134 | {0.2120, -0.1493, 0.0001}, 135 | {0.2101, -0.1440, 0.0001}, 136 | {0.2143, -0.1522, 0.0001}, 137 | {0.1992, -0.1312, 0.0001}, 138 | {0.2140, -0.1468, 0.0001}, 139 | {0.2020, -0.1340, 0.0001} 140 | }) 141 | du_gen_true = torch.Tensor({ 142 | {0, 0, -0.0021, 0}, 143 | {0, 0, -0.0071, 0}, 144 | {0, 0, -0.0031, 0}, 145 | {0, 0, -0.0084, 0}, 146 | {0, 0, 0.0024, 0}, 147 | {0, 0, -0.0032, 0}, 148 | {0, 0, 0.0014, 0} 149 | }) 150 | 151 | 152 | 153 | 154 | db_gen_true = torch.Tensor({0.4170,-0.2797,0.0001}):resize(n_visible,1) 155 | dc_gen_true = torch.Tensor({-0.0021,-0.0071,-0.0031,-0.0084,0.0024,-0.0032,0.0014}):resize(n_hidden,1) 156 | dd_gen_true = torch.Tensor({0,0,0,0}):resize(n_classes,1) 157 | 158 | --------------------------------------------------------- 159 | --- DISCRIMINATIVE TRUE WEIGHTS 160 | --------------------------------------------------------- 161 | dw_dis_true = torch.Tensor({ 162 | {-0.0062, -0.0107, -0.0000}, 163 | {-0.0019, -0.0033, -0.0000}, 164 | { 0.0075, 0.0130, 0.0000}, 165 | { 0.0044, 0.0076, 0.0000}, 166 | { 0.0008, 0.0014, 0.0000}, 167 | { 0.0028, 0.0049, 0.0000}, 168 | {-0.0049, -0.0085, -0.0000}}) 169 | 170 | du_dis_true = torch.Tensor({ 171 | {-0.1232, -0.1315, 0.3568, -0.1170}, 172 | {-0.1245, -0.1378, 0.3797, -0.1220}, 173 | {-0.1143, -0.1303, 0.3762, -0.1136}, 174 | {-0.1184, -0.1368, 0.3838, -0.1180}, 175 | {-0.1148, -0.1280, 0.3567, -0.1121}, 176 | {-0.1251, -0.1361, 0.3832, -0.1152}, 177 | {-0.1173, -0.1364, 0.3616, -0.1198}}) 178 | 179 | 180 | dc_dis_true = torch.Tensor({-0.0149,-0.0046,0.0181,0.0106,0.0019,0.0067,-0.0118}):resize(n_hidden,1) 181 | dd_dis_true = torch.Tensor({ -0.2423,-0.2672,0.7468,-0.2373}):resize(n_classes,1) 182 | 183 | --------------------------------------------------------- 184 | --- CHECK FOR SIDE EFFECTS 185 | --------------------------------------------------------- 186 | -- if they have side effects on x,y or rbm then generative tests will fails 187 | _h0 = rbmup(rbm,x2d,y) -- UP 188 | _h0_rnd = bernoullisampler(_h0,rbm.rand) 189 | _vkx = rbmdownx(rbm,_h0_rnd) -- DOWNX 190 | _vkx_rnd = bernoullisampler(_vkx,rbm.rand) 191 | _vky = rbmdowny(rbm,_h0_rnd) 192 | _vky_rnd = samplevec(_vky,rbm.rand) 193 | _p_y_given_x = grads.pygivenx(rbm,x2d) 194 | rbm.learningrate = 0 -- to avoid updating weights 195 | --rbm = rbmtrain(rbm,x,y) 196 | 197 | 198 | --------------------------------------------------------- 199 | --- CALCULATE VALUES FOR TESTING 200 | --------------------------------------------------------- 201 | tcwx = torch.mm( x2d,rbm.W:t() ):add( rbm.c:t() ) 202 | 203 | print(x2d,y,tcwx) 204 | print(rbm) 205 | stat_gen= grads.generativestatistics(rbm,x2d,y,tcwx) 206 | grads_gen = grads.generativegrads(rbm,x2d,y,stat_gen) 207 | grads_dis, p_y_given_x_dis = grads.discriminativegrads(rbm,x2d,y,tcwx) 208 | -- -- calculte value 209 | -- h0 = rbmup(rbm,x,y) -- UP 210 | -- h0_rnd = sampler(h0,rbm.rand) 211 | -- vkx = rbmdownx(rbm,h0_rnd) -- DOWNX 212 | -- vkx_rnd = sampler(vkx,rbm.rand) 213 | -- vky = rbmdowny(rbm,h0_rnd) 214 | -- vky_rnd = samplevec(vky,rbm.rand) 215 | p_y_given_x =grads.pygivenx(rbm,x2d) 216 | 217 | 218 | -- --------------------------------------------------------- 219 | -- --- TEST RBM-UTIL FUNCTIONS 220 | -- --------------------------------------------------------- 221 | assert(checkequality(stat_gen.h0, h0_true,-4),'Check h0 failed') 222 | --assert(checkequality(stat_gen.h0_rnd, h0_rnd_true),'Check h0_rnd failed') 223 | --assert(checkequality(stat_gen.vkx, vkx_true),'Check vkx failed') 224 | assert(checkequality(stat_gen.vkx_unsampled, vkx_true),'Check vkx_unsampled failed') 225 | 226 | 227 | assert(checkequality(stat_gen.vkx, vkx_rnd_true),'Check vkx_rnd failed') 228 | 229 | 230 | assert(checkequality(stat_gen.vky, vky_rnd_true),'Check vky_rnd failed') 231 | 232 | assert(checkequality(p_y_given_x, p_y_given_x_true),'Check p_y_given_x failed') 233 | assert(checkequality(p_y_given_x, p_y_given_x_dis,-3),'Check p_y_given_x_dis failed') 234 | 235 | 236 | 237 | -- print "TEST of RBM-UTIL gradients : PASSED" 238 | 239 | -- --------------------------------------------------------- 240 | -- --- TEST GENERATIVE WEIGHTS 241 | -- --------------------------------------------------------- 242 | assert(checkequality(grads_gen.dW, dw_gen_true,-3),'Check dw failed') 243 | assert(checkequality(grads_gen.dU, du_gen_true),'Check du failed') 244 | assert(checkequality(grads_gen.db, db_gen_true,-3),'Check db failed') 245 | assert(checkequality(grads_gen.dc, dc_gen_true),'Check dc failed') 246 | assert(checkequality(grads_gen.dd, dd_gen_true),'Check dd failed') 247 | print "TEST of GENERATIVE gradients : PASSED" 248 | 249 | 250 | -- --------------------------------------------------------- 251 | -- --- TEST DISCRIMINATIVE WEIGHTS 252 | -- --------------------------------------------------------- 253 | assert(checkequality(grads_dis.dW, dw_dis_true,-3),'Check dw failed') 254 | assert(checkequality(grads_dis.dU, du_dis_true),'Check du failed') 255 | assert(checkequality(grads_dis.dc, dc_dis_true,-3),'Check dc failed') 256 | assert(checkequality(grads_dis.dd, dd_dis_true),'Check dd failed') 257 | print "TEST of DISCRIMINATIVE gradients : PASSED" 258 | 259 | 260 | -- --------------------------------------------------------- 261 | -- --- TEST RBMTRAIN 262 | -- --------------------------------------------------------- 263 | 264 | trainData = datatensor.createDataset(x,y,{'A','B','C','D',},{1,3}) 265 | 266 | 267 | rbm.beta = 0 268 | rbm.learningrate = 0.1 269 | rbm.dropout = 0 270 | rbm.dropconnect = 0 271 | rbm.boost = 'none' 272 | rbm.progress = 0 273 | 274 | --train ={} 275 | --train.data = x 276 | --train.labels = torch.Tensor({3}):float() 277 | 278 | orgrbm = {} 279 | orgrbm.W = rbm.W:clone() 280 | orgrbm.U = rbm.U:clone() 281 | orgrbm.b = rbm.b:clone() 282 | orgrbm.c = rbm.c:clone() 283 | orgrbm.d = rbm.d:clone() 284 | 285 | 286 | rbm = rbmtrain(rbm,trainData) 287 | 288 | 289 | 290 | 291 | 292 | -- check generative 293 | rbm.alpha = 1 294 | assert(checkequality(rbm.W, torch.add(W ,torch.mul(dw_gen_true,rbm.learningrate)) ,-3),'Check rbm.W failed') 295 | assert(checkequality(rbm.U, torch.add(U ,torch.mul(du_gen_true,rbm.learningrate)) ,-3),'Check rbm.U failed') 296 | assert(checkequality(rbm.b, torch.add(b ,torch.mul(db_gen_true,rbm.learningrate)) ,-3),'Check rbm.b failed') 297 | assert(checkequality(rbm.c, torch.add(c ,torch.mul(dc_gen_true,rbm.learningrate)) ,-3),'Check rbm.c failed') 298 | assert(checkequality(rbm.d, torch.add(d ,torch.mul(dd_gen_true,rbm.learningrate)) ,-3),'Check rbm.d failed') 299 | print('Generative Training : OK') 300 | 301 | -- check discriminative 302 | 303 | -- rbm.W = orgrbm.W:clone() 304 | -- rbm.U = orgrbm.U:clone() 305 | -- rbm.b = orgrbm.b:clone() 306 | -- rbm.c = orgrbm.c:clone() 307 | -- rbm.d = orgrbm.d:clone() 308 | 309 | -- assert(checkequality(rbm.W, orgrbm.W ,-3),'Check rbm.W failed') 310 | -- assert(checkequality(rbm.U, orgrbm.U ,-3),'Check rbm.U failed') 311 | -- assert(checkequality(rbm.b, orgrbm.b ,-3),'Check rbm.b failed') 312 | -- assert(checkequality(rbm.c, orgrbm.c ,-3),'Check rbm.c failed') 313 | -- assert(checkequality(rbm.d, orgrbm.d ,-3),'Check rbm.d failed') 314 | 315 | 316 | rbm.W = W:clone() 317 | rbm.U = U:clone() 318 | rbm.b = b:clone() 319 | rbm.c = c:clone() 320 | rbm.d = d:clone() 321 | rbm.vW:fill(0) 322 | rbm.vU:fill(0) 323 | rbm.vb:fill(0) 324 | rbm.vc:fill(0) 325 | rbm.vd:fill(0) 326 | rbm.dW:fill(0) 327 | rbm.dU:fill(0) 328 | rbm.db:fill(0) 329 | rbm.dc:fill(0) 330 | rbm.dd:fill(0) 331 | rbm.alpha = 0 332 | rbm = rbmtrain(rbm,trainData) 333 | 334 | 335 | rbm.learningrate = 0.1 336 | assert(checkequality(rbm.W, torch.add(W ,torch.mul(dw_dis_true,rbm.learningrate)) ,-3),'Check rbm.W failed') 337 | assert(checkequality(rbm.U, torch.add(U ,torch.mul(du_dis_true,rbm.learningrate)) ,-3),'Check rbm.U failed') 338 | --assert(checkequality(rbm.b, torch.add(b ,torch.mul(db_gen_true,rbm.learningrate)) ,-3),'Check rbm.b failed') 339 | assert(checkequality(rbm.c, torch.add(c ,torch.mul(dc_dis_true,rbm.learningrate)) ,-3),'Check rbm.c failed') 340 | assert(checkequality(rbm.d, torch.add(d ,torch.mul(dd_dis_true,rbm.learningrate)) ,-3),'Check rbm.d failed') 341 | print('Discriminative Training : OK') 342 | -- print "TEST of rbmtrain 343 | 344 | 345 | -- check hybrid 346 | rbm.W = W:clone() 347 | rbm.U = U:clone() 348 | rbm.b = b:clone() 349 | rbm.c = c:clone() 350 | rbm.d = d:clone() 351 | rbm.vW:fill(0) 352 | rbm.vU:fill(0) 353 | rbm.vb:fill(0) 354 | rbm.vc:fill(0) 355 | rbm.vd:fill(0) 356 | rbm.dW:fill(0) 357 | rbm.dU:fill(0) 358 | rbm.db:fill(0) 359 | rbm.dc:fill(0) 360 | rbm.dd:fill(0) 361 | rbm.alpha = 0.1 362 | rbm = rbmtrain(rbm,trainData) 363 | assert(checkequality(rbm.W, torch.add(W ,torch.mul(dw_dis_true,rbm.learningrate*(1-rbm.alpha))):add(torch.mul(dw_gen_true,rbm.learningrate*rbm.alpha)) ,-3),'Check rbm.W failed') 364 | assert(checkequality(rbm.U, torch.add(U ,torch.mul(du_dis_true,rbm.learningrate*(1-rbm.alpha))):add(torch.mul(du_gen_true,rbm.learningrate*rbm.alpha)) ,-3),'Check rbm.U failed') 365 | assert(checkequality(rbm.b, torch.add(b ,torch.mul(db_gen_true,rbm.learningrate*(rbm.alpha))) ,-3),'Check rbm.b failed') 366 | assert(checkequality(rbm.c, torch.add(c ,torch.mul(dc_dis_true,rbm.learningrate*(1-rbm.alpha))):add(torch.mul(dc_gen_true,rbm.learningrate*rbm.alpha)) ,-3),'Check rbm.c failed') 367 | assert(checkequality(rbm.d, torch.add(d ,torch.mul(dd_dis_true,rbm.learningrate*(1-rbm.alpha))):add(torch.mul(dd_gen_true,rbm.learningrate*rbm.alpha)) ,-3),'Check rbm.d failed') 368 | print('Hybrid Training : OK') 369 | 370 | 371 | 372 | 373 | 374 | -- : PASSED" 375 | 376 | 377 | -- print(stat_gen['vky'],vky_rnd_true) 378 | -- error() 379 | 380 | 381 | 382 | -- opts_mnist = {} 383 | -- opts_mnist.n_hidden = 10 384 | -- rbm_mnist = rbmsetup(opts_mnist,train_mnist) 385 | -- rbm_mnist.alpha = 0 386 | -- rbm_mnist.beta = 0 387 | -- rbm_mnist.learningrate = 0.1 388 | -- rbm_mnist.dropout = 0.5 389 | -- rbm_mnist.dropconnect = 0 390 | -- rbm_mnist = rbmtrain(rbm_mnist,train_mnist,val_mnist) 391 | 392 | 393 | -- -- extend training with new objective 394 | -- rbm_mnist.dropout = 0.1 395 | -- rbm_mnist.numepochs = 10 396 | -- rbm_mnist.learningrate =0.5 397 | -- rbm_mnist.alpha = 0.5 398 | -- rbm_mnist = rbmtrain(rbm_mnist,train_mnist,val_mnist) 399 | 400 | -- uppass = rbmuppass(rbm_mnist,train_mnist) 401 | -- uppass = rbmuppass(rbm_mnist,test_mnist) 402 | 403 | 404 | -- rbm_mnist.toprbm = false 405 | -- rbm_mnist.alpha = 1 406 | -- rbm_mnist.currentepoch = 1 407 | -- rbm_mnist.U = nil 408 | -- rbm_mnist.dU = nil 409 | -- rbm_mnist.d = nil 410 | -- rbm_mnist.dd = nil 411 | -- rbm_mnist.vd = nil 412 | -- rbm_mnist.vU = nil 413 | 414 | -- rbm_mnist = rbmtrain(rbm_mnist,train_mnist) 415 | 416 | -- uppass = rbmuppass(rbm_mnist,train_mnist) 417 | -- uppass = rbmuppass(rbm_mnist,test_mnist) 418 | 419 | 420 | 421 | 422 | -- -- MNIST TEST 423 | 424 | 425 | 426 | -- --assert(checkequality(torch.add(U,du_true):mul(rbm.learningrate), rbm.U,-3),'Check rbm.W failed') 427 | -- --assert(checkequality(torch.add(b,db_true):mul(rbm.learningrate), rbm.b,-3),'Check rbm.b failed') 428 | -- --assert(checkequality(torch.add(c,dc_true):mul(rbm.learningrate), rbm.c,-3),'Check rbm.c failed') 429 | -- --assert(checkequality(torch.add(d,dd_true):mul(rbm.learningrate), rbm.d,-3),'Check rbm.d failed') 430 | -- --assert(checkequality(torch.add(U,du_true), rbm.U,-3),'Check rbm.U failed') 431 | -- --assert(checkequality(torch.add(b,db_true), rbm.b,-3),'Check rbm.b failed') 432 | -- --assert(checkequality(torch.add(c,dc_true), rbm.c,-3),'Check rbm.c failed') 433 | -- --assert(checkequality(torch.add(d,dd_true), rbm.d,-3),'Check rbm.d failed') 434 | -- --assert(checkequality(du, du_true),'Check du failed') 435 | -- --assert(checkequality(db, db_true),'Check db failed') 436 | -- --assert(checkequality(dc, dc_true),'Check dc failed') 437 | -- --assert(checkequality(dd, dd_true),'Check dd failed') 438 | 439 | -- -- INit rbm 440 | -- --[[torch.setdefaulttensortype('torch.FloatTensor') 441 | 442 | 443 | 444 | 445 | 446 | -- x = torch.zeros(10,784) 447 | -- y = torch.zeros(10,10) 448 | -- ]] 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | -- --[[print(rbm.W) 457 | -- print(rbm.U) 458 | -- print(rbm.b) 459 | -- print(rbm.c) 460 | -- print(rbm.d)--]] 461 | 462 | 463 | -- -- tcwx = torch.mm( x,rbm.W:t() ):add( rbm.c:t() ) 464 | -- --tcwx = -0.4333 0.1500 0.1500 0.1133 0.1500 0.1500 -0.1167 465 | 466 | -- ---print(tcwx) 467 | 468 | -------------------------------------------------------------------------------- /examples/runrbm.lua: -------------------------------------------------------------------------------- 1 | codeFolder = '../code/' 2 | 3 | require('torch') 4 | require(codeFolder..'rbm') 5 | require(codeFolder..'dataset-mnist') 6 | require(codeFolder..'ProFi') 7 | require 'paths' 8 | 9 | 10 | if not opts then 11 | print '==> processing options' 12 | cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('MNIST/Optimization') 15 | cmd:text() 16 | cmd:text('Options:') 17 | cmd:option('-eta', 0.05, 'LearningRate') 18 | cmd:option('-save', 'logs', 'subdirectory to save/log experiments in') 19 | cmd:option('-datasetsize', 'full', 'small|full size of dataset') 20 | cmd:option('-dataset', 'MNIST', 'Select dataset') 21 | cmd:option('-seed', 101, 'random seed') 22 | cmd:option('-folder', '../rbmtest', 'folder where models are saved') 23 | cmd:option('-traintype', 'CD', 'CD|PCD') 24 | cmd:option('-ngibbs', 1, 'Number of gibbs steps, e.g CD-5') 25 | cmd:option('-numepochs', 500, 'Number of epochs') 26 | cmd:option('-patience', 15, 'Early stopping patience') 27 | cmd:option('-alpha', 0.5, '0=dicriminative, 1=generative, ]0-1[ = hybrid') 28 | cmd:option('-beta', 0, 'semisupervised training (NOT IMPLEMENTED)') 29 | cmd:option('-dropout', 0, 'dropout probability') 30 | cmd:option('-progress', 1, 'display progressbar') 31 | cmd:option('-L2', 0, 'weight decay') 32 | cmd:option('-L1', 0, 'weight decay') 33 | cmd:option('-momentum', 0, 'momentum') 34 | cmd:option('-sparsity', 0, 'sparsity') 35 | cmd:option('-inittype', 'crbm', 'crbm|gauss Gaussian or uniformly drawn initial weights') 36 | cmd:option('-nhidden', 500, 'number of hidden units') 37 | cmd:option('-toprbm', true, 'non-toprbms are trained generatively,used for stacking RBMs') 38 | cmd:option('-batchsize', 1, 'Minibatch size') 39 | cmd:option('-errfunc', 'acc', 'acc|classacc|spec|sens|mcc|ppv|npv|fpr|fdr|F1| Error measure') 40 | cmd:option('-pretrain', 'none', 'none|top|bottom specify if rbm will be used in DBM as top or bottom (untested)') 41 | cmd:text() 42 | opts = cmd:parse(arg or {}) 43 | end 44 | 45 | 46 | torch.manualSeed(opts.seed) 47 | torch.setdefaulttensortype('torch.FloatTensor') 48 | 49 | -- geometry: width and height of input images 50 | if opts.dataset == "MNIST" then 51 | geometry = {32,32} 52 | if opts.datasetsize == 'full' then 53 | trainData,valData = mnist.loadTrainAndValSet(geometry,'none') 54 | testData = mnist.loadTestSet(nbTestingPatches, geometry) 55 | elseif opts.datasetsize == 'small' then 56 | print(' only using 2000 samples to train quickly (use flag -full to use 60000 samples)') 57 | trainData = mnist.loadTrainSet(2000, geometry,'none') 58 | testData = mnist.loadTestSet(1000, geometry) 59 | valData = mnist.loadTestSet(1000, geometry) 60 | else 61 | print('Unknown datasize') 62 | error() 63 | end 64 | trainData:toProbability() 65 | valData:toProbability() 66 | testData:toProbability() 67 | 68 | local errfunc 69 | local class_to_optimize = 1 70 | 71 | if opts.errfunc == "acc" then 72 | print('Using 1-accuracy error') 73 | errfunc = function(conf) return 1-conf:accuracy() end 74 | elseif opts.errfunc == "spec" then 75 | print('Using 1-specicity error for class',class_to_optimize) 76 | errfunc = function(conf) return 1-conf:specificity()[{1,class_to_optimize}] end 77 | elseif opts.errfunc == "sens" then 78 | print('Using 1-sensitivity error for class',class_to_optimize) 79 | errfunc = function(conf) return 1-conf:sensitivity()[{1,class_to_optimize}] end 80 | elseif opts.errfunc == "mcc" then 81 | print('Using 1-matthew correlation error for class',class_to_optimize) 82 | errfunc = function(conf) return 1-conf:matthewsCorrelation()[{1,class_to_optimize}] end 83 | elseif opts.errfunc == "ppv" then 84 | print('Using 1-positive predictive value error for class',class_to_optimize) 85 | errfunc = function(conf) return 1-conf:positivePredictiveValue()[{1,class_to_optimize}] end 86 | elseif opts.errfunc == "npv" then 87 | print('Using 1-negative predictive value error for class',class_to_optimize) 88 | errfunc = function(conf) return 1-conf:negativePredictiveValue()[{1,class_to_optimize}] end 89 | elseif opts.errfunc == "fpr" then 90 | print('Using 1-false positive rate error for class',class_to_optimize) 91 | errfunc = function(conf) return 1-conf:falsePositiveRate()[{1,class_to_optimize}] end 92 | elseif opts.errfunc == "fdr" then 93 | print('Using 1-false discovery rate error for class',class_to_optimize) 94 | errfunc = function(conf) return 1-conf:falseDiscoveryRate()[{1,class_to_optimize}] end 95 | elseif opts.errfunc == "F1" then 96 | print('Using 1-F1 error for class',class_to_optimize) 97 | errfunc = function(conf) return 1-conf:F1()[{1,class_to_optimize}] end 98 | elseif opts.errfunc == "classacc" then 99 | print('Using 1-class Accuracy error for class',class_to_optimize) 100 | errfunc = function(conf) return 1-conf:classAccuracy()[{1,class_to_optimize}] end 101 | else 102 | print('unknown errorfunction') 103 | error() 104 | end 105 | 106 | else 107 | opts.errorfunction = errfunc 108 | print('unknown dataset') 109 | error() 110 | end 111 | 112 | num_threads = 1 113 | torch.setnumthreads(num_threads) 114 | if torch.getnumthreads() < num_threads then 115 | print("Setting number of threads had no effect. Maybe install with gcc 4.9 for openMP?") 116 | end 117 | 118 | -- SETUP RBM 119 | 120 | os.execute('mkdir -p ' .. opts.folder) -- create tempfolder if it does not exist 121 | opts.finalfile = paths.concat(opts.folder,'final.asc') 122 | opts.tempfile = paths.concat(opts.folder,'temp.asc') -- current best is saved to this folder 123 | opts.learningrate = opts.eta 124 | opts.n_hidden = opts.nhidden 125 | opts.cdn = opts.ngibbs 126 | 127 | -- DO TRAINING 128 | rbm = rbmsetup(opts,trainData) 129 | rbm = rbmtrain(rbm,trainData,valData) 130 | local err_train = geterror(rbm,trainData) 131 | local err_val = geterror(rbm,valData) 132 | local err_test = geterror(rbm,testData) 133 | print('Train error : ', err_train) 134 | print('Validation error : ', err_val) 135 | print('Test error : ', err_test) 136 | 137 | -------------------------------------------------------------------------------- /examples/stackrbms.lua: -------------------------------------------------------------------------------- 1 | codeFolder = '../code/' 2 | 3 | require('torch') 4 | require(codeFolder..'rbm') 5 | require(codeFolder..'dataset-mnist') 6 | require(codeFolder..'ProFi') 7 | require(codeFolder..'dataset-from-tensor.lua') 8 | require 'paths' 9 | local opts 10 | if not opts then 11 | print '==> processing options' 12 | cmd = torch.CmdLine() 13 | cmd:text() 14 | cmd:text('MNIST/Optimization') 15 | cmd:text() 16 | cmd:text('Options:') 17 | cmd:option('-eta', 0.05, 'LearningRate') 18 | cmd:option('-save', 'logs', 'subdirectory to save/log experiments in') 19 | cmd:option('-datasetsize', 'small', 'small|full size of dataset') 20 | cmd:option('-dataset', 'MNIST', 'MNIST|SIGP select dataset') 21 | cmd:option('-seed', 101, 'random seed') 22 | cmd:option('-folder', '../rbmtest', 'folder where models are saved') 23 | cmd:option('-traintype', 'CD', 'CD|PCD') 24 | cmd:option('-ngibbs', 1, 'Number of gibbs steps, e.g CD-5') 25 | cmd:option('-numepochs1', 1, 'Number of epochs rbm1') 26 | cmd:option('-numepochs2', 1, 'Number of epochs rbm2') 27 | cmd:option('-numepochs3', 1, 'Number of epochs rbm3') 28 | cmd:option('-numepochs4', 1, 'Number of epochs rbm4') 29 | cmd:option('-patience', 3, 'Early stopping patience') 30 | cmd:option('-alpha', 0.0, '0=dicriminative, 1=generative, ]0-1[ = hybrid') 31 | cmd:option('-beta', 0, 'semisupervised training (NOT IMPLEMENTED)') 32 | cmd:option('-dropout', 0, 'dropout probability') 33 | cmd:option('-progress', 1, 'display progressbar') 34 | cmd:option('-L2', 0, 'weight decay') 35 | cmd:option('-L1', 0, 'weight decay') 36 | cmd:option('-momentum', 0, 'momentum') 37 | cmd:option('-sparsity', 0, 'sparsity') 38 | cmd:option('-inittype', 'crbm', 'crbm|gauss Gaussian or uniformly drawn initial weights') 39 | cmd:option('-nhidden1', 10, 'number of hidden units in RBM1') 40 | cmd:option('-nhidden2', 10, 'number of hidden units in RBM2') 41 | cmd:option('-nhidden3', 10, 'number of hidden units in RBM3') 42 | cmd:option('-nhidden4', 10, 'number of hidden units in RBM4') 43 | cmd:option('-batchsize', 1, 'Minibatch size') 44 | cmd:option('-errfunc', 'acc', 'acc|classacc|spec|sens|mcc|ppv|npv|fpr|fdr|F1| Error measure does not apply to sigp data') 45 | cmd:option('-boost', 'none', 'none|diff enable disable boosting') 46 | cmd:option('-flip', 0, 'flip data probability SIGP dataset only') 47 | cmd:text() 48 | opts = cmd:parse(arg or {}) 49 | end 50 | 51 | torch.manualSeed(opts.seed) 52 | torch.setdefaulttensortype('torch.FloatTensor') 53 | 54 | if opts.dataset == "MNIST" then 55 | geometry = {32,32} 56 | if opts.datasetsize == 'full' then 57 | trainData,valData = mnist.loadTrainAndValSet(geometry,opts.boost) 58 | testData = mnist.loadTestSet(nbTestingPatches, geometry) 59 | elseif opts.datasetsize == 'small' then 60 | print(' only using 2000 samples to train quickly (use flag -full to use 60000 samples)') 61 | trainData = mnist.loadTrainSet(2000, geometry,opts.boost) 62 | testData = mnist.loadTestSet(1000, geometry) 63 | valData = mnist.loadTestSet(1000, geometry) 64 | else 65 | print('Unknown datasize') 66 | error() 67 | end 68 | trainData:toProbability() 69 | valData:toProbability() 70 | testData:toProbability() 71 | 72 | local errfunc 73 | local class_to_optimize = 1 74 | 75 | if opts.errfunc == "acc" then 76 | print('Using 1-accuracy error') 77 | errfunc = function(conf) return 1-conf:accuracy() end 78 | elseif opts.errfunc == "spec" then 79 | print('Using 1-specicity error for class',class_to_optimize) 80 | errfunc = function(conf) return 1-conf:specificity()[{1,class_to_optimize}] end 81 | elseif opts.errfunc == "sens" then 82 | print('Using 1-sensitivity error for class',class_to_optimize) 83 | errfunc = function(conf) return 1-conf:sensitivity()[{1,class_to_optimize}] end 84 | elseif opts.errfunc == "mcc" then 85 | print('Using 1-matthew correlation error for class',class_to_optimize) 86 | errfunc = function(conf) return 1-conf:matthewsCorrelation()[{1,class_to_optimize}] end 87 | elseif opts.errfunc == "ppv" then 88 | print('Using 1-positive predictive value error for class',class_to_optimize) 89 | errfunc = function(conf) return 1-conf:positivePredictiveValue()[{1,class_to_optimize}] end 90 | elseif opts.errfunc == "npv" then 91 | print('Using 1-negative predictive value error for class',class_to_optimize) 92 | errfunc = function(conf) return 1-conf:negativePredictiveValue()[{1,class_to_optimize}] end 93 | elseif opts.errfunc == "fpr" then 94 | print('Using 1-false positive rate error for class',class_to_optimize) 95 | errfunc = function(conf) return 1-conf:falsePositiveRate()[{1,class_to_optimize}] end 96 | elseif opts.errfunc == "fdr" then 97 | print('Using 1-false discovery rate error for class',class_to_optimize) 98 | errfunc = function(conf) return 1-conf:falseDiscoveryRate()[{1,class_to_optimize}] end 99 | elseif opts.errfunc == "F1" then 100 | print('Using 1-F1 error for class',class_to_optimize) 101 | errfunc = function(conf) return 1-conf:F1()[{1,class_to_optimize}] end 102 | elseif opts.errfunc == "classacc" then 103 | print('Using 1-class Accuracy error for class',class_to_optimize) 104 | errfunc = function(conf) return 1-conf:classAccuracy()[{1,class_to_optimize}] end 105 | else 106 | print('unknown errorfunction') 107 | error() 108 | end 109 | 110 | 111 | opts.errorfunction = errfunc 112 | --print(opts.errorfunction,errfunc) 113 | --error() 114 | elseif opts.dataset == 'SIGP' then 115 | --th runRBM.lua -eta 0.05 -alpha 1 -nhidden 500 -folder sigp_discriminative -dataset SIGP -progress 0 116 | require(codeFolder..'dataset-sigp') 117 | geometry = {1,882} 118 | classes = {'SP','CS','TM','OTHER'} 119 | if opts.datasetsize == "full" then 120 | print("Loading full SIGP dataset...") 121 | trainFiles = {'../sigp-data/Eukar1.dat', 122 | '../sigp-data/Eukar2.dat', 123 | '../sigp-data/Eukar3.dat'} 124 | trainData = sigp.loadsigp(trainFiles,opts.boost,opts.flip) 125 | elseif opts.datasetsize == "small" then 126 | print("Loading small SIGP dataset...") 127 | trainFiles = {'../sigp-data/Eukar1.dat'} 128 | trainData = sigp.loadsigp(trainFiles,opts.boost,opts.flip) 129 | elseif opts.datasetsize == "schmidhuber" then 130 | --th runRBM.lua -dataset SIGP -datasetsize schmidhuber -nhidden 100 -numepochs 5000 -alpha 0 -folder schmidhuber_probs 131 | require(codeFolder..'dataset-from-tensor.lua') 132 | local trainTensor = torch.load('../sigp-data/schmidhuber_sigp_inputs.dat'):view(-1,1,882) 133 | local trainLabels = torch.load('../sigp-data/schmidhuber_sigp_targets.dat'):view(-1,1,4) 134 | trainData = datatensor.createDataset(trainTensor,trainLabels,classes,geometry) 135 | -- for i = 1,10000 do 136 | -- if trainData:next()[2]:sum() > 1 then 137 | -- print(i) 138 | -- end 139 | -- end 140 | -- error() 141 | elseif opts.datasetsize == "schmidhuberweighted" then 142 | --th runRBM.lua -dataset SIGP -datasetsize schmidhuberweighted -nhidden 100 -numepochs 5000 -alpha 0 -folder schmidhuber_weighted 143 | require(codeFolder..'dataset-from-tensor.lua') 144 | local trainTensor = torch.load('../sigp-data/schmidhuber_sigp_inputs_weighted.dat'):view(-1,1,882) 145 | local trainLabels = torch.load('../sigp-data/schmidhuber_sigp_targets_weighted.dat'):view(-1,1,4) 146 | trainData = datatensor.createDataset(trainTensor,trainLabels,classes,geometry) 147 | else 148 | print('unknown datasize') 149 | error() 150 | end 151 | valFiles = {'../sigp-data/Eukar4.dat'} 152 | testFiles = {'../sigp-data/Eukar5.dat'} 153 | 154 | valData = sigp.loadsigp(valFiles,'none',0) 155 | testData = sigp.loadsigp(testFiles,'none',0) 156 | 157 | 158 | if opts.datasetsize ~= "schmidhuber" then 159 | errorfunc = function(conf) 160 | conf:updateValids() 161 | conf:printscore('mcc') 162 | local mcc = conf:matthewsCorrelation() 163 | return 1-mcc[{1,2}] 164 | end 165 | else 166 | errorfunc = function(conf) 167 | conf:updateValids() 168 | conf:printscore('acc') 169 | local mcc = conf:matthewsCorrelation() 170 | return 1-conf:accuracy() 171 | end 172 | end 173 | opts.errorfunction =errorfunc 174 | elseif opts.dataset == 'NEWS' then 175 | print("Loading NEWS data") 176 | require(codeFolder..'dataset-from-tensor.lua') 177 | trainTensor = torch.load('../20news-data/trainData.dat'):view(-1,1,5000) 178 | valTensor = torch.load('../20news-data/valData.dat'):view(-1,1,5000) 179 | testTensor = torch.load('../20news-data/testData.dat'):view(-1,1,5000) 180 | trainLabels = torch.load('../20news-data/trainLabels.dat'):view(-1) 181 | valLabels = torch.load('../20news-data/valLabels.dat'):view(-1) 182 | testLabels = torch.load('../20news-data/testLabels.dat'):view(-1) 183 | classes = {} 184 | for i = 1,20 do classes[i] = tostring(i) end 185 | geometry = {1,5000} 186 | trainData = datatensor.createDataset(trainTensor,oneOfK(20,trainLabels),classes,geometry) 187 | valData = datatensor.createDataset(valTensor,oneOfK(20,valLabels),classes,geometry) 188 | testData = datatensor.createDataset(testTensor,oneOfK(20,testLabels),classes,geometry) 189 | 190 | 191 | else 192 | 193 | print('unknown dataset') 194 | error() 195 | end 196 | 197 | --print(test) 198 | num_threads = 1 199 | torch.setnumthreads(num_threads) 200 | if torch.getnumthreads() < num_threads then 201 | print("Setting number of threads had no effect. Maybe install with gcc 4.9 for openMP?") 202 | end 203 | 204 | 205 | topAlpha = opts.alpha 206 | -- SETUP RBM 207 | 208 | os.execute('mkdir -p ' .. opts.folder) -- create tempfolder if it does not exist 209 | opts.learningrate = opts.eta 210 | 211 | 212 | geometry = {1,opts.nhidden1} 213 | classes = trainData:classnames() 214 | function train(layer,nhidden,numepochs,toprbm,alpha,dropout,traindata,valdata,testdata) 215 | print("################# TRAINING RBM "..layer .." #################") 216 | returnLabels = true 217 | opts.finalfile = paths.concat(opts.folder,'final_rbm'..layer..'.asc') 218 | opts.tempfile = paths.concat(opts.folder,'temp_rbm'..layer..'.asc') -- current best is saved to this folder 219 | opts.n_hidden = nhidden 220 | opts.numepochs = numepochs 221 | opts.toprbm = toprbm 222 | opts.isgpu = 0 223 | opts.alpha = alpha -- generative for non top rbm 224 | opts.dropout = dropout 225 | 226 | local rbm = rbmsetup(opts,traindata) 227 | rbm = rbmtrain(rbm,traindata,valdata) 228 | 229 | 230 | trainTensor2,trainLabels2 = rbmuppass(rbm,traindata,returnLabels) 231 | valTensor2,valLabels2 = rbmuppass(rbm,valdata,returnLabels) 232 | testTensor2,testLabels2 = rbmuppass(rbm,testdata,returnLabels) 233 | 234 | 235 | traindata2 = datatensor.createDataset(trainTensor2,trainLabels2,classes,geometry) 236 | valdata2 = datatensor.createDataset(valTensor2,valLabels2,classes,geometry) 237 | testdata2 = datatensor.createDataset(testTensor2,testLabels2,classes,geometry) 238 | return rbm,traindata2,valdata2,testdata2 239 | end 240 | 241 | 242 | 243 | 244 | -- Train layer one 245 | print(trainData,valData,testData) 246 | rbm1,trainData2,valData2,testData2 = train(1,opts.nhidden1,opts.numepochs1,false,1,opts.dropout,trainData,valData,testData) 247 | collectgarbage() 248 | rbm2,trainData3,valData3,testData3 = train(2,opts.nhidden2,opts.numepochs2,false,1,opts.dropout,trainData2,valData2,testData2) 249 | collectgarbage() 250 | rbm3,trainData4,valData4,testData4 = train(3,opts.nhidden3,opts.numepochs3,false,1,opts.dropout,trainData3,valData3,testData3) 251 | collectgarbage() 252 | rbm4,trainData5,valData5,testData5 = train(4,opts.nhidden4,opts.numepochs4,true,opts.alpha,opts.dropout,trainData4,valData4,testData4) 253 | 254 | 255 | 256 | 257 | 258 | 259 | -- DO TRAINING 260 | --rbm = trainAndPrint(opts,train,val,test,tempfolder,finalfile) 261 | 262 | local err_train = geterror(rbm4,trainData4) 263 | local err_val = geterror(rbm4,valData4) 264 | local err_test = geterror(rbm4,testData4) 265 | print('Train error : ', err_train) 266 | print('Validation error : ', err_val) 267 | print('Test error : ', err_test) 268 | 269 | 270 | 271 | -------------------------------------------------------------------------------- /examples/test_image.lua: -------------------------------------------------------------------------------- 1 | -- An example which uses the create_weight_image to display the filter for the 2 | -- rbm. 3 | -- The RBM is trained on a small data set (for 1 epoch) after which the image 4 | -- is generated. 5 | 6 | codeFolder = '../code/' 7 | 8 | require('torch') 9 | require(codeFolder..'rbm') 10 | require(codeFolder..'dataset-mnist') 11 | require(codeFolder..'rbm-visualisation.lua') 12 | ProFi = require(codeFolder..'ProFi') 13 | require('paths') 14 | 15 | -- create the options 16 | if not opts then 17 | cmd = torch.CmdLine() 18 | cmd:option('-n_hidden', 500, 'number of hidden units') 19 | cmd:option('-datasetsize', 'full', 'small|full size of dataset') 20 | cmd:option('-image_name', 'demo.png', 'the filename with which the generated image should be saved') 21 | opts = cmd:parse(arg or {}) 22 | end 23 | 24 | torch.setdefaulttensortype('torch.FloatTensor') 25 | 26 | 27 | -- The supplied MNIST images are 32x32 pixels in size. 28 | geometry = {32,32} 29 | 30 | -- Only load the small dataset to start with. 31 | if opts.datasetsize == 'full' then 32 | trainData,valData = mnist.loadTrainAndValSet(geometry,'none') 33 | testData = mnist.loadTestSet(nbTestingPatches, geometry) 34 | else 35 | dataSize = 2000 36 | trainData = mnist.loadTrainSet(dataSize, geometry,'none') 37 | testData = mnist.loadTestSet(dataSize/2, geometry) 38 | valData = mnist.loadTestSet(dataSize/2, geometry) 39 | end 40 | 41 | -- The datasets need to be converted probabilities 42 | trainData:toProbability() 43 | valData:toProbability() 44 | testData:toProbability() 45 | 46 | -- Create the rbm 47 | rbm = rbmsetup(opts, trainData) 48 | 49 | -- Train the rbm 50 | rbm = rbmtrain(rbm,trainData,valData) 51 | 52 | -- Output the rbm weights 53 | create_weight_image(rbm, geometry, opts.image_name) -------------------------------------------------------------------------------- /examples/testconv.lua: -------------------------------------------------------------------------------- 1 | 2 | 3 | codeFolder = '../code/' 4 | 5 | require('torch') 6 | require(codeFolder..'rbm') 7 | require(codeFolder..'dataset-mnist') 8 | require(codeFolder..'ProFi') 9 | require 'paths' 10 | torch.setdefaulttensortype('torch.FloatTensor') 11 | require 'nn' 12 | 13 | 14 | tester = torch.Tester() 15 | 16 | nInput = 1 17 | filterSize = 2 18 | pad = filterSize -1 19 | nFilters = 3 20 | poolSize =2 21 | inputSize = 5 22 | hidSize = inputSize - filterSize +1 23 | 24 | 25 | maxPool = function(filters,poolsize) 26 | -- maxpool over several filters 27 | -- filters should be a 3d matrix 28 | local pool = function(x) 29 | --Calculate exp(x) / [sum(exp(x)) +1] in numerically stable way 30 | local m = torch.max(x) 31 | local exp_x = torch.exp(x - m) 32 | -- normalizer = sum(exp(x)) + 1 in scaled domain 33 | local normalizer = torch.exp(-m) + exp_x:sum() 34 | exp_x:cdiv( torch.Tensor(exp_x:nElement()):fill(normalizer) ) 35 | return exp_x 36 | end 37 | 38 | local maxPoolSingle = function(hf,hfres,poolsize) 39 | -- Performs probabilistic maxpooling. 40 | -- For each block of poolsize x poolsize calculate 41 | -- exp(h_i) / [sum_i(exp(h_i)) + 1] 42 | -- hf should be a 2d matrix 43 | local height = hf:size(1) 44 | local width = hf:size(2) 45 | --poshidprobs = torch.Tensor(height,width):typeAs(hf) 46 | -- notation h_(i,j) 47 | for i_start = 1,height,poolsize do 48 | i_end = i_start+poolsize -1 49 | for j_start = 1,width,poolsize do -- columns 50 | j_end = j_start+poolsize -1 51 | hfres[{{i_start,i_end},{j_start,j_end}}] = pool(hf[{{i_start,i_end},{j_start,j_end}}]) 52 | end 53 | end 54 | end 55 | 56 | 57 | 58 | dest = torch.Tensor(filters:size()):typeAs(filters) 59 | for i = 1, filters:size(1) do 60 | maxPoolSingle(filters[{i,{},{}}],dest[{ i,{},{} }],poolsize) 61 | end 62 | return dest 63 | end 64 | 65 | 66 | function invertWeights(x) 67 | -- a mxn matrix 68 | local xtemp = torch.Tensor(x:size()) 69 | local idx = x:size(2) 70 | for i = 1,x:size(2) do 71 | xtemp[{{},idx}] = x[{{},i}] 72 | idx = idx -1 73 | end 74 | return xtemp 75 | end 76 | 77 | function invertWeights3d(x) 78 | res = torch.Tensor(x:size()) 79 | for i = 1,x:size(1) do 80 | res[{i,{},{}}] = invertWeights(x[{i,{},{}}]) 81 | end 82 | return x 83 | end 84 | 85 | 86 | W = torch.Tensor({1,-2,-3,7,2,1,-3,2,-1,2,5,2}) -- Filter: | 1, -2| 87 | W = W:resize(nFilters,4) -- |-3, 7| 88 | 89 | 90 | 91 | modelup = nn.Sequential() 92 | --modelup:add(nn.Reshape(1,inputSize,inputSize)) 93 | modelup:add(nn.SpatialConvolutionMM(1,nFilters,filterSize,filterSize)) 94 | --modelup:add(nn.Sigmoid()) 95 | 96 | 97 | 98 | modeldownx = nn.Sequential() 99 | modeldownx:add(nn.SpatialZeroPadding(pad, pad, pad, pad)) -- pad (filterwidth -1) 100 | modeldownx:add(nn.SpatialConvolution(nFilters,1,filterSize,filterSize)) 101 | --modeldownx:add(nn.Sigmoid()) 102 | 103 | 104 | -- SET TESTING WEIGHTS OF UP MODEL 105 | modelup.modules[1].weight = W 106 | modelup.modules[1].bias = torch.zeros(nFilters) 107 | 108 | 109 | -- -- SET TESTING WEIGHTS OF DOWNX MODEL 110 | Wtilde = invertWeights(W) -- Because the weights are unrolled to vectors flipping 111 | modeldownx.modules[2].weight = Wtilde:resize(modeldownx.modules[2].weight:size()) 112 | modeldownx.modules[2].bias = torch.zeros(1) 113 | 114 | 115 | rbmup = function(x) 116 | local res = modelup:forward(x):clone() 117 | return maxPool(res,filterSize) 118 | end 119 | 120 | rbmdownx = function(x) 121 | return modeldownx:forward(x):clone() 122 | end 123 | 124 | 125 | x = torch.Tensor({1,2,7,1,0,3,4,3,4,1,2,3,4,1,2,5,4,3,1,1,3,3,2,1,2}):resize(1,5,5) / 20 126 | h0 = rbmup(x) 127 | 128 | -- sampler 129 | v1 = rbmdownx(h0) 130 | 131 | -- sampler 132 | h1 = rbmup(v1) 133 | 134 | nWeights = nFilters * filterSize * filterSize 135 | 136 | shrink = filterSize -1 137 | nHidden = (x:size(2)-shrink) * (x:size(3)-shrink) 138 | hidBias = ( h0:sum(3):sum(2):squeeze() - h1:sum(3):sum(2):squeeze() ) / nHidden 139 | 140 | visBias = torch.Tensor({( x:sum() - v1:sum() ) / (x:nElement())}) 141 | 142 | 143 | -- W grads 144 | 145 | shrink = filterSize -1 146 | x_h = x:size(2) 147 | x_w = x:size(3) 148 | hid_h =x_h-filterSize+1 149 | hid_w =x_w-filterSize+1 150 | --(filterSize:x_h-filterSize+1,filterSize:x_w-filterSize+1 151 | 152 | x_in = x[ {{}, {filterSize, x_h-filterSize+1 }, {filterSize, x_w-filterSize+1 }}] 153 | h0_in = h0[{{}, {filterSize, hid_h-filterSize+1}, {filterSize, hid_w-filterSize+1}}] 154 | h0_in_filter = h0_in:resize(nFilters,nInput,h0_in:size(2),h0_in:size(3)) 155 | 156 | nnw = nn.SpatialConvolution(nInput,nFilters,filterSize,filterSize) 157 | -- -- -- -- gradsnn.weight is poshidprobs(Wfilter:Hhidden-Wfilter+1,Wfilter:Whidden-Wfilter+1,:,:) 158 | nnw.weight = h0_in_filter 159 | nnw.bias = torch.zeros(nFilters) 160 | 161 | dw_pos = nnw:forward(x_in) 162 | 163 | -- -- -- gradsnn:add(nn.Reshape(1,inputSize-filterSize,inputSize-filterSize)) 164 | 165 | 166 | 167 | 168 | -- 169 | -- DEFINING TESTS RESULTS 170 | -- 171 | x_test = torch.Tensor({ 172 | 0.0500, 0.1000, 0.3500, 0.0500, 0, 173 | 0.1500, 0.2000, 0.1500, 0.2000, 0.0500, 174 | 0.1000, 0.1500, 0.2000, 0.0500, 0.1000, 175 | 0.2500, 0.2000, 0.1500, 0.0500, 0.0500, 176 | 0.1500, 0.1500, 0.1000, 0.0500, 0.1000}):resize(1,5,5) 177 | 178 | poshidprobs_test = torch.Tensor({ 179 | 0.2756, 0.1066, 0.4334, 0.1069, 180 | 0.2042, 0.2898, 0.0792, 0.2500, 181 | 0.2405, 0.1873, 0.1723, 0.1811, 182 | 0.2405, 0.1782, 0.1904, 0.2840, 183 | 184 | 0.1723, 0.1904, 0.3180, 0.1058, 185 | 0.2445, 0.2445, 0.1579, 0.2603, 186 | 0.1586, 0.1937, 0.1956, 0.2056, 187 | 0.2749, 0.2141, 0.2056, 0.2162, 188 | 189 | 0.2073, 0.3777, 0.2121, 0.2465, 190 | 0.1614, 0.1972, 0.3327, 0.1224, 191 | 0.3485, 0.2582, 0.2598, 0.2024, 192 | 0.1819, 0.1566, 0.2127, 0.2024 193 | }):resize(3,4,4) 194 | 195 | -- -- V1 tests 196 | v1_test = torch.Tensor({ 197 | 0.4129, 0.1453, 1.5898, -0.0524, 0.3851, 198 | 0.2243, 4.4263, 0.8105, 6.0015, 1.4578, 199 | -0.3297, 2.3153, 4.5030, 1.3770, 2.7639, 200 | 1.1535, 3.4532, 2.9705, 2.7326, 2.1365, 201 | -0.6367, 2.2036, 1.8644, 1.6806, 2.8251}):resize(1,5,5) 202 | 203 | 204 | neghidprobs_test = torch.Tensor({ 205 | 0.9549, 0.0000, 1.0000, 0.0000, 206 | 0.0000, 0.0451, 0.0000, 0.0000, 207 | 0.9848, 0.0000, 0.2265, 0.0000, 208 | 0.0152, 0.0000, 0.0001, 0.7734, 209 | 210 | 0.0558, 0.0000, 0.1035, 0.0000, 211 | 0.2138, 0.7304, 0.0000, 0.8965, 212 | 0.0009, 0.0006, 0.1927, 0.0009, 213 | 0.9923, 0.0062, 0.1175, 0.6887, 214 | 215 | 0.0000, 0.9999, 0.0000, 0.0621, 216 | 0.0000, 0.0001, 0.9379, 0.0000, 217 | 0.0000, 1.0000, 0.0287, 0.9688, 218 | 0.0000, 0.0000, 0.0010, 0.0015}):resize(3,4,4) 219 | 220 | visBias_test = torch.Tensor({-1.7306}) 221 | hidBias_test = torch.Tensor({ -0.0363, -0.0401, -0.0200}) 222 | dw_test = torch.Tensor({ 223 | -1.0870, -0.2431, 224 | -0.6546, -0.7233, 225 | 226 | -3.9614, -0.7434, 227 | -2.1294, -3.7218, 228 | 229 | -3.0255, -10.0109, 230 | -7.5756, -4.2333}) 231 | 232 | dw_pos_test = torch.Tensor({ 233 | 0.1324, 0.1054, 234 | 0.1226, 0.0986, 235 | 236 | 0.1408, 0.1168, 237 | 0.1363, 0.0956, 238 | 239 | 0.1800, 0.1607, 240 | 0.1867, 0.1078 241 | }):resize(3,2,2) 242 | 243 | print "################ TESTS ############################" 244 | assert(checkequality(x,x_test,-4,false)) 245 | assert(checkequality(h0,poshidprobs_test,-4,false)) 246 | assert(checkequality(v1,v1_test,-4,false)) 247 | assert(checkequality(h1,neghidprobs_test,-4,false)) 248 | 249 | print("Testing Gradients...") 250 | assert(checkequality(visBias,visBias_test,-4,false)) 251 | assert(checkequality(hidBias,hidBias_test,-4,false)) 252 | --assert(checkequality(dw_pos,dw_pos_test,-4,false)) 253 | print('OK') 254 | -- -- ---UPDATES 255 | 256 | 257 | 258 | 259 | 260 | print('ADD SIGMOIDS!!! ') 261 | -------------------------------------------------------------------------------- /examples/testconv3d.lua: -------------------------------------------------------------------------------- 1 | 2 | 3 | codeFolder = '../code/' 4 | 5 | require('torch') 6 | require(codeFolder..'rbm') 7 | conv = require(codeFolder..'conv-functions') 8 | 9 | include(codeFolder..'mnist.lua') 10 | require(codeFolder..'ProFi') 11 | require 'paths' 12 | torch.setdefaulttensortype('torch.FloatTensor') 13 | require 'nn' 14 | 15 | num_threads = 2 16 | torch.setnumthreads(num_threads) 17 | -- W : vis - hid weights [ #hid x #vis ] 18 | -- U : label - hid weights [ #hid x #n_classes ] 19 | -- b : bias of visible layer [ #vis x 1] 20 | -- c : bias of hidden layer [ #hid x 1] 21 | -- d : bias of label layer [ #n_classes x 1] 22 | 23 | -- --n_input = 3 -- color channels 24 | -- filter_size = 2 -- 25 | -- n_input = 3 26 | -- n_filters = 5 -- how many filters to use 27 | -- pool_size =2 -- how large the probabilistic maxpooling 28 | -- input_size = 5 -- the size of the input image 29 | -- n_classes = 4 30 | 31 | -- -- create test data 32 | -- -- X is always a row vector when it is stored. 33 | -- x3d = torch.Tensor({ 34 | -- 0.0500, 0.1000, 0.3500, 0.0500, 0, 35 | -- 0.1500, 0.2000, 0.1500, 0.2000, 0.0500, 36 | -- 0.1000, 0.1500, 0.2000, 0.0500, 0.1000, 37 | -- 0.2500, 0.2000, 0.1500, 0.0500, 0.0500, 38 | -- 0.1500, 0.1500, 0.1000, 0.0500, 0.1000, 39 | 40 | -- 0.0400, 0.0800, 0.1200, 0.1200, 0.1200, 41 | -- 0.1200, 0.1600, 0.1200, 0.1600, 0.0400, 42 | -- 0.0800, 0.1200, 0.1600, 0.0400, 0.0800, 43 | -- 0.2000, 0.1600, 0.1200, 0.0400, 0.0400, 44 | -- 0.1200, 0.1200, 0.0800, 0.0400, 0.0800, 45 | 46 | -- 0.0500, -0.1000, 0.3500, 0.0500, 0, 47 | -- 0.1500, -0.2000, 0.1500, 0.2000, 0.0500, 48 | -- 0.1000, 0.1500, 0.2000, 0.0500, 0.1000, 49 | -- 0.2500, 0.2000, 0.1500, 0.0500, 0.0500, 50 | -- 0.1500, 0.4000, 0.1000, 0.0500, 0.1000}):resize(1,n_input,math.pow(input_size,2)) 51 | 52 | -- x2d = x3d:view(1,75) 53 | 54 | -- y = torch.Tensor({0,0,1, 0}):resize(1,n_classes) 55 | 56 | 57 | -- -- Rows are 1 filters for 3 color channels 58 | -- W = torch.Tensor({ 1,-2,-3,7, 1 ,2, 3,4, -2 ,4,5 ,-2, 59 | -- 2,1,-3,2 , -1,-4,2,1 , 2 ,3,-4,2, 60 | -- 2,1,-3,2 , -1,-4,2,1 , 2 ,3,-4,2, 61 | -- 2,1,-3,2 , -1,-4,2,1 , 2 ,3,-4,2, 62 | -- 2,1,-3,2 , -1,-4,2,1 , 2 ,3,-4,2}):resize(1,60) 63 | 64 | -- labels = torch.Tensor{4} 65 | -- train = {} 66 | -- train.data = x3d--:view(1,1,75) 67 | -- train.labels = labels 68 | 69 | 70 | 71 | -- sizes = conv.calcconvsizes(filter_size,n_filters,n_classes,input_size,pool_size,train) 72 | -- --opts.n_hidden = 10 73 | -- opts = {} 74 | -- conv.setupsettings(opts,sizes) 75 | -- opts.W = W -- Testing weights 76 | -- opts.U = torch.zeros(opts.U:size()) -- otherwise tests fails 77 | -- opts.numepochs = 1 78 | -- rbm = rbmsetup(opts,train) 79 | -- debug_1 = conv.setupfunctions(rbm,sizes) -- modofies RBM to use conv functions 80 | 81 | -- opts_2 = {} 82 | -- conv.setupsettings(opts_2,sizes) 83 | -- opts_2.W = W:clone() -- Testing weights 84 | -- opts_2.U = torch.zeros(opts.U:size()) -- otherwise tests fails 85 | -- opts.numepochs = 5 86 | -- rbm_2 = rbmsetup(opts_2,train) 87 | -- debug_2 = conv.setupfunctions(rbm_2,sizes) 88 | -- --print(rbm) 89 | -- --- Clone original weighs 90 | 91 | 92 | -- W_org = rbm.W:clone() 93 | -- U_org = rbm.U:clone() 94 | -- b_org = rbm.b:clone() 95 | -- c_org = rbm.c:clone() 96 | -- d_org = rbm.d:clone() 97 | 98 | 99 | 100 | 101 | -- -- Setup a trainng session with a single epoch and generative training 102 | -- rbm.rand = function(m,n) return torch.Tensor(m,n):fill(1):mul(0.53)end -- for testing 103 | -- rbm.alpha = 1 -- generative training 104 | -- rbm.momentum = 0 105 | 106 | 107 | -- -- Test the values grads.calculate grads produce 108 | -- grads.calculategrads(rbm,x2d,y) 109 | -- dc_cgrads = rbm.dc:clone() 110 | -- db_cgrads = rbm.db:clone() 111 | -- dW_cgrads = rbm.dW:clone() 112 | 113 | 114 | -- print("#######################---EVALAUTE---#######################") 115 | 116 | -- h0,h0_rnd,hk,vkx,vkx_rnd,vky_rnd = rbm.generativestatistics(rbm,x2d,y,tcwx) 117 | -- dW, dU, db, dc, dd= rbm.generativegrads(x2d,y,h0,hk,vkx_rnd,vky_rnd) 118 | -- rbm.generativegrads(x2d,y,h0,hk,vkx_rnd,vky_rnd) 119 | -- updategradsandmomentum(rbm) 120 | -- --print(">>>>updategradsandmomentum rbm.db: ",rbm.db) 121 | -- -- update vW, vU, vb, vc and vd, formulae: vx = vX*mom + dX 122 | -- --updateweights(rbm) 123 | 124 | -- print(">>>>>>RBMTRAIN") 125 | 126 | -- rbm = rbmtrain(rbm,train) 127 | -- print("<<<<<<<<<<<<<<<") 128 | 129 | -- print(">>>>>>RBMTRAIN MULTIPLE updates") 130 | -- rbm_2.numepochs = 5 131 | -- rbm_2 = rbmtrain(rbm_2,train,train) 132 | -- print("<<<<<<<<<<<<<<<") 133 | 134 | -- print("OK\n\n") 135 | -- -- TESTING VALUES 136 | -- h0_test = torch.Tensor({ 137 | -- 0.2437, 0.0720, 0.4036, 0.1796, 138 | -- 0.0368, 0.6303, 0.2626, 0.1180, 139 | -- 0.3400, 0.2146, 0.2562, 0.2036, 140 | -- 0.1024, 0.3299, 0.2077, 0.2562, 141 | 142 | -- 0.0326, 0.7167, 0.4101, 0.0376, 143 | -- 0.0612, 0.0999, 0.1063, 0.3162, 144 | -- 0.1335, 0.1680, 0.1829, 0.2021, 145 | -- 0.4997, 0.0718, 0.2042, 0.2469, 146 | 147 | -- 0.0326, 0.7167, 0.4101, 0.0376, 148 | -- 0.0612, 0.0999, 0.1063, 0.3162, 149 | -- 0.1335, 0.1680, 0.1829, 0.2021, 150 | -- 0.4997, 0.0718, 0.2042, 0.2469, 151 | 152 | -- 0.0326, 0.7167, 0.4101, 0.0376, 153 | -- 0.0612, 0.0999, 0.1063, 0.3162, 154 | -- 0.1335, 0.1680, 0.1829, 0.2021, 155 | -- 0.4997, 0.0718, 0.2042, 0.2469, 156 | 157 | -- 0.0326, 0.7167, 0.4101, 0.0376, 158 | -- 0.0612, 0.0999, 0.1063, 0.3162, 159 | -- 0.1335, 0.1680, 0.1829, 0.2021, 160 | -- 0.4997, 0.0718, 0.2042, 0.2469 161 | -- }):resize(1,5*4*4) 162 | -- h0_rnd_test = torch.Tensor({ 163 | -- 1 1 1 1 1 1 164 | -- 1 1 1 1 1 0 165 | -- 1 1 1 1 1 1 166 | -- 1 0 0 1 1 0 167 | -- 1 0 0 1 1 1 168 | -- 1 0 1 0 0 0 169 | 170 | -- 0 0 0 1 1 0 171 | -- 1 0 0 1 1 0 172 | -- 1 0 0 0 1 0 173 | -- 0 1 1 1 0 0 174 | -- 1 0 0 0 1 0 175 | -- 1 0 0 1 0 0 176 | 177 | -- 1 1 1 0 0 1 178 | -- 0 0 0 0 0 0 179 | -- 0 0 0 1 0 1 180 | -- 1 1 0 0 0 1 181 | -- 0 1 1 0 0 1 182 | -- 1 0 1 1 1 1 183 | 184 | -- 0 0 0 0 0 1 185 | -- 1 1 1 1 1 1 186 | -- 0 1 1 1 1 1 187 | -- 1 1 0 0 1 1 188 | -- 1 1 1 1 1 1 189 | -- 1 0 0 1 0 0 190 | 191 | -- 1 0 1 1 0 0 192 | -- 0 0 0 0 0 1 193 | -- 0 0 0 0 0 1 194 | -- 0 1 1 1 0 1 195 | -- 0 1 0 1 0 0 196 | -- 0 1 1 0 1 1}):resize(1,5*6*6) 197 | -- vkx_rnd_test = torch.Tensor({ 198 | -- 0.5046, 5.4483, 6.4071, 1.3132, -0.2088, 199 | -- -0.5959, -5.2479, 0.3576, 7.6641, 2.5863, 200 | -- 0.5627, -0.9298, 5.1102, 0.5802, 3.7569, 201 | -- 1.4782, 3.4863, 1.3514, 2.8530, 3.5173, 202 | -- -6.3037, 2.8630, -0.1891, -0.6443, 3.7686, 203 | 204 | -- 0.1133, -2.8290, -12.5596, -5.7247, -0.2421, 205 | -- 0.7840, 6.3792, 7.1448, 1.7716, -3.9546, 206 | -- 0.4063, 1.1692, 1.8246, 1.3398, -1.0906, 207 | -- 0.1914, -3.8660, 2.6643, 0.4013, -1.8153, 208 | -- 4.3049, 3.9727, 3.8634, 4.3912, 2.0124, 209 | 210 | -- -0.2266, 6.9558, 11.3613, 6.4770, 1.1691, 211 | -- 1.1131, -10.9126, 5.0921, 7.3895, 4.2077, 212 | -- -0.4076, 5.8450, 2.9765, 0.2859, 5.5336, 213 | -- 3.3570, 5.0936, 2.6688, 3.4780, 5.1977, 214 | -- -7.4833, 4.2936, -2.3139, -1.4514, 1.4628 215 | -- }):resize(1,n_input*5*5) 216 | -- vkx_sigm_rnd_test = torch.Tensor({ 217 | -- 1, 1, 1, 1, 0, 218 | -- 0, 0, 0, 1, 1, 219 | -- 1, 1, 1, 1, 1, 220 | -- 0, 1, 1, 1, 1, 221 | -- 0, 1, 0, 0, 1, 222 | 223 | -- 1, 0, 0, 0, 1, 224 | -- 1, 1, 1, 1, 0, 225 | -- 0, 0, 0, 0, 0, 226 | -- 1, 1, 1, 1, 0, 227 | -- 1, 1, 1, 1, 1, 228 | 229 | -- 0, 1, 1, 1, 1, 230 | -- 1, 0, 0, 1, 1, 231 | -- 1, 1, 1, 1, 1, 232 | -- 0, 1, 1, 1, 1, 233 | -- 0, 1, 0, 0, 1}):resize(1,n_input*5*5) 234 | -- vkx_sigm_test = torch.Tensor({ 235 | -- 0.7311, 0.9975, 1.0000, 0.9526, 0.1192, 236 | -- 0.0474, 0.0180, 0.0003, 1.0000, 1.0000, 237 | -- 0.9999, 0.9997, 1.0000, 0.9975, 1.0000, 238 | -- 0.0009, 0.9933, 0.9991, 1.0000, 1.0000, 239 | -- 0.0000, 0.9933, 0.0003, 0.5000, 1.0000, 240 | 241 | -- 0.7311, 0.1192, 0.0000, 0.0000, 0.8808, 242 | -- 0.9526, 1.0000, 1.0000, 0.9999, 0.0000, 243 | -- 0.0474, 0.0000, 0.0000, 0.0067, 0.0000, 244 | -- 0.9991, 0.9820, 1.0000, 0.8808, 0.0025, 245 | -- 0.9997, 0.9991, 1.0000, 1.0000, 0.9997, 246 | 247 | -- 0.1192, 1.0000, 1.0000, 1.0000, 0.9820, 248 | -- 0.9933, 0.0000, 0.2689, 1.0000, 1.0000, 249 | -- 0.9975, 1.0000, 1.0000, 0.9820, 1.0000, 250 | -- 0.0474, 0.9933, 0.9933, 1.0000, 1.0000, 251 | -- 0.0000, 1.0000, 0.0000, 0.0067, 0.9975}):resize(1,n_input*5*5) 252 | -- hk_sigm_test = torch.Tensor({ 253 | -- 0.7304, 0.5001, 0.5285, 0.6989, 254 | -- 0.5001, 0.5006, 0.5105, 0.5005, 255 | -- 0.5039, 0.5288, 0.6708, 0.5033, 256 | -- 0.7013, 0.5039, 0.5033, 0.5651, 257 | 258 | -- 0.5002, 0.7309, 0.7311, 0.5000, 259 | -- 0.5000, 0.5000, 0.5000, 0.5000, 260 | -- 0.7309, 0.5002, 0.5006, 0.5002, 261 | -- 0.5000, 0.5000, 0.5001, 0.7303, 262 | 263 | -- 0.5002, 0.7309, 0.7311, 0.5000, 264 | -- 0.5000, 0.5000, 0.5000, 0.5000, 265 | -- 0.7309, 0.5002, 0.5006, 0.5002, 266 | -- 0.5000, 0.5000, 0.5001, 0.7303, 267 | 268 | -- 0.5002, 0.7309, 0.7311, 0.5000, 269 | -- 0.5000, 0.5000, 0.5000, 0.5000, 270 | -- 0.7309, 0.5002, 0.5006, 0.5002, 271 | -- 0.5000, 0.5000, 0.5001, 0.7303, 272 | 273 | -- 0.5002, 0.7309, 0.7311, 0.5000, 274 | -- 0.5000, 0.5000, 0.5000, 0.5000, 275 | -- 0.7309, 0.5002, 0.5006, 0.5002, 276 | -- 0.5000, 0.5000, 0.5001, 0.7303}):resize(1,5*4*4) 277 | -- h0_filter_test = torch.Tensor({ 278 | -- 0.6303, 0.2626, 279 | -- 0.2146, 0.2562, 280 | 281 | -- 0.0999, 0.1063, 282 | -- 0.1680, 0.1829, 283 | 284 | -- 0.0999, 0.1063, 285 | -- 0.1680, 0.1829, 286 | 287 | -- 0.0999, 0.1063, 288 | -- 0.1680, 0.1829, 289 | 290 | -- 0.0999, 0.1063, 291 | -- 0.1680, 0.1829}):resize(n_filters,1,2,2) 292 | -- h0_filter_test = sigm(h0_filter_test) 293 | -- x_filter_test = torch.Tensor({ 294 | -- 0.2000, 0.1500, 0.2000, 295 | -- 0.1500, 0.2000, 0.0500, 296 | -- 0.2000, 0.1500, 0.0500, 297 | 298 | -- 0.1600, 0.1200, 0.1600, 299 | -- 0.1200, 0.1600, 0.0400, 300 | -- 0.1600, 0.1200, 0.0400, 301 | 302 | -- -0.2000, 0.1500, 0.2000, 303 | -- 0.1500, 0.2000, 0.0500, 304 | -- 0.2000, 0.1500, 0.0500}):resize(n_input,3,3) 305 | -- x_filter_test = sigm(x_filter_test) 306 | -- db_test = torch.Tensor({-0.5540,-0.4976,-0.6080}):resize(3,1) 307 | -- dc_test = torch.Tensor({ 0.0003,-0.0041,-0.0041,-0.0041,-0.0041}):resize(5,1) 308 | -- dW_test = torch.Tensor({ 309 | -- -0.1972, -0.3401, 310 | -- -0.4511, -0.4852, 311 | 312 | -- -0.1706, -0.1828, 313 | -- -0.2187, -0.2459, 314 | 315 | -- -0.2624, -0.3401, 316 | -- -0.4511, -0.4852, 317 | 318 | -- -0.1566, -0.2953, 319 | -- -0.4066, -0.4402, 320 | 321 | -- -0.1751, -0.1861, 322 | -- -0.1754, -0.2022, 323 | 324 | -- -0.2091, -0.2953, 325 | -- -0.4066, -0.4402, 326 | 327 | -- -0.1566, -0.2953, 328 | -- -0.4066, -0.4402, 329 | 330 | -- -0.1751, -0.1861, 331 | -- -0.1754, -0.2022, 332 | 333 | -- -0.2091, -0.2953, 334 | -- -0.4066, -0.4402, 335 | 336 | -- -0.1566, -0.2953, 337 | -- -0.4066, -0.4402, 338 | 339 | -- -0.1751, -0.1861, 340 | -- -0.1754, -0.2022, 341 | 342 | -- -0.2091, -0.2953, 343 | -- -0.4066, -0.4402, 344 | 345 | -- -0.1566, -0.2953, 346 | -- -0.4066, -0.4402, 347 | 348 | -- -0.1751, -0.1861, 349 | -- -0.1754, -0.2022, 350 | 351 | -- -0.2091, -0.2953, 352 | -- -0.4066, -0.4402, 353 | -- }):resize(1,60) 354 | 355 | 356 | -- -- Calculate expected values before update 357 | -- --print(W) 358 | -- W_train_test = W_org + dW_test * rbm.learningrate 359 | -- b_train_test = b_org + db_test * rbm.learningrate 360 | -- c_train_test = c_org + dc_test * rbm.learningrate 361 | -- -- Setup check after RBMtrain 362 | -- -- dd and dU are not tested 363 | 364 | 365 | 366 | 367 | -- print "################ TESTS ############################" 368 | -- -- The debug table has references to the up/down/unroll networks 369 | 370 | -- -- a) Assert that rbm.W and b and modelup.weights share memeory 371 | 372 | -- print('Test network memory sharing') 373 | -- checkequality(conv.toFlat(rbm.W),conv.toFlat(debug_1.modelup.modules[2].weight),-4,false) 374 | -- assert(rbm.W:storage() == debug_1.modelup.modules[2].weight:storage()) 375 | -- assert(rbm.c:storage() == debug_1.modelup.modules[2].bias:storage()) 376 | 377 | -- -- b) for modeldownx the bias should be shared 378 | -- assert(rbm.b:storage() == debug_1.modeldownx.modules[3].bias:storage()) 379 | 380 | -- -- ADD MORE TESTS WITH NETWORKS 381 | -- -- E.g. test conversion functions 382 | -- print("OK") 383 | 384 | 385 | -- print("Testing Statistics...") 386 | -- assert(checkequality(h0,h0_test,-4,false)) 387 | 388 | -- assert(checkequality(h0_rnd,h0_rnd_test,-4,false)) 389 | -- assert(checkequality(vkx,vkx_sigm_test,-4,false)) 390 | -- assert(checkequality(vkx_rnd,vkx_sigm_rnd_test,-4,false)) 391 | -- assert(checkequality(hk,hk_sigm_test,-4,false)) 392 | -- print("OK") 393 | 394 | -- print("Testing Gradients...") 395 | -- assert(checkequality(db,db_test,-4,false)) 396 | -- assert(checkequality(dc,dc_test,-4,false)) 397 | -- assert(checkequality(dW,dW_test,-4,false)) 398 | 399 | -- assert(checkequality(dc_cgrads,dc_test,-4,false)) 400 | -- assert(checkequality(db_cgrads,db_test,-4,false)) 401 | -- assert(checkequality(dW_cgrads,dW_test,-4,false)) 402 | 403 | -- assert(checkequality(rbm.c,c_train_test,-4,false)) 404 | -- assert(checkequality(rbm.b,b_train_test,-4,false)) 405 | -- assert(checkequality(rbm.W,W_train_test,-4,false)) 406 | 407 | -- print('OK') 408 | 409 | 410 | print("################--Check against MEDAL----###############") 411 | n_classes = 4 412 | n_input = 1 413 | input_size = 8 414 | n_filters = 5 415 | filter_size = 3 416 | pool_size = 2 417 | 418 | y = torch.Tensor({0,0,1, 0}):resize(1,n_classes) 419 | 420 | x3d = torch.Tensor({ 421 | 0.0118, 0.0706, 0.0706, 0.0706, 0.4941, 0.5333, 0.6863, 0.1020, 422 | 0.6667, 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.8824, 0.6745, 423 | 0.9922, 0.9922, 0.9922, 0.9922, 0.9922, 0.9843, 0.3647, 0.3216, 424 | 0.9922, 0.9922, 0.7765, 0.7137, 0.9686, 0.9451, 0, 0, 425 | 0.9922, 0.8039, 0.0431, 0, 0.1686, 0.6039, 0, 0, 426 | 0.9922, 0.3529, 0, 0, 0, 0, 0, 0, 427 | 0.9922, 0.7451, 0.0078, 0, 0, 0, 0, 0, 428 | 0.7451, 0.9922, 0.2745, 0, 0, 0, 0, 0 429 | }):resize(1,n_input,input_size*input_size) 430 | 431 | 432 | x2d = x3d:view(1,input_size*input_size) 433 | W = torch.Tensor({ 434 | -0.0184, -0.0439, -0.0697, 435 | 0.0490, -0.0785, -0.0343, 436 | -0.1111, -0.0906, -0.0229, 437 | 438 | 0.0086, -0.0657, 0.0379, 439 | -0.0180, 0.0840, -0.0184, 440 | 0.0412, -0.1050, 0.0130, 441 | 442 | -0.0799, 0.1041, 0.0836, 443 | -0.0671, -0.0415, 0.0877, 444 | 0.0668, 0.0427, -0.0922, 445 | 446 | -0.1024, -0.0893, 0.0074, 447 | -0.0734, -0.0175, 0.0426, 448 | 0.0840, 0.1018, -0.0410, 449 | 450 | 0.0414, 0.0556, -0.0488, 451 | 0.0744, 0.1086, 0.0643, 452 | -0.1070, 0.0551, -0.0882 453 | }):resize(1,45) 454 | 455 | labels_medal = torch.Tensor{4} 456 | train_medal = {} 457 | train_medal.data = x3d--:view(1,1,81) 458 | train_medal.labels = labels_medal 459 | 460 | sizes_medal = conv.calcconvsizes(filter_size,n_filters,n_classes,input_size,pool_size,train_medal) 461 | --opts.n_hidden = 10 462 | opts_medal = {} 463 | conv.setupsettings(opts_medal,sizes_medal) 464 | opts_medal.W = W -- Testing weights 465 | opts_medal.U = torch.zeros(opts_medal.U:size()) -- otherwise tests fails 466 | opts_medal.numepochs = 1 467 | rbm_medal = rbmsetup(opts_medal,train_medal) 468 | debug_3 = conv.setupfunctions(rbm_medal,sizes_medal) -- modofies RBM to use conv functions 469 | 470 | 471 | W_org = rbm_medal.W:clone() 472 | U_org = rbm_medal.U:clone() 473 | b_org = rbm_medal.b:clone() 474 | c_org = rbm_medal.c:clone() 475 | d_org = rbm_medal.d:clone() 476 | 477 | 478 | print("------> TEST MEDAL STATISTICS <-----------------") 479 | conv_rbmup_m,conv_rbmdownx_m,conv_pygivenx_m,conv_pygivenxdropout_m,debugupdown_m = conv.createupdownpygivenx(rbm_medal,sizes_medal) 480 | 481 | rbm_medal.rand = function(m,n) return torch.Tensor(m,n):fill(1):mul(0.2)end 482 | 483 | stat_gen = rbm_medal.generativestatistics(rbm_medal,x2d,y,tcwx) 484 | grads_gen= rbm_medal.generativegrads(rbm_medal,x2d,y,stat_gen) 485 | 486 | rbm_medal = rbmtrain(rbm_medal,train_medal,train_medal) 487 | 488 | 489 | h0_test = torch.Tensor({ 490 | 0.1919, 0.1948, 0.1899, 0.1859, 0.1761, 0.1960, 491 | 0.1744, 0.1770, 0.1820, 0.1793, 0.1744, 0.2041, 492 | 0.1675, 0.1869, 0.1893, 0.1815, 0.1775, 0.1981, 493 | 0.1845, 0.2184, 0.2039, 0.1961, 0.1939, 0.2166, 494 | 0.1824, 0.2079, 0.2010, 0.1939, 0.1957, 0.1994, 495 | 0.1807, 0.2032, 0.1977, 0.2037, 0.2016, 0.2016, 496 | 497 | 0.2015, 0.2005, 0.2038, 0.1985, 0.1974, 0.2028, 498 | 0.1959, 0.2007, 0.2010, 0.1952, 0.1952, 0.2045, 499 | 0.1960, 0.2070, 0.2022, 0.2040, 0.1943, 0.2025, 500 | 0.2032, 0.1958, 0.1982, 0.1959, 0.2004, 0.2011, 501 | 0.1897, 0.2087, 0.2002, 0.2012, 0.1938, 0.2023, 502 | 0.1951, 0.2035, 0.2010, 0.1988, 0.2013, 0.2013, 503 | 504 | 0.1937, 0.1886, 0.1961, 0.2057, 0.2087, 0.1941, 505 | 0.2181, 0.2118, 0.2040, 0.2057, 0.2108, 0.2008, 506 | 0.2174, 0.2094, 0.2022, 0.1934, 0.2045, 0.1997, 507 | 0.1987, 0.1900, 0.2038, 0.2172, 0.2041, 0.1845, 508 | 0.2082, 0.1965, 0.1983, 0.2099, 0.2100, 0.1904, 509 | 0.1890, 0.2031, 0.1996, 0.1961, 0.1999, 0.1999, 510 | 511 | 0.2208, 0.2143, 0.2219, 0.2138, 0.2162, 0.1975, 512 | 0.1886, 0.1789, 0.1784, 0.1822, 0.1950, 0.1799, 513 | 0.2079, 0.1899, 0.1849, 0.1840, 0.1879, 0.1974, 514 | 0.1931, 0.1845, 0.2008, 0.2012, 0.1865, 0.1991, 515 | 0.1890, 0.1976, 0.1991, 0.1975, 0.1912, 0.1930, 516 | 0.1942, 0.2117, 0.2042, 0.1996, 0.2053, 0.2053, 517 | 518 | 0.1955, 0.2008, 0.1954, 0.1998, 0.2090, 0.2057, 519 | 0.2111, 0.2126, 0.2106, 0.2154, 0.2164, 0.1927, 520 | 0.2193, 0.2096, 0.2222, 0.2216, 0.2321, 0.1868, 521 | 0.2028, 0.1911, 0.1842, 0.1948, 0.2106, 0.1931, 522 | 0.2156, 0.1860, 0.2008, 0.1982, 0.2055, 0.2024, 523 | 0.2236, 0.1855, 0.1965, 0.2022, 0.1974, 0.1974}):resize(1,5*6*6) 524 | 525 | 526 | h0_rnd_test = torch.Tensor({ 527 | 0, 0, 0, 0, 0, 0, 528 | 0, 0, 0, 0, 0, 1, 529 | 0, 0, 0, 0, 0, 0, 530 | 0, 1, 1, 0, 0, 1, 531 | 0, 1, 1, 0, 0, 0, 532 | 0, 1, 0, 1, 1, 1, 533 | 534 | 1, 1, 1, 0, 0, 1, 535 | 0, 1, 1, 0, 0, 1, 536 | 0, 1, 1, 1, 0, 1, 537 | 1, 0, 0, 0, 1, 1, 538 | 0, 1, 1, 1, 0, 1, 539 | 0, 1, 1, 0, 1, 1, 540 | 541 | 0, 0, 0, 1, 1, 0, 542 | 1, 1, 1, 1, 1, 1, 543 | 1, 1, 1, 0, 1, 0, 544 | 0, 0, 1, 1, 1, 0, 545 | 1, 0, 0, 1, 1, 0, 546 | 0, 1, 0, 0, 0, 0, 547 | 548 | 1, 1, 1, 1, 1, 0, 549 | 0, 0, 0, 0, 0, 0, 550 | 1, 0, 0, 0, 0, 0, 551 | 0, 0, 1, 1, 0, 0, 552 | 0, 0, 0, 0, 0, 0, 553 | 0, 1, 1, 0, 1, 1, 554 | 555 | 0, 1, 0, 0, 1, 1, 556 | 1, 1, 1, 1, 1, 0, 557 | 1, 1, 1, 1, 1, 0, 558 | 1, 0, 0, 0, 1, 0, 559 | 1, 0, 1, 0, 1, 1, 560 | 1, 0, 0, 1, 0, 0}):resize(1,5*6*6) 561 | 562 | vkx_test = torch.Tensor({ 563 | 0.4766, 0.4483, 0.4631, 0.4156, 0.4798, 0.5527, 0.5080, 0.4973, 564 | 0.4676, 0.5447, 0.5515, 0.5356, 0.5232, 0.5845, 0.6026, 0.5244, 565 | 0.4979, 0.5279, 0.6336, 0.6318, 0.5969, 0.5844, 0.4536, 0.4994, 566 | 0.4859, 0.5273, 0.4733, 0.4614, 0.5572, 0.5460, 0.4665, 0.4620, 567 | 0.5154, 0.6438, 0.4025, 0.3296, 0.4227, 0.6698, 0.4696, 0.4874, 568 | 0.4957, 0.4503, 0.4749, 0.5213, 0.4989, 0.3901, 0.4399, 0.5029, 569 | 0.5085, 0.5067, 0.3467, 0.5560, 0.4530, 0.4821, 0.4372, 0.4787, 570 | 0.4733, 0.5339, 0.4965, 0.4099, 0.4877, 0.4524, 0.4639, 0.4873 571 | }):resize(1,1*8*8) 572 | 573 | hk_test = torch.Tensor({ 574 | 0.1916, 0.1905, 0.1891, 0.1881, 0.1878, 0.1912, 575 | 0.1906, 0.1893, 0.1926, 0.1911, 0.1893, 0.1931, 576 | 0.1866, 0.1887, 0.1931, 0.1909, 0.1892, 0.1890, 577 | 0.1912, 0.1974, 0.1926, 0.1882, 0.1897, 0.1969, 578 | 0.1905, 0.1939, 0.1941, 0.1889, 0.1919, 0.1902, 579 | 0.1910, 0.1911, 0.1890, 0.1944, 0.1922, 0.1925, 580 | 581 | 0.2004, 0.1977, 0.1991, 0.1990, 0.1985, 0.2022, 582 | 0.1986, 0.2016, 0.2012, 0.1991, 0.1996, 0.1982, 583 | 0.1973, 0.2009, 0.2009, 0.2011, 0.1944, 0.2014, 584 | 0.2023, 0.1980, 0.1979, 0.1980, 0.2039, 0.1988, 585 | 0.1963, 0.2034, 0.1981, 0.2025, 0.1955, 0.2009, 586 | 0.2002, 0.1978, 0.2019, 0.1965, 0.2017, 0.1999, 587 | 588 | 0.1989, 0.1989, 0.2012, 0.2059, 0.2059, 0.1990, 589 | 0.2066, 0.2042, 0.1995, 0.2015, 0.2027, 0.2020, 590 | 0.2069, 0.2094, 0.2023, 0.1958, 0.2027, 0.2030, 591 | 0.1984, 0.1943, 0.2015, 0.2108, 0.2028, 0.1977, 592 | 0.2069, 0.1946, 0.1989, 0.2053, 0.2068, 0.2002, 593 | 0.1984, 0.2077, 0.2038, 0.1995, 0.1987, 0.2024, 594 | 595 | 0.1988, 0.2004, 0.2031, 0.2032, 0.2026, 0.1970, 596 | 0.1988, 0.1960, 0.1925, 0.1950, 0.1974, 0.1950, 597 | 0.2029, 0.1977, 0.1918, 0.1909, 0.1980, 0.2010, 598 | 0.1966, 0.1945, 0.2028, 0.2054, 0.1970, 0.1945, 599 | 0.1971, 0.1947, 0.2013, 0.2011, 0.1963, 0.1960, 600 | 0.1984, 0.2023, 0.1971, 0.1944, 0.1997, 0.2007, 601 | 602 | 0.2001, 0.2022, 0.1993, 0.1991, 0.2047, 0.2032, 603 | 0.2045, 0.2069, 0.2076, 0.2073, 0.2038, 0.2023, 604 | 0.2047, 0.2013, 0.2062, 0.2059, 0.2087, 0.1968, 605 | 0.2053, 0.2012, 0.1969, 0.2027, 0.2042, 0.2046, 606 | 0.2073, 0.2013, 0.2064, 0.1971, 0.2046, 0.2026, 607 | 0.2024, 0.2008, 0.2012, 0.2071, 0.2021, 0.2016 608 | }):resize(1,5*6*6) 609 | 610 | 611 | 612 | dW_test = torch.Tensor({ 613 | 614 | 0.6897, 0.1193, -0.4294, 615 | 0.7322, -0.1370, -0.7868, 616 | 0.1229, -0.8993, -1.6596, 617 | 618 | 0.7539, 0.2024, -0.3725, 619 | 0.8539, -0.0078, -0.7207, 620 | 0.2052, -0.8324, -1.6701, 621 | 622 | 0.7792, 0.2744, -0.3199, 623 | 0.8652, 0.0210, -0.7132, 624 | 0.2138, -0.8103, -1.6942, 625 | 626 | 0.6137, 0.0841, -0.4521, 627 | 0.8015, -0.0256, -0.7237, 628 | 0.2071, -0.7876, -1.6409, 629 | 630 | 0.8560, 0.3313, -0.3115, 631 | 0.9545, 0.1280, -0.6752, 632 | 0.2382, -0.7432, -1.6661 633 | 634 | }):mul(1/( (6 - 2 * 3 + 2) * (6 - 2 * 3 + 2) )):resize(1,5*3*3) 635 | 636 | db_test = torch.Tensor({-2.9119}):mul(1/(input_size * input_size)):resize(1,1) 637 | dc_test = torch.Tensor({0.0317,0.0099,-0.0115,-0.0395,0.0304}):mul(1/(6*6)):resize(5,1) 638 | 639 | W_train_test = W_org + dW_test * rbm_medal.learningrate 640 | b_train_test = b_org + db_test * rbm_medal.learningrate 641 | c_train_test = c_org + dc_test * rbm_medal.learningrate 642 | 643 | 644 | assert(checkequality(stat_gen.h0,h0_test,-4,false)) 645 | assert(checkequality(stat_gen.h0_rnd,h0_rnd_test,-4,false)) 646 | assert(checkequality(stat_gen.vkx,vkx_test,-4,false)) 647 | assert(checkequality(stat_gen.hk,hk_test,-4,false)) 648 | 649 | assert(checkequality(grads_gen.dW,dW_test*rbm_medal.learningrate,-4,false)) 650 | assert(checkequality(grads_gen.db,db_test*rbm_medal.learningrate,-4,false)) 651 | assert(checkequality(grads_gen.dc,dc_test*rbm_medal.learningrate,-4,false)) 652 | 653 | assert(checkequality(rbm_medal.c,c_train_test,-3,false)) 654 | assert(checkequality(rbm_medal.b,b_train_test,-3,false)) 655 | assert(checkequality(rbm_medal.W,W_train_test,-3,false)) 656 | 657 | 658 | print("TODO") 659 | print("FIGURE OUT WHY THEY USE ") 660 | 661 | 662 | rbm_medal.numepochs = 5 663 | rbm_medal = rbmtrain(rbm_medal,train_medal,train_medal) 664 | 665 | 666 | 667 | 668 | rbm_medal.toprbm = false 669 | rbm_medal.currentepoch = 1 670 | rbm_medal.U = nil 671 | rbm_medal.dU = nil 672 | rbm_medal.d = nil 673 | rbm_medal.dd = nil 674 | rbm_medal.vd = nil 675 | rbm_medal.vU = nil 676 | rbm_medal = rbmtrain(rbm_medal,train_medal,train_medal) 677 | 678 | 679 | --uppass2 = rbmuppass(rbm_medal,train_medal) 680 | 681 | 682 | 683 | -- TEST UPPASS FUNCTION 684 | 685 | -- with usemaxpool uppass should be a factor of max_pool^2 smaller than the 686 | -- number of hidden units 687 | 688 | 689 | settings = {} 690 | settings.usemaxpool = true 691 | settings.vistype = 'binary' 692 | 693 | 694 | settings.filter_size = filter_size 695 | settings.n_filters = n_filters 696 | settings.n_classes = n_classes 697 | settings.input_size = input_size 698 | settings.pool_size = pool_size 699 | settings.toprbm = true 700 | 701 | 702 | 703 | rbm,opts,debug__ = setupconvrbm(settings,train_medal) 704 | rbm = rbmtrain(rbm,train_medal,train_medal) 705 | 706 | 707 | uppass1 = rbmuppass(rbm,train_medal) 708 | assert(uppass1:size(1) == train_medal.data:size(1)) 709 | assert(uppass1:size(3) == ( rbm.n_hidden / math.pow(pool_size,2)) ) 710 | 711 | settings.usemaxpool = false 712 | rbm,opts,debug__ = setupconvrbm(settings,train_medal) 713 | rbm = rbmtrain(rbm,train_medal,train_medal) 714 | 715 | uppass2 = rbmuppass(rbm,train_medal) 716 | assert(uppass2:size(1) == train_medal.data:size(1)) 717 | assert(uppass2:size(3) == rbm.n_hidden ) 718 | 719 | 720 | settings.toprbm = false 721 | settings.usemaxpool = true 722 | rbm,opts,debug__ = setupconvrbm(settings,train_medal) 723 | rbm = rbmtrain(rbm,train_medal,train_medal) 724 | uppass1 = rbmuppass(rbm,train_medal) 725 | assert(rbm.U == nil and rbm.dU == nil and rbm.vU == nil) 726 | assert(rbm.d == nil and rbm.dd == nil and rbm.vd == nil) 727 | assert(uppass1:size(1) == train_medal.data:size(1)) 728 | assert(uppass1:size(3) == ( rbm.n_hidden / math.pow(pool_size,2)) ) 729 | 730 | 731 | settings.toprbm = false 732 | settings.usemaxpool = false 733 | 734 | testopts = {} 735 | testopts.n_hidden = 1001231 736 | 737 | rbm,opts,debug__ = setupconvrbm(settings,train_medal,testopts) 738 | rbm = rbmtrain(rbm,train_medal,train_medal) 739 | uppass2 = rbmuppass(rbm,train_medal) 740 | assert(rbm.U == nil and rbm.dU == nil and rbm.vU == nil) 741 | assert(rbm.d == nil and rbm.dd == nil and rbm.vd == nil) 742 | assert(uppass2:size(1) == train_medal.data:size(1)) 743 | assert(uppass2:size(3) == rbm.n_hidden ) 744 | 745 | 746 | 747 | print("TEST STACKING") 748 | trainstackconvtorbm(convsettings,convopts,toprbmopts,train,val) 749 | -- settings.usemaxpool = false 750 | -- rbm,opts,debug__ = setupconvrbm(settings,train_medal) 751 | -- rbm = rbmtrain(rbm,train_medal,train_medal) 752 | 753 | 754 | 755 | -- print("\n\n#########################TRAINING ON SMALL MNIST##############") 756 | -- torch.manualSeed(123) 757 | -- train_mnist,val_mnist,test_mnist = mnist.loadMnist(1000,false) 758 | -- print("TRAINING DATA:\n", train_mnist) 759 | -- filter_size =11 -- 760 | 761 | -- n_filters = 40 -- how many filters to use 762 | -- pool_size =2 -- how large the probabilistic maxpooling 763 | -- input_size = 28 -- the size of the input image 764 | -- n_classes = 10 765 | -- n_input = 1 766 | 767 | -- sizes_mnist = conv.calcconvsizes(filter_size,n_filters,n_classes,input_size,pool_size,train_mnist) 768 | -- print(sizes_mnist) 769 | -- --opts.n_hidden = 10 770 | -- opts_mnist = {} 771 | -- conv.setupsettings(opts_mnist,sizes_mnist) 772 | -- --print(opts_mnist) 773 | -- --opts.W = W -- Testing weights 774 | -- --opts.U = torch.zeros(opts.U:size()) -- otherwise tests fails 775 | -- opts_mnist.numepochs = 100 776 | -- rbm_mnist = rbmsetup(opts_mnist,train_mnist) 777 | -- vistype = 'binary' -- | 'gauss' 778 | -- debug_mnist = conv.setupfunctions(rbm_mnist,sizes_mnist,vistype) -- modofies RBM to use conv functions 779 | 780 | -- --gfx.image(train_mnist.data[{1,5,{}}]:resize(28,28), {zoom=2, legend=''}) 781 | -- rbm_mnist.learningrate = 0.001 782 | -- rbm_mnist.L2 = 0.01 783 | -- rbm_mnist.batchsize = 10 784 | -- rbm_mnist.sparsity = 0.000 785 | -- rbm_mnist.c:fill(-0.1) -- hidden bias as in honglak lee 786 | -- rbm_mnist.U:fill(0) 787 | -- print("#######->>>>>DATASETS COMPOSITIONS---<<<<<<<#########") 788 | -- print("train: ", torch.histc(train_mnist.labels,10):resize(1,10) ) 789 | -- print("val: ", torch.histc(val_mnist.labels,10):resize(1,10) ) 790 | -- print("test: ", torch.histc(test_mnist.labels,10):resize(1,10) ) 791 | -- print("######################################################") 792 | -- rbm_mnist = rbmtrain(rbm_mnist,train_mnist,val_mnist) 793 | --stat_gen = rbm_medal.generativestatistics(rbm_medal,x2d_mnist,y,tcwx) 794 | 795 | 796 | -------------------------------------------------------------------------------- /run_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | IMAGE_NAME=itorch 3 | 4 | docker run -d -p 9999:9999 -v `pwd`:/root/mount -it $IMAGE_NAME /root/torch/install/bin/itorch notebook --profile=itorch_svr 5 | 6 | --------------------------------------------------------------------------------