├── .gitignore ├── DataProvider2D.lua ├── doc └── images │ └── model_arch.png ├── imtools.lua ├── models └── caae-nopool_GAN_v2.lua ├── plot error.ipynb ├── printFigs2D.ipynb ├── readme.md ├── run_docker.sh ├── setup.lua ├── train.lua ├── train_model_2D.lua ├── train_model_2D.sh └── utils.lua /.gitignore: -------------------------------------------------------------------------------- 1 | trainAAE2D_256 2 | .ipynb_checkpoints 3 | .nfs* 4 | -------------------------------------------------------------------------------- /DataProvider2D.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'torchx' 3 | require 'imtools' 4 | 5 | local DataProvider = { 6 | data = nil, -- Size of isotropic multivariate Gaussian Z 7 | labels = nil, 8 | opts = nil 9 | } 10 | 11 | function DataProvider.create(image_dir, opts) 12 | 13 | local save_dir = opts.parent_dir 14 | 15 | local data_file = save_dir .. '/data.t7' 16 | local labels_file = save_dir .. '/labels.t7' 17 | local data_info_file = save_dir .. '/data_info.t7' 18 | 19 | local data 20 | local labels 21 | local data_info 22 | 23 | local data = {} 24 | if paths.filep(data_file) then 25 | print('Loading data from ' .. save_dir) 26 | data = torch.load(data_file) 27 | 28 | print('Done') 29 | else 30 | 31 | local c = 0 32 | local images, image_paths, classes = {}, {}, {} 33 | 34 | for dir in paths.iterdirs(image_dir) do 35 | print('Loading images from ' .. image_dir .. '/' .. dir) 36 | 37 | local images_tmp, image_paths_tmp = imtools.load_img(image_dir .. '/' .. dir .. '/', 'png', opts.image_sub_size) 38 | 39 | for i = 1,#image_paths_tmp do 40 | c = c+1 41 | 42 | images[c] = nn.Unsqueeze(1):forward(images_tmp[i]) 43 | image_paths[c] = image_paths_tmp[i] 44 | 45 | tokens = utils.split(image_paths_tmp[i], '/') 46 | classes[c] = tokens[#tokens-1] 47 | end 48 | end 49 | images = torch.concat(images,1) 50 | images = torch.FloatTensor(images:size()):copy(images) 51 | classes, labels = utils.unique(classes) 52 | 53 | local nImgs = images:size()[1] 54 | local nClasses = torch.max(labels) 55 | local one_hot = torch.zeros(nImgs, nClasses):long() 56 | for i = 1,nImgs do 57 | one_hot[{i,labels[i]}] = 1 58 | end 59 | 60 | -- save 5% of the data for testing 61 | local rand_inds = torch.randperm(nImgs):long() 62 | local nTest = torch.round(nImgs/20) 63 | 64 | 65 | data.train = {} 66 | data.train.inds = rand_inds[{{nTest+1,-1}}] 67 | data.train.images = images:index(1, data.train.inds) 68 | data.train.labels = one_hot:index(1, data.train.inds) 69 | 70 | data.test = {} 71 | data.test.inds = rand_inds[{{1,nTest}}] 72 | data.test.images = images:index(1, data.test.inds) 73 | data.test.labels = one_hot:index(1, data.test.inds) 74 | 75 | data.image_paths = image_paths 76 | data.classes = classes 77 | 78 | paths.mkdir(save_dir) 79 | 80 | torch.save(data_file, data) 81 | end 82 | 83 | local self = data 84 | 85 | self.opts = {} 86 | 87 | -- shallow copy these options 88 | self.opts.channel_inds_in = opts.channel_inds_in:clone() or torch.LongTensor{1} 89 | self.opts.channel_inds_out = opts.channel_inds_out:clone() or torch.LongTensor{1} 90 | self.opts.rotate = opts.rotate or false 91 | 92 | setmetatable(self, { __index = DataProvider }) 93 | return self 94 | 95 | end 96 | 97 | function DataProvider:getImages(indices, train_or_test) 98 | 99 | local images_in = self[train_or_test].images:index(1, indices):index(2, self.opts.channel_inds_in):clone() 100 | local images_out = self[train_or_test].images:index(1, indices):index(2, self.opts.channel_inds_out):clone() 101 | 102 | if self.opts.rotate and torch.rand(1)[1] < 0.01 then 103 | for i = 1,images_in:size()[1] do 104 | rad = (torch.rand(1)*2*math.pi)[1] 105 | flip = torch.rand(1)[1]>0.5 106 | if flip then 107 | images_in[i] = image.hflip(images_in[i]) 108 | images_out[i] = image.hflip(images_out[i]) 109 | end 110 | 111 | images_in[i] = image.rotate(images_in[i], rad) 112 | images_out[i] = image.rotate(images_out[i], rad) 113 | end 114 | end 115 | 116 | return images_in, images_out 117 | end 118 | 119 | function DataProvider:getLabels(indices, train_or_test) 120 | local labels_in = self[train_or_test].labels:index(1, indices):typeAs(self.test.labels) 121 | return labels_in 122 | end 123 | 124 | return DataProvider 125 | 126 | 127 | -------------------------------------------------------------------------------- /doc/images/model_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenCellModeling/torch_integrated_cell/35037ae7b640e6c89357f27dfd510726088ce7fa/doc/images/model_arch.png -------------------------------------------------------------------------------- /imtools.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'image' 3 | require 'paths' 4 | require 'colormap' 5 | require 'utils' 6 | 7 | imtools = {} 8 | 9 | function imtools.load_img(im_dir, im_pattern, im_scale) 10 | if im_scale == nil then 11 | im_scale = 1; 12 | end 13 | -- find the index of the last file seperator 14 | 15 | -- tmp, ind = im_path:reverse():find('/') 16 | -- ind = im_path:len() - ind + 1 17 | 18 | -- im_dir = im_path:sub(1,ind) 19 | -- im_path = im_path:sub(ind, -1) 20 | 21 | local p = {} 22 | local c = 1 23 | 24 | for f in paths.files(im_dir, im_pattern) do 25 | p[c] = im_dir .. f 26 | c = c+1 27 | end 28 | 29 | -- natural sort 30 | local p = utils.alphanumsort(p) 31 | 32 | local im_tmp = image.load(p[1], 3, 'double') 33 | -- img_tmp = torch.DoubleTensor(img_tmp:size()):copy(img_tmp) 34 | 35 | local im_size_tmp = im_tmp:size() 36 | 37 | local im_size = torch.LongStorage(4) 38 | 39 | im_size[1] = #p 40 | im_size[2] = im_size_tmp[1] 41 | im_size[3] = torch.round(im_size_tmp[2]/im_scale) 42 | im_size[4] = torch.round(im_size_tmp[3]/im_scale) 43 | 44 | local im_out = torch.Tensor(im_size) 45 | 46 | for i = 1,im_size[1] do 47 | im_out[i] = image.scale(image.load(p[i], 3, 'double'), im_size[3], im_size[4], 'bilinear') 48 | end 49 | 50 | return im_out, p 51 | end 52 | 53 | 54 | function imtools.mat2img(img) 55 | if img:size(2) < 3 then 56 | local padDims = 3-img:size(2) 57 | img = torch.cat(img, torch.zeros(img:size(1), padDims, img:size(4), img:size(4)):typeAs(img), 2) 58 | end 59 | return img 60 | end 61 | 62 | 63 | function imtools.otsu(img) 64 | -- Shamelessley adapted from 65 | -- https://en.wikipedia.org/wiki/Otsu%27s_method 66 | -- grj 4/28/16 67 | 68 | local nbins = 256 69 | 70 | img = img - torch.min(img); 71 | img = torch.div(img, torch.max(img)) * (nbins-1) 72 | 73 | local counts = torch.histc(img, nbins, 0, nbins) 74 | p = counts / torch.sum(counts) 75 | 76 | local sigma_b = torch.LongTensor(nbins-1) 77 | 78 | for t = 1,(nbins-1) do 79 | local q_L = torch.sum(p[{{1, t}}]) 80 | local q_H = torch.sum(p[{{t + 1, -1}}]) 81 | 82 | local mul = torch.cmul( p[{{1,t}}], torch.linspace(1, t, t)) 83 | local miu_L = torch.sum( mul ) / q_L; 84 | 85 | local muh = torch.cmul(p[{{t+1, -1}}],torch.linspace(t+1, nbins, nbins-t)) 86 | local miu_H = torch.sum(muh) / q_H 87 | 88 | sigma_b[t] = q_L * q_H * ((miu_L - miu_H)^2) 89 | end 90 | 91 | local y, i = torch.max(sigma_b, 1) 92 | 93 | local i = torch.FloatTensor(1):copy(i) 94 | 95 | local im2 = img:gt(i[1]); 96 | return im2 97 | end 98 | 99 | function imtools.im2projection(im) 100 | local im_out = nil 101 | if im:size():size() == 5 then 102 | im_out = torch.zeros(im:size()[1], 3, im:size()[4], im:size()[5]):typeAs(im) 103 | for i = 1,im:size()[1] do 104 | im_out[i] = imtools.im2projection(im[i]) 105 | end 106 | else 107 | local imsize = im:size() 108 | local nchan = imsize[1] 109 | 110 | colormap:setStyle('jet') 111 | colormap:setSteps(nchan) 112 | local colors = colormap:colorbar(nchan, 2)[{{},{},{1}}] 113 | 114 | if nchan == 3 then 115 | colors = colors:index(2, torch.LongTensor{3,1,2}) 116 | end 117 | 118 | local im_flat = torch.max(im,2) 119 | 120 | im_out = torch.zeros(nchan, 3, imsize[3], imsize[4]):typeAs(im) 121 | 122 | for i = 1,nchan do 123 | local im_chan = im_flat[i] 124 | local im_chan = torch.repeatTensor(im_chan, 3, 1, 1) 125 | local cmap = torch.repeatTensor(colors[{{},{i}}], 1, imsize[3], imsize[4]):typeAs(im_chan) 126 | 127 | im_out[{i}] = torch.cmul(im_chan, cmap) 128 | 129 | local im_max = torch.max(im_out[{i}]) 130 | if im_max ~= 0 then 131 | im_out[{i}] = torch.div(im_out[{i}], im_max) 132 | end 133 | end 134 | 135 | im_out = torch.squeeze(torch.sum(im_out, 1)) 136 | 137 | local im_max = torch.max(im_out) 138 | if im_max ~= 0 then 139 | im_out = torch.div(im_out, im_max) 140 | end 141 | end 142 | -- print(im_out) 143 | return im_out 144 | end 145 | 146 | -------------------------------------------------------------------------------- /models/caae-nopool_GAN_v2.lua: -------------------------------------------------------------------------------- 1 | -- updated with apparently working advarsarial network 2 | 3 | require 'nn' 4 | require 'dpnn' 5 | 6 | local Model = { 7 | zSize = 16 -- -- Size of isotropic multivariate Gaussian Z 8 | } 9 | 10 | local function weights_init(m) 11 | local name = torch.type(m) 12 | if name:find('Convolution') then 13 | m.weight:normal(0.0, 0.02) 14 | m:noBias() 15 | elseif name:find('BatchNormalization') then 16 | if m.weight then m.weight:normal(1.0, 0.02) end 17 | if m.bias then m.bias:fill(0) end 18 | end 19 | end 20 | 21 | 22 | 23 | function Model:create(opts) 24 | self.scale_factor = 8/(512/opts.imsize) 25 | self.ganNoise = opts.ganNoise or 0.01 26 | self.ganNoiseAllLayers = opts.ganNoiseAllLayers or false 27 | 28 | self.nLatentDims = opts.nLatentDims 29 | self.nClasses = opts.nClasses 30 | self.nChIn = opts.nChIn 31 | self.nChOut = opts.nChOut 32 | self.nOther = opts.nOther 33 | 34 | self.fullConv = opts.fullConv 35 | 36 | self.dropoutRate = opts.dropoutRate 37 | 38 | Model:createAutoencoder() 39 | Model:createAdversary() 40 | Model:createAdversaryGen() 41 | 42 | Model:assemble() 43 | end 44 | 45 | function Model:createAutoencoder() 46 | 47 | scale = self.scale_factor 48 | 49 | -- Create encoder 50 | self.encoder = nn.Sequential() 51 | 52 | self.encoder:add(nn.SpatialConvolution(self.nChIn, 64, 4, 4, 2, 2, 1, 1)) 53 | self.encoder:add(nn.SpatialBatchNormalization(64)) 54 | 55 | self.encoder:add(nn.PReLU()) 56 | self.encoder:add(nn.SpatialConvolution(64, 128, 4, 4, 2, 2, 1, 1)) 57 | self.encoder:add(nn.SpatialBatchNormalization(128)) 58 | 59 | self.encoder:add(nn.PReLU()) 60 | self.encoder:add(nn.SpatialConvolution(128, 256, 4, 4, 2, 2, 1, 1)) 61 | self.encoder:add(nn.SpatialBatchNormalization(256)) 62 | 63 | self.encoder:add(nn.PReLU()) 64 | self.encoder:add(nn.SpatialConvolution(256, 512, 4, 4, 2, 2, 1, 1)) 65 | self.encoder:add(nn.SpatialBatchNormalization(512)) 66 | 67 | self.encoder:add(nn.PReLU()) 68 | self.encoder:add(nn.SpatialConvolution(512, 1024, 4, 4, 2, 2, 1, 1)) 69 | self.encoder:add(nn.SpatialBatchNormalization(1024)) 70 | 71 | self.encoder:add(nn.PReLU()) 72 | self.encoder:add(nn.SpatialConvolution(1024, 1024, 4, 4, 2, 2, 1, 1)) 73 | self.encoder:add(nn.SpatialBatchNormalization(1024)) 74 | 75 | self.encoder:add(nn.View(1024*scale*scale)) 76 | self.encoder:add(nn.PReLU()) 77 | 78 | 79 | encoder_out = nn.ConcatTable() 80 | 81 | if self.nClasses > 0 then 82 | -- output for the class label 83 | pred_label = nn.Sequential() 84 | pred_label:add(nn.Linear(1024*scale*scale, self.nClasses)) 85 | pred_label:add(nn.BatchNormalization(self.nClasses)) 86 | pred_label:add(nn.LogSoftMax()) 87 | 88 | encoder_out:add(pred_label) 89 | end 90 | 91 | if self.nOther > 0 then 92 | -- output for the structure features 93 | pred_other = nn.Sequential() 94 | pred_other:add(nn.Linear(1024*scale*scale, self.nOther)) 95 | pred_other:add(nn.BatchNormalization(self.nOther)) 96 | 97 | encoder_out:add(pred_other) 98 | end 99 | 100 | if self.nLatentDims > 0 then 101 | -- output for the noise 102 | noise = nn.Sequential() 103 | noise:add(nn.Linear(1024*scale*scale, self.nLatentDims)) 104 | noise:add(nn.BatchNormalization(self.nLatentDims)) 105 | 106 | encoder_out:add(noise) 107 | end 108 | 109 | -- join them all together 110 | self.encoder:add(encoder_out) 111 | 112 | ngf = 64 113 | 114 | -- Create decoder 115 | self.decoder = nn.Sequential() 116 | self.decoder:add(nn.JoinTable(-1)) 117 | self.decoder:add(nn.Linear(self.nClasses + self.nOther + self.nLatentDims, 1024*scale*scale)) 118 | self.decoder:add(nn.View(1024, scale, scale)) 119 | 120 | -- -- input is Z, going into a convolution 121 | -- self.decoder:add(nn.SpatialFullConvolution(self.nClasses + self.nOther + self.nLatentDims, 1024, 4, 4, 1, 1)) 122 | -- self.decoder:add(nn.SpatialBatchNormalization(1024)) 123 | 124 | self.decoder:add(nn.PReLU()) 125 | self.decoder:add(nn.SpatialFullConvolution(1024, 1024, 4, 4, 2, 2, 1, 1)) 126 | self.decoder:add(nn.SpatialBatchNormalization(1024)) 127 | 128 | self.decoder:add(nn.PReLU()) 129 | self.decoder:add(nn.SpatialFullConvolution(1024, 512, 4, 4, 2, 2, 1, 1)) 130 | self.decoder:add(nn.SpatialBatchNormalization(512)) 131 | 132 | self.decoder:add(nn.PReLU()) 133 | self.decoder:add(nn.SpatialFullConvolution(512, 256, 4, 4, 2, 2, 1, 1)) 134 | self.decoder:add(nn.SpatialBatchNormalization(256)) 135 | 136 | self.decoder:add(nn.PReLU()) 137 | self.decoder:add(nn.SpatialFullConvolution(256, 128, 4, 4, 2, 2, 1, 1)) 138 | self.decoder:add(nn.SpatialBatchNormalization(128)) 139 | 140 | self.decoder:add(nn.PReLU()) 141 | self.decoder:add(nn.SpatialFullConvolution(128, 64, 4, 4, 2, 2, 1, 1)) 142 | self.decoder:add(nn.SpatialBatchNormalization(64)) 143 | 144 | self.decoder:add(nn.PReLU()) 145 | self.decoder:add(nn.SpatialFullConvolution(64, self.nChOut, 4, 4, 2, 2, 1, 1)) 146 | -- self.decoder:add(nn.SpatialBatchNormalization(self.nChOut)) 147 | self.decoder:add(nn.Sigmoid()) 148 | 149 | self.encoder:apply(weights_init) 150 | self.decoder:apply(weights_init) 151 | end 152 | 153 | function Model:assemble() 154 | self.autoencoder = nn.Sequential() 155 | self.autoencoder:add(self.encoder) 156 | self.autoencoder:add(self.decoder) 157 | end 158 | 159 | function Model:createAdversary() 160 | -- also modeled off of soumith dcgan 161 | noise = 0.1 162 | 163 | self.adversary = nn.Sequential() 164 | 165 | self.adversary:add(nn.Linear(self.nLatentDims, 1024)) 166 | self.adversary:add(nn.LeakyReLU(0.2, true)) 167 | 168 | self.adversary:add(nn.Linear(1024, 1024)) 169 | self.adversary:add(nn.BatchNormalization(1024)) 170 | self.adversary:add(nn.LeakyReLU(0.2, true)) 171 | 172 | self.adversary:add(nn.Linear(1024, 512)) 173 | self.adversary:add(nn.BatchNormalization(512)) 174 | self.adversary:add(nn.LeakyReLU(0.2, true)) 175 | 176 | self.adversary:add(nn.Linear(512, 1)) 177 | self.adversary:add(nn.Sigmoid(true)) 178 | 179 | self.adversary:apply(weights_init) 180 | 181 | end 182 | 183 | function Model:createAdversaryGen() 184 | noise = self.ganNoise 185 | ndf = 64 186 | 187 | -- l1weight = 0.001 188 | 189 | scale = self.scale_factor 190 | 191 | self.adversaryGen = nn.Sequential() 192 | 193 | -- input is (nc) x 64 x 64 194 | if noise > 0 then 195 | self.adversaryGen:add(nn.WhiteNoise(0, noise)) 196 | end 197 | self.adversaryGen:add(nn.SpatialConvolution(self.nChOut, ndf, 4, 4, 2, 2, 1, 1)) 198 | -- self.adversaryGen:add(nn.SpatialBatchNormalization(ndf)) 199 | self.adversaryGen:add(nn.LeakyReLU(0.2, true)) 200 | 201 | if self.ganNoiseAllLayers and noise > 0 then 202 | self.adversaryGen:add(nn.WhiteNoise(0, noise)) 203 | end 204 | self.adversaryGen:add(nn.SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1)) 205 | self.adversaryGen:add(nn.SpatialBatchNormalization(ndf * 2)) 206 | self.adversaryGen:add(nn.LeakyReLU(0.2, true)) 207 | 208 | if self.ganNoiseAllLayers and noise > 0 then 209 | self.adversaryGen:add(nn.WhiteNoise(0, noise)) 210 | end 211 | self.adversaryGen:add(nn.SpatialConvolution(ndf * 2, ndf * 4, 4, 4, 2, 2, 1, 1)) 212 | self.adversaryGen:add(nn.SpatialBatchNormalization(ndf * 4)) 213 | self.adversaryGen:add(nn.LeakyReLU(0.2, true)) 214 | 215 | if self.ganNoiseAllLayers and noise > 0 then 216 | self.adversaryGen:add(nn.WhiteNoise(0, noise)) 217 | end 218 | self.adversaryGen:add(nn.SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1)) 219 | self.adversaryGen:add(nn.SpatialBatchNormalization(ndf * 8)) 220 | self.adversaryGen:add(nn.LeakyReLU(0.2, true)) 221 | 222 | if self.ganNoiseAllLayers and noise > 0 then 223 | self.adversaryGen:add(nn.WhiteNoise(0, noise)) 224 | end 225 | self.adversaryGen:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1)) 226 | self.adversaryGen:add(nn.SpatialBatchNormalization(ndf * 8)) 227 | self.adversaryGen:add(nn.LeakyReLU(0.2, true)) 228 | 229 | -- state size: (ndf*8) x 4 x 4 230 | if noise > 0 then 231 | self.adversaryGen:add(nn.WhiteNoise(0, noise)) 232 | end 233 | 234 | if self.nClasses > 0 then 235 | -- output for the class label 236 | 237 | 238 | self.adversaryGen:add(nn.View(ndf * 8 * scale*2 * scale*2)) 239 | self.adversaryGen:add(nn.Linear(ndf * 8 * scale*2 * scale*2, self.nClasses+1)) 240 | 241 | -- self.adversaryGen:add(nn.SpatialConvolution(ndf * 8, self.nClasses+1, 4, 4)) 242 | -- self.adversaryGen:add(nn.View(self.nClasses+1)) 243 | -- self.adversaryGen:add(nn.BatchNormalization(self.nClasses+1)) 244 | self.adversaryGen:add(nn.LogSoftMax()) 245 | else 246 | 247 | self.adversaryGen:add(nn.View(ndf * 8 * scale*2 * scale*2)) 248 | self.adversaryGen:add(nn.Linear(ndf * 8 * scale*2 * scale*2, 1)) 249 | 250 | -- self.adversaryGen:add(nn.SpatialConvolution(ndf * 8, 16, 4, 4, 1, 1)) 251 | -- self.adversaryGen:add(nn.BatchNormalization(1)) 252 | -- self.adversaryGen:add(nn.LeakyReLU(0.2, true)) 253 | -- self.adversaryGen:add(nn.View(1)) 254 | self.adversaryGen:add(nn.Sigmoid()) 255 | end 256 | 257 | self.adversaryGen:apply(weights_init) 258 | 259 | end 260 | 261 | 262 | return Model 263 | 264 | 265 | -------------------------------------------------------------------------------- /printFigs2D.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 10, 6 | "metadata": { 7 | "collapsed": false, 8 | "deletable": true, 9 | "editable": true 10 | }, 11 | "outputs": [ 12 | { 13 | "data": { 14 | "text/plain": [ 15 | "GPU is ON\t\n", 16 | "Device 1:\t\n" 17 | ] 18 | }, 19 | "execution_count": 10, 20 | "metadata": {}, 21 | "output_type": "execute_result" 22 | }, 23 | { 24 | "data": { 25 | "text/plain": [ 26 | "\t Free memory 11500444160\t\n", 27 | "\t Total memory 12782075904\t\n", 28 | "Device 2:\t\n" 29 | ] 30 | }, 31 | "execution_count": 10, 32 | "metadata": {}, 33 | "output_type": "execute_result" 34 | }, 35 | { 36 | "data": { 37 | "text/plain": [ 38 | "\t Free memory 11997492224\t\n", 39 | "\t Total memory 12778799104\t\n", 40 | "Device 3:\t\n", 41 | "\t Free memory 12176064512\t\n", 42 | "\t Total memory 12782075904\t\n", 43 | "Setting up\t\n", 44 | "Has cudnn: \t\n", 45 | "true\t\n" 46 | ] 47 | }, 48 | "execution_count": 10, 49 | "metadata": {}, 50 | "output_type": "execute_result" 51 | }, 52 | { 53 | "data": { 54 | "text/plain": [ 55 | "Loading data from ./trainAAE2D_256\t\n" 56 | ] 57 | }, 58 | "execution_count": 10, 59 | "metadata": {}, 60 | "output_type": "execute_result" 61 | }, 62 | { 63 | "data": { 64 | "text/plain": [ 65 | "Done\t\n", 66 | "Loading Models\t\n" 67 | ] 68 | }, 69 | "execution_count": 10, 70 | "metadata": {}, 71 | "output_type": "execute_result" 72 | }, 73 | { 74 | "data": { 75 | "text/plain": [ 76 | "Done Loading Models\t\n", 77 | "/root/images/release_4_1_17/release_v2/aligned/2D/\t\n", 78 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Beta actin\t\n" 79 | ] 80 | }, 81 | "execution_count": 10, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | }, 85 | { 86 | "data": { 87 | "text/plain": [ 88 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//ZO1\t\n" 89 | ] 90 | }, 91 | "execution_count": 10, 92 | "metadata": {}, 93 | "output_type": "execute_result" 94 | }, 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Sec61 beta\t\n" 99 | ] 100 | }, 101 | "execution_count": 10, 102 | "metadata": {}, 103 | "output_type": "execute_result" 104 | }, 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Alpha actinin\t\n" 109 | ] 110 | }, 111 | "execution_count": 10, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | }, 115 | { 116 | "data": { 117 | "text/plain": [ 118 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Tom20\t\n" 119 | ] 120 | }, 121 | "execution_count": 10, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | }, 125 | { 126 | "data": { 127 | "text/plain": [ 128 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Lamin B1\t\n" 129 | ] 130 | }, 131 | "execution_count": 10, 132 | "metadata": {}, 133 | "output_type": "execute_result" 134 | }, 135 | { 136 | "data": { 137 | "text/plain": [ 138 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Alpha tubulin\t\n" 139 | ] 140 | }, 141 | "execution_count": 10, 142 | "metadata": {}, 143 | "output_type": "execute_result" 144 | }, 145 | { 146 | "data": { 147 | "text/plain": [ 148 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Desmoplakin\t\n" 149 | ] 150 | }, 151 | "execution_count": 10, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | }, 155 | { 156 | "data": { 157 | "text/plain": [ 158 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Myosin IIB\t\n" 159 | ] 160 | }, 161 | "execution_count": 10, 162 | "metadata": {}, 163 | "output_type": "execute_result" 164 | }, 165 | { 166 | "data": { 167 | "text/plain": [ 168 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Fibrillarin\t\n" 169 | ] 170 | }, 171 | "execution_count": 10, 172 | "metadata": {}, 173 | "output_type": "execute_result" 174 | }, 175 | { 176 | "data": { 177 | "text/plain": [ 178 | "{\n", 179 | " 1 : Beta actin\n", 180 | " 2 : ZO1\n", 181 | " 3 : Sec61 beta\n", 182 | " 4 : Alpha actinin\n", 183 | " 5 : Tom20\n", 184 | " 6 : Lamin B1\n", 185 | " 7 : Alpha tubulin\n", 186 | " 8 : Desmoplakin\n", 187 | " 9 : Myosin IIB\n", 188 | " 10 : Fibrillarin\n", 189 | "}\n" 190 | ] 191 | }, 192 | "execution_count": 10, 193 | "metadata": {}, 194 | "output_type": "execute_result" 195 | } 196 | ], 197 | "source": [ 198 | "require 'cutorch' -- Use CUDA if available\n", 199 | "require 'cudnn'\n", 200 | "require 'cunn'\n", 201 | "\n", 202 | "-- print(cudnn)\n", 203 | "\n", 204 | "-- Load dependencies\n", 205 | "optim = require 'optim'\n", 206 | "gnuplot = require 'gnuplot'\n", 207 | "image = require 'image'\n", 208 | "nn = require 'nn'\n", 209 | "\n", 210 | "require 'paths'\n", 211 | "\n", 212 | "count = cutorch.getDeviceCount()\n", 213 | "-- torch.setdefaulttensortype('torch.CudaTensor') \n", 214 | "print('GPU is ON')\n", 215 | "for i = 1, count do\n", 216 | " print('Device ' .. i .. ':')\n", 217 | " freeMemory, totalMemory = cutorch.getMemoryUsage(i)\n", 218 | " print('\\t Free memory ' .. freeMemory)\n", 219 | " print('\\t Total memory ' .. totalMemory)\n", 220 | "end\n", 221 | "\n", 222 | "cutorch.setDevice(2)\n", 223 | "torch.setnumthreads(12)\n", 224 | "\n", 225 | "-- Set up Torch\n", 226 | "print('Setting up')\n", 227 | "torch.setdefaulttensortype('torch.FloatTensor')\n", 228 | "torch.manualSeed(1) \n", 229 | "cutorch.manualSeed(torch.random())\n", 230 | "-- end\n", 231 | "\n", 232 | "cuda = true\n", 233 | "hasCudnn = true\n", 234 | "\n", 235 | "print('Has cudnn: ')\n", 236 | "print(hasCudnn)\n", 237 | "\n", 238 | "\n", 239 | "package.loaded['./modelTools'] = nil\n", 240 | "DataProvider = require ('./modelTools')\n", 241 | "\n", 242 | "package.loaded['./DataProvider2D'] = nil\n", 243 | "DataProvider = require ('./DataProvider2D')\n", 244 | "\n", 245 | "package.loaded['setupAAEGANv2'] = nil\n", 246 | "setup = require 'setupAAEGANv2'\n", 247 | "\n", 248 | "model_opts = setup.getModelOpts()\n", 249 | "\n", 250 | "model_opts.parent_dir = './trainAAE2D_256'\n", 251 | "model_opts.model_name = 'caae-nopool_GAN_v2_AAEGANv1_4_alldat'\n", 252 | "model_opts.image_dir = '/root/images/release_4_1_17/release_v2/aligned/2D/'\n", 253 | "\n", 254 | "model_opts.save_dir = model_opts.parent_dir .. '/' .. model_opts.model_name\n", 255 | "\n", 256 | "model_opts.channel_inds_in = torch.LongTensor{1,3}\n", 257 | "model_opts.channel_inds_out = torch.LongTensor{1,3}\n", 258 | "model_opts.nChOut = 3\n", 259 | "model_opts.nChIn = 3\n", 260 | "model_opts.rotate = false\n", 261 | "\n", 262 | "dataProvider = DataProvider.create(model_opts.image_dir, model_opts)\n", 263 | "\n", 264 | "print('Loading Models')\n", 265 | "\n", 266 | "encoder = torch.load(model_opts.save_dir .. '/encoder.t7')\n", 267 | "\n", 268 | "decoder:cuda()\n", 269 | "encoder:cuda()\n", 270 | "\n", 271 | "print('Done Loading Models')\n", 272 | "\n", 273 | "decoder:evaluate()\n", 274 | "encoder:evaluate()\n", 275 | "\n", 276 | "package.loaded['utils'] = nil\n", 277 | "setup = require 'utils'\n", 278 | "\n", 279 | "print(model_opts.image_dir)\n", 280 | "\n", 281 | "image_dir = model_opts.image_dir\n", 282 | "image_paths = {}\n", 283 | "classes = {}\n", 284 | "c = 0\n", 285 | "for dir in paths.iterdirs(image_dir) do\n", 286 | " print('Loading images from ' .. image_dir .. '/' .. dir)\n", 287 | "\n", 288 | " im_dir = image_dir .. '/' .. dir .. '/'\n", 289 | " im_pattern = 'png'\n", 290 | " \n", 291 | " local p = {}\n", 292 | " local c_files = 1\n", 293 | " \n", 294 | " for f in paths.files(im_dir, im_pattern) do\n", 295 | " p[c_files] = im_dir .. f\n", 296 | " c_files = c_files+1\n", 297 | " end\n", 298 | " \n", 299 | " -- natural sort\n", 300 | " local image_paths_tmp = utils.alphanumsort(p)\n", 301 | " \n", 302 | " \n", 303 | " for i = 1,#image_paths_tmp do\n", 304 | " c = c+1\n", 305 | "\n", 306 | " image_paths[c] = image_paths_tmp[i]\n", 307 | "\n", 308 | " tokens = utils.split(image_paths_tmp[i], '/')\n", 309 | " classes[c] = tokens[#tokens-1] \n", 310 | " end\n", 311 | "end\n", 312 | "\n", 313 | "-- print(classes)\n", 314 | "\n", 315 | "classes, labels = utils.unique(classes)\n", 316 | "\n", 317 | "print(classes)\n", 318 | "\n", 319 | "\n", 320 | "\n", 321 | "save_dir = model_opts.save_dir .. '/' .. 'gridded'\n", 322 | "paths.mkdir(save_dir)\n", 323 | "\n", 324 | "nsteps = 21;\n", 325 | "\n", 326 | "pos = torch.linspace(-3, 3, nsteps);\n", 327 | "\n", 328 | "im_out = {}\n", 329 | "\n", 330 | "-- for k = 1, 10 do\n", 331 | " imgs_i = {}\n", 332 | " for i = 1,nsteps do\n", 333 | " imgs_j = {}\n", 334 | " for j = 1,nsteps do\n", 335 | " dat = {}\n", 336 | " \n", 337 | " dat[1] = torch.zeros(2,model_opts.nLatentDims):cuda()\n", 338 | " dat[1][{{},3}] = pos[i]\n", 339 | " dat[1][{{},4}] = pos[j]\n", 340 | "\n", 341 | " imgs_j[j] = decoder:forward(dat)[1]:float()\n", 342 | " \n", 343 | " end\n", 344 | " \n", 345 | " imgs_i[i] = torch.cat(imgs_j, 2)\n", 346 | " end\n", 347 | " \n", 348 | " imgs_struct = torch.cat(imgs_i,3)\n", 349 | "\n", 350 | " imgs_out = torch.zeros(3, imgs_struct:size(2), imgs_struct:size(2))\n", 351 | "\n", 352 | " imgs_out[1] = imgs_struct[1]\n", 353 | " imgs_out[3] = imgs_struct[2]\n", 354 | " \n", 355 | " save_path = save_dir .. '/cellnuc.png'\n", 356 | " image.save(save_path, imgs_out)\n", 357 | "-- end\n", 358 | "\n", 359 | "encoder = nil\n", 360 | "\n", 361 | "\n" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": 7, 367 | "metadata": { 368 | "collapsed": false 369 | }, 370 | "outputs": [ 371 | { 372 | "data": { 373 | "text/plain": [ 374 | "./trainAAE2D_256/caae-nopool_GAN_v2_AAEGANv1_4_alldat/gridded_ref\t\n" 375 | ] 376 | }, 377 | "execution_count": 7, 378 | "metadata": {}, 379 | "output_type": "execute_result" 380 | } 381 | ], 382 | "source": [ 383 | "save_dir" 384 | ] 385 | }, 386 | { 387 | "cell_type": "code", 388 | "execution_count": 1, 389 | "metadata": { 390 | "collapsed": false, 391 | "deletable": true, 392 | "editable": true 393 | }, 394 | "outputs": [ 395 | { 396 | "data": { 397 | "text/plain": [ 398 | "GPU is ON\t\n", 399 | "Device 1:\t\n" 400 | ] 401 | }, 402 | "execution_count": 1, 403 | "metadata": {}, 404 | "output_type": "execute_result" 405 | }, 406 | { 407 | "data": { 408 | "text/plain": [ 409 | "\t Free memory 12096372736\t\n", 410 | "\t Total memory 12782075904\t\n", 411 | "Device 2:\t\n" 412 | ] 413 | }, 414 | "execution_count": 1, 415 | "metadata": {}, 416 | "output_type": "execute_result" 417 | }, 418 | { 419 | "data": { 420 | "text/plain": [ 421 | "\t Free memory 12077301760\t\n", 422 | "\t Total memory 12778799104\t\n", 423 | "Device 3:\t\n" 424 | ] 425 | }, 426 | "execution_count": 1, 427 | "metadata": {}, 428 | "output_type": "execute_result" 429 | }, 430 | { 431 | "data": { 432 | "text/plain": [ 433 | "\t Free memory 12176064512\t\n", 434 | "\t Total memory 12782075904\t\n", 435 | "Setting up\t\n", 436 | "Has cudnn: \t\n", 437 | "true\t\n" 438 | ] 439 | }, 440 | "execution_count": 1, 441 | "metadata": {}, 442 | "output_type": "execute_result" 443 | }, 444 | { 445 | "data": { 446 | "text/plain": [ 447 | "Loading data from ./trainAAE2D_256\t\n" 448 | ] 449 | }, 450 | "execution_count": 1, 451 | "metadata": {}, 452 | "output_type": "execute_result" 453 | }, 454 | { 455 | "data": { 456 | "text/plain": [ 457 | "Done\t\n", 458 | "Loading Models\t\n" 459 | ] 460 | }, 461 | "execution_count": 1, 462 | "metadata": {}, 463 | "output_type": "execute_result" 464 | }, 465 | { 466 | "data": { 467 | "text/plain": [ 468 | "Done Loading Models\t\n" 469 | ] 470 | }, 471 | "execution_count": 1, 472 | "metadata": {}, 473 | "output_type": "execute_result" 474 | }, 475 | { 476 | "data": { 477 | "text/plain": [ 478 | "/root/images/release_4_1_17/release_v2/aligned/2D/\t\n", 479 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Beta actin\t\n" 480 | ] 481 | }, 482 | "execution_count": 1, 483 | "metadata": {}, 484 | "output_type": "execute_result" 485 | }, 486 | { 487 | "data": { 488 | "text/plain": [ 489 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//ZO1\t\n" 490 | ] 491 | }, 492 | "execution_count": 1, 493 | "metadata": {}, 494 | "output_type": "execute_result" 495 | }, 496 | { 497 | "data": { 498 | "text/plain": [ 499 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Sec61 beta\t\n" 500 | ] 501 | }, 502 | "execution_count": 1, 503 | "metadata": {}, 504 | "output_type": "execute_result" 505 | }, 506 | { 507 | "data": { 508 | "text/plain": [ 509 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Alpha actinin\t\n" 510 | ] 511 | }, 512 | "execution_count": 1, 513 | "metadata": {}, 514 | "output_type": "execute_result" 515 | }, 516 | { 517 | "data": { 518 | "text/plain": [ 519 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Tom20\t\n" 520 | ] 521 | }, 522 | "execution_count": 1, 523 | "metadata": {}, 524 | "output_type": "execute_result" 525 | }, 526 | { 527 | "data": { 528 | "text/plain": [ 529 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Lamin B1\t\n" 530 | ] 531 | }, 532 | "execution_count": 1, 533 | "metadata": {}, 534 | "output_type": "execute_result" 535 | }, 536 | { 537 | "data": { 538 | "text/plain": [ 539 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Alpha tubulin\t\n" 540 | ] 541 | }, 542 | "execution_count": 1, 543 | "metadata": {}, 544 | "output_type": "execute_result" 545 | }, 546 | { 547 | "data": { 548 | "text/plain": [ 549 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Desmoplakin\t\n" 550 | ] 551 | }, 552 | "execution_count": 1, 553 | "metadata": {}, 554 | "output_type": "execute_result" 555 | }, 556 | { 557 | "data": { 558 | "text/plain": [ 559 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Myosin IIB\t\n" 560 | ] 561 | }, 562 | "execution_count": 1, 563 | "metadata": {}, 564 | "output_type": "execute_result" 565 | }, 566 | { 567 | "data": { 568 | "text/plain": [ 569 | "Loading images from /root/images/release_4_1_17/release_v2/aligned/2D//Fibrillarin\t\n" 570 | ] 571 | }, 572 | "execution_count": 1, 573 | "metadata": {}, 574 | "output_type": "execute_result" 575 | }, 576 | { 577 | "data": { 578 | "text/plain": [ 579 | "{\n", 580 | " 1 : Beta actin\n", 581 | " 2 : ZO1\n", 582 | " 3 : Sec61 beta\n", 583 | " 4 : Alpha actinin\n", 584 | " 5 : Tom20\n", 585 | " 6 : Lamin B1\n", 586 | " 7 : Alpha tubulin\n", 587 | " 8 : Desmoplakin\n", 588 | " 9 : Myosin IIB\n", 589 | " 10 : Fibrillarin\n", 590 | "}\n" 591 | ] 592 | }, 593 | "execution_count": 1, 594 | "metadata": {}, 595 | "output_type": "execute_result" 596 | } 597 | ], 598 | "source": [ 599 | "require 'cutorch' -- Use CUDA if available\n", 600 | "require 'cudnn'\n", 601 | "require 'cunn'\n", 602 | "\n", 603 | "-- print(cudnn)\n", 604 | "\n", 605 | "-- Load dependencies\n", 606 | "optim = require 'optim'\n", 607 | "gnuplot = require 'gnuplot'\n", 608 | "image = require 'image'\n", 609 | "nn = require 'nn'\n", 610 | "\n", 611 | "require 'paths'\n", 612 | "\n", 613 | "count = cutorch.getDeviceCount()\n", 614 | "-- torch.setdefaulttensortype('torch.CudaTensor') \n", 615 | "print('GPU is ON')\n", 616 | "for i = 1, count do\n", 617 | " print('Device ' .. i .. ':')\n", 618 | " freeMemory, totalMemory = cutorch.getMemoryUsage(i)\n", 619 | " print('\\t Free memory ' .. freeMemory)\n", 620 | " print('\\t Total memory ' .. totalMemory)\n", 621 | "end\n", 622 | "\n", 623 | "cutorch.setDevice(2)\n", 624 | "torch.setnumthreads(12)\n", 625 | "\n", 626 | "-- Set up Torch\n", 627 | "print('Setting up')\n", 628 | "torch.setdefaulttensortype('torch.FloatTensor')\n", 629 | "torch.manualSeed(1) \n", 630 | "cutorch.manualSeed(torch.random())\n", 631 | "-- end\n", 632 | "\n", 633 | "cuda = true\n", 634 | "hasCudnn = true\n", 635 | "\n", 636 | "print('Has cudnn: ')\n", 637 | "print(hasCudnn)\n", 638 | "\n", 639 | "\n", 640 | "package.loaded['./modelTools'] = nil\n", 641 | "DataProvider = require ('./modelTools')\n", 642 | "\n", 643 | "package.loaded['./DataProvider2D'] = nil\n", 644 | "DataProvider = require ('./DataProvider2D')\n", 645 | "\n", 646 | "package.loaded['setupAAEGANv2'] = nil\n", 647 | "setup = require 'setupAAEGANv2'\n", 648 | "\n", 649 | "model_opts = setup.getModelOpts()\n", 650 | "\n", 651 | "model_opts.parent_dir = './trainAAE2D_256'\n", 652 | "model_opts.model_name = 'caae-nopool_GAN_v2_AAEGANv1_4_alldat_pattern'\n", 653 | "model_opts.image_dir = '/root/images/release_4_1_17/release_v2/aligned/2D/'\n", 654 | "\n", 655 | "model_opts.save_dir = model_opts.parent_dir .. '/' .. model_opts.model_name\n", 656 | "\n", 657 | "model_opts.channel_inds_in = torch.LongTensor{1,2,3}\n", 658 | "model_opts.channel_inds_out = torch.LongTensor{1,2,3}\n", 659 | "model_opts.nChOut = 3\n", 660 | "model_opts.nChIn = 3\n", 661 | "model_opts.rotate = false\n", 662 | "\n", 663 | "dataProvider = DataProvider.create(model_opts.image_dir, model_opts)\n", 664 | "\n", 665 | "print('Loading Models')\n", 666 | "\n", 667 | "decoder = torch.load(model_opts.save_dir .. '/decoder.t7')\n", 668 | "encoder = torch.load(model_opts.save_dir .. '/encoder.t7')\n", 669 | "\n", 670 | "decoder:cuda()\n", 671 | "encoder:cuda()\n", 672 | "\n", 673 | "print('Done Loading Models')\n", 674 | "\n", 675 | "decoder:evaluate()\n", 676 | "encoder:evaluate()\n", 677 | "\n", 678 | "package.loaded['utils'] = nil\n", 679 | "setup = require 'utils'\n", 680 | "\n", 681 | "print(model_opts.image_dir)\n", 682 | "\n", 683 | "image_dir = model_opts.image_dir\n", 684 | "image_paths = {}\n", 685 | "classes = {}\n", 686 | "c = 0\n", 687 | "for dir in paths.iterdirs(image_dir) do\n", 688 | " print('Loading images from ' .. image_dir .. '/' .. dir)\n", 689 | "\n", 690 | " im_dir = image_dir .. '/' .. dir .. '/'\n", 691 | " im_pattern = 'png'\n", 692 | " \n", 693 | " local p = {}\n", 694 | " local c_files = 1\n", 695 | " \n", 696 | " for f in paths.files(im_dir, im_pattern) do\n", 697 | " p[c_files] = im_dir .. f\n", 698 | " c_files = c_files+1\n", 699 | " end\n", 700 | " \n", 701 | " -- natural sort\n", 702 | " local image_paths_tmp = utils.alphanumsort(p)\n", 703 | " \n", 704 | " \n", 705 | " for i = 1,#image_paths_tmp do\n", 706 | " c = c+1\n", 707 | "\n", 708 | " image_paths[c] = image_paths_tmp[i]\n", 709 | "\n", 710 | " tokens = utils.split(image_paths_tmp[i], '/')\n", 711 | " classes[c] = tokens[#tokens-1] \n", 712 | " end\n", 713 | "end\n", 714 | "\n", 715 | "-- print(classes)\n", 716 | "\n", 717 | "classes, labels = utils.unique(classes)\n", 718 | "\n", 719 | "print(classes)\n", 720 | "\n", 721 | "\n" 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": null, 727 | "metadata": { 728 | "collapsed": true, 729 | "deletable": true, 730 | "editable": true 731 | }, 732 | "outputs": [], 733 | "source": [ 734 | "\n", 735 | "save_dir = model_opts.save_dir .. '/' .. 'gridded'\n", 736 | "paths.mkdir(save_dir)\n", 737 | "\n", 738 | "nsteps = 21;\n", 739 | "\n", 740 | "pos = torch.linspace(-3, 3, nsteps);\n", 741 | "\n", 742 | "im_out = {}\n", 743 | "\n", 744 | "for k = 1, 10 do\n", 745 | " imgs_i = {}\n", 746 | " for i = 1,nsteps do\n", 747 | " imgs_j = {}\n", 748 | " for j = 1,nsteps do\n", 749 | " dat = {}\n", 750 | " dat[1] = torch.zeros(2,10):fill(-26):cuda()\n", 751 | " dat[1][{{},k}] = 1\n", 752 | " \n", 753 | " dat[2] = torch.zeros(2,model_opts.nLatentDims):cuda()\n", 754 | " dat[2][{{},3}] = pos[i]\n", 755 | " dat[2][{{},4}] = pos[j]\n", 756 | " \n", 757 | " dat[3] = torch.zeros(2,model_opts.nLatentDims):cuda()\n", 758 | " \n", 759 | " imgs_j[j] = decoder:forward(dat)[1]:float()\n", 760 | " \n", 761 | " end\n", 762 | " \n", 763 | " imgs_i[i] = torch.cat(imgs_j, 2)\n", 764 | " end\n", 765 | " \n", 766 | " imgs_struct = torch.cat(imgs_i,3)\n", 767 | "-- itorch.image(imgs_struct)\n", 768 | " save_path = save_dir .. '/' .. classes[k] .. '.png'\n", 769 | " image.save(save_path, imgs_struct)\n", 770 | "end\n" 771 | ] 772 | }, 773 | { 774 | "cell_type": "code", 775 | "execution_count": null, 776 | "metadata": { 777 | "collapsed": false, 778 | "deletable": true, 779 | "editable": true 780 | }, 781 | "outputs": [], 782 | "source": [ 783 | "\n", 784 | "\n", 785 | "\n", 786 | "print_images = function(encoder, decoder, dataProvider, train_or_test, save_dir)\n", 787 | " \n", 788 | " require 'paths'\n", 789 | "\n", 790 | " paths.mkdir(save_dir)\n", 791 | "\n", 792 | " ndat = dataProvider[train_or_test].labels:size(1)\n", 793 | " nclasses = dataProvider[train_or_test].labels:size(2)\n", 794 | "\n", 795 | " for j = 1,ndat do\n", 796 | "-- print('Printing image ' .. j)\n", 797 | " img_num = torch.LongTensor{j}\n", 798 | "\n", 799 | " local label = dataProvider:getLabels(torch.LongTensor{j,j}, train_or_test)\n", 800 | " local img = dataProvider:getImages(torch.LongTensor{j,j}, train_or_test):cuda()\n", 801 | " local imsize = img:size()\n", 802 | " local out_latent = nil\n", 803 | " out_latent = encoder:forward(img)\n", 804 | "\n", 805 | " out_latent[3] = torch.zeros(2,model_opts.nLatentDims):cuda()\n", 806 | "\n", 807 | " local out_img = torch.Tensor(nclasses, 3, imsize[3], imsize[4]):cuda()\n", 808 | "\n", 809 | " im_path = dataProvider.image_paths[dataProvider[train_or_test].inds[j]]\n", 810 | " tokens = utils.split(im_path, '/|.')\n", 811 | " im_class = tokens[7]\n", 812 | " im_id = tokens[8]\n", 813 | " \n", 814 | " print('Printing images for ' .. im_class .. ' ' .. im_id)\n", 815 | " \n", 816 | " save_path = save_dir .. '/' .. im_class .. '_'.. im_id .. '_orig.png'\n", 817 | " image.save(save_path, img[1])\n", 818 | " \n", 819 | " for i = 1,#dataProvider.classes do\n", 820 | " save_path = save_dir .. '/' .. im_class .. '_'.. im_id .. '_pred_' .. classes[i] .. '.png'\n", 821 | " \n", 822 | " if not paths.filep(save_path) then\n", 823 | " one_hot = torch.ones(2,nclasses):fill(-25):cuda()\n", 824 | " one_hot[{{1},{i}}] = 0\n", 825 | "\n", 826 | " out_latent[1] = one_hot\n", 827 | " out_img = decoder:forward(out_latent)[1][2]\n", 828 | "\n", 829 | "\n", 830 | " image.save(save_path, out_img)\n", 831 | " end\n", 832 | " \n", 833 | " end\n", 834 | " end\n", 835 | "end\n", 836 | "\n", 837 | "print_images(encoder, decoder, dataProvider, 'train', model_opts.save_dir .. '/' .. 'pred_train')\n", 838 | "print_images(encoder, decoder, dataProvider, 'test', model_opts.save_dir .. '/' .. 'pred_test')" 839 | ] 840 | }, 841 | { 842 | "cell_type": "code", 843 | "execution_count": 28, 844 | "metadata": { 845 | "collapsed": false, 846 | "deletable": true, 847 | "editable": true 848 | }, 849 | "outputs": [ 850 | { 851 | "data": { 852 | "text/plain": [ 853 | "{\n", 854 | " 1 : Alpha actinin\n", 855 | " 2 : Alpha tubulin\n", 856 | " 3 : Beta actin\n", 857 | " 4 : Desmoplakin\n", 858 | " 5 : Fibrillarin\n", 859 | " 6 : Lamin B1\n", 860 | " 7 : Myosin IIB\n", 861 | " 8 : Sec61 beta\n", 862 | " 9 : Tom20\n", 863 | " 10 : ZO1\n", 864 | "}\n" 865 | ] 866 | }, 867 | "execution_count": 28, 868 | "metadata": {}, 869 | "output_type": "execute_result" 870 | }, 871 | { 872 | "data": { 873 | "text/plain": [ 874 | " 22 2 5 1 0 0 0 0 0 1\n", 875 | " 0 36 3 0 0 0 0 1 1 0\n", 876 | " 3 7 19 0 0 0 0 0 0 0\n", 877 | " 1 0 1 7 0 0 0 0 0 1\n", 878 | " 0 0 0 0 35 0 0 0 0 0\n", 879 | " 0 0 0 0 0 46 0 0 0 0\n", 880 | " 2 0 0 1 0 0 0 0 1 4\n", 881 | " 0 1 0 0 0 0 0 50 0 0\n", 882 | " 0 1 0 0 0 0 0 0 47 0\n", 883 | " 1 0 0 1 0 0 2 0 0 1\n", 884 | "[torch.FloatTensor of size 10x10]\n", 885 | "\n", 886 | "total dat: 6077\t\n", 887 | "\t\n", 888 | "By label:\t\n", 889 | "Alpha actinin: 493, 462, 31\t\n", 890 | "Alpha tubulin: 1043, 1002, 41\t\n" 891 | ] 892 | }, 893 | "execution_count": 28, 894 | "metadata": {}, 895 | "output_type": "execute_result" 896 | }, 897 | { 898 | "data": { 899 | "text/plain": [ 900 | "Beta actin: 542, 513, 29\t\n", 901 | "Desmoplakin: 229, 219, 10\t\n", 902 | "Fibrillarin: 988, 953, 35\t\n", 903 | "Lamin B1: 785, 739, 46\t\n", 904 | "Myosin IIB: 157, 149, 8\t\n", 905 | "Sec61 beta: 835, 784, 51\t\n", 906 | "Tom20: 771, 723, 48\t\n", 907 | "ZO1: 234, 229, 5\t\n" 908 | ] 909 | }, 910 | "execution_count": 28, 911 | "metadata": {}, 912 | "output_type": "execute_result" 913 | } 914 | ], 915 | "source": [ 916 | "\n", 917 | "classes_sort = {}\n", 918 | "for i = 1,#classes do\n", 919 | " classes_sort[i] = classes[i]\n", 920 | "end\n", 921 | " \n", 922 | "-- print(classes_sort)\n", 923 | "utils.alphanumsort(classes_sort)\n", 924 | "\n", 925 | "order = torch.zeros(#classes)\n", 926 | "order_bak = torch.zeros(#classes)\n", 927 | "for i = 1, #classes do\n", 928 | " for j = 1, #classes do\n", 929 | " if classes_sort[i] == classes[j] then\n", 930 | " order[j] = i\n", 931 | " order_bak[i] = j\n", 932 | " end\n", 933 | " end\n", 934 | "end\n", 935 | "print(classes_sort)\n", 936 | "-- print(classes)\n", 937 | "-- print(order)\n", 938 | "\n", 939 | "\n", 940 | "eval_images = function(encoder, dataProvider, train_or_test)\n", 941 | " \n", 942 | " require 'paths'\n", 943 | "\n", 944 | "\n", 945 | " ndat = dataProvider[train_or_test].labels:size(1)\n", 946 | " nclasses = dataProvider[train_or_test].labels:size(2)\n", 947 | "\n", 948 | " confmat = torch.zeros(nclasses, nclasses)\n", 949 | " \n", 950 | " for j = 1,ndat do\n", 951 | "-- print('Printing image ' .. j)\n", 952 | " img_num = torch.LongTensor{j}\n", 953 | "\n", 954 | " label = dataProvider:getLabels(torch.LongTensor{j,j}, train_or_test)\n", 955 | " local img = dataProvider:getImages(torch.LongTensor{j,j}, train_or_test):cuda()\n", 956 | " local imsize = img:size()\n", 957 | "-- local out_latent = nil\n", 958 | " out_latent = encoder:forward(img) \n", 959 | " out_pred = out_latent[1][1]:type('torch.FloatTensor')\n", 960 | "\n", 961 | " tmp, pred_label = torch.max(out_pred,1)\n", 962 | " \n", 963 | " tmp, label = torch.max(label[1],1)\n", 964 | "\n", 965 | " confmat[order[label[1]]][order[pred_label[1]]] = confmat[order[label[1]]][order[pred_label[1]]] + 1\n", 966 | " end\n", 967 | "end\n", 968 | "\n", 969 | "eval_images(encoder, dataProvider, 'test')\n", 970 | "print(confmat)\n", 971 | "\n", 972 | "\n", 973 | "print('total dat: ' .. labels:size(1))\n", 974 | "\n", 975 | "print('')\n", 976 | "print('By label:')\n", 977 | "\n", 978 | "labels_test = labels:index(1,dataProvider.test.inds)\n", 979 | "labels_train = labels:index(1,dataProvider.train.inds)\n", 980 | "\n", 981 | "for i = 1, #dataProvider.classes do\n", 982 | " i = order_bak[i]\n", 983 | " ntot = torch.sum(torch.eq(labels,i))\n", 984 | " ntest = torch.sum(torch.eq(labels_test,i))\n", 985 | " ntrain = torch.sum(torch.eq(labels_train,i))\n", 986 | " print(classes[i] .. ': ' .. ntot .. ', ' .. ntrain .. ', ' .. ntest)\n", 987 | "end" 988 | ] 989 | }, 990 | { 991 | "cell_type": "code", 992 | "execution_count": 34, 993 | "metadata": { 994 | "collapsed": false, 995 | "deletable": true, 996 | "editable": true 997 | }, 998 | "outputs": [ 999 | { 1000 | "data": { 1001 | "text/plain": [ 1002 | "Alpha actinin & 22 & 2 & 5 & 1 & 0 & 0 & 0 & 0 & 0 & 1 \\\\\t\n", 1003 | "Alpha tubulin & 0 & 36 & 3 & 0 & 0 & 0 & 0 & 1 & 1 & 0 \\\\\t\n", 1004 | "Beta actin & 3 & 7 & 19 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\\\\t\n", 1005 | "Desmoplakin & 1 & 0 & 1 & 7 & 0 & 0 & 0 & 0 & 0 & 1 \\\\\t\n", 1006 | "Fibrillarin & 0 & 0 & 0 & 0 & 35 & 0 & 0 & 0 & 0 & 0 \\\\\t\n", 1007 | "Lamin B1 & 0 & 0 & 0 & 0 & 0 & 46 & 0 & 0 & 0 & 0 \\\\\t\n", 1008 | "Myosin IIB & 2 & 0 & 0 & 1 & 0 & 0 & 0 & 0 & 1 & 4 \\\\\t\n", 1009 | "Sec61 beta & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 50 & 0 & 0 \\\\\t\n", 1010 | "Tom20 & 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 & 47 & 0 \\\\\t\n", 1011 | "ZO1 & 1 & 0 & 0 & 1 & 0 & 0 & 2 & 0 & 0 & 1 \\\\\t\n" 1012 | ] 1013 | }, 1014 | "execution_count": 34, 1015 | "metadata": {}, 1016 | "output_type": "execute_result" 1017 | } 1018 | ], 1019 | "source": [ 1020 | "-- latex formatted table shennanary\n", 1021 | "\n", 1022 | "for i = 1,#classes do\n", 1023 | " str = classes_sort[i]\n", 1024 | " for j =1,#classes do\n", 1025 | " str = str .. ' & ' .. confmat[i][j]\n", 1026 | " end\n", 1027 | " str = str .. ' \\\\\\\\'\n", 1028 | " \n", 1029 | " print(str)\n", 1030 | "end" 1031 | ] 1032 | }, 1033 | { 1034 | "cell_type": "code", 1035 | "execution_count": 36, 1036 | "metadata": { 1037 | "collapsed": false, 1038 | "deletable": true, 1039 | "editable": true 1040 | }, 1041 | "outputs": [ 1042 | { 1043 | "data": { 1044 | "text/plain": [ 1045 | "Alpha actinin & 493 & 462 & 31 \\\\\t\n", 1046 | "Alpha tubulin & 1043 & 1002 & 41 \\\\\t\n", 1047 | "Beta actin & 542 & 513 & 29 \\\\\t\n", 1048 | "Desmoplakin & 229 & 219 & 10 \\\\\t\n", 1049 | "Fibrillarin & 988 & 953 & 35 \\\\\t\n", 1050 | "Lamin B1 & 785 & 739 & 46 \\\\\t\n", 1051 | "Myosin IIB & 157 & 149 & 8 \\\\\t\n", 1052 | "Sec61 beta & 835 & 784 & 51 \\\\\t\n" 1053 | ] 1054 | }, 1055 | "execution_count": 36, 1056 | "metadata": {}, 1057 | "output_type": "execute_result" 1058 | }, 1059 | { 1060 | "data": { 1061 | "text/plain": [ 1062 | "Tom20 & 771 & 723 & 48 \\\\\t\n", 1063 | "ZO1 & 234 & 229 & 5 \\\\\t\n" 1064 | ] 1065 | }, 1066 | "execution_count": 36, 1067 | "metadata": {}, 1068 | "output_type": "execute_result" 1069 | } 1070 | ], 1071 | "source": [ 1072 | "for i = 1, #dataProvider.classes do\n", 1073 | " i = order_bak[i]\n", 1074 | " ntot = torch.sum(torch.eq(labels,i))\n", 1075 | " ntest = torch.sum(torch.eq(labels_test,i))\n", 1076 | " ntrain = torch.sum(torch.eq(labels_train,i))\n", 1077 | " print(classes[i] .. ' & ' .. ntot .. ' & ' .. ntrain .. ' & ' .. ntest .. ' \\\\\\\\')\n", 1078 | "end" 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "code", 1083 | "execution_count": 4, 1084 | "metadata": { 1085 | "collapsed": false, 1086 | "deletable": true, 1087 | "editable": true 1088 | }, 1089 | "outputs": [ 1090 | { 1091 | "data": { 1092 | "text/plain": [ 1093 | "total dat: 6077\t\n", 1094 | "\t\n", 1095 | "By label:\t\n", 1096 | "Beta actin: 542, 513, 29\t\n", 1097 | "ZO1: 234, 229, 5\t\n", 1098 | "Sec61 beta: 835, 784, 51\t\n", 1099 | "Alpha actinin: 493, 462, 31\t\n", 1100 | "Tom20: 771, 723, 48\t\n", 1101 | "Lamin B1: 785, 739, 46\t\n", 1102 | "Alpha tubulin: 1043, 1002, 41\t\n" 1103 | ] 1104 | }, 1105 | "execution_count": 4, 1106 | "metadata": {}, 1107 | "output_type": "execute_result" 1108 | }, 1109 | { 1110 | "data": { 1111 | "text/plain": [ 1112 | "Desmoplakin: 229, 219, 10\t\n", 1113 | "Myosin IIB: 157, 149, 8\t\n", 1114 | "Fibrillarin: 988, 953, 35\t\n" 1115 | ] 1116 | }, 1117 | "execution_count": 4, 1118 | "metadata": {}, 1119 | "output_type": "execute_result" 1120 | } 1121 | ], 1122 | "source": [ 1123 | "print('total dat: ' .. labels:size(1))\n", 1124 | "\n", 1125 | "print('')\n", 1126 | "print('By label:')\n", 1127 | "\n", 1128 | "labels_test = labels:index(1,dataProvider.test.inds)\n", 1129 | "labels_train = labels:index(1,dataProvider.train.inds)\n", 1130 | "\n", 1131 | "for i = 1, #dataProvider.classes do\n", 1132 | " ntot = torch.sum(torch.eq(labels,i))\n", 1133 | " ntest = torch.sum(torch.eq(labels_test,i))\n", 1134 | " ntrain = torch.sum(torch.eq(labels_train,i))\n", 1135 | " print(classes[i] .. ': ' .. ntot .. ', ' .. ntrain .. ', ' .. ntest)\n", 1136 | "end" 1137 | ] 1138 | }, 1139 | { 1140 | "cell_type": "code", 1141 | "execution_count": 20, 1142 | "metadata": { 1143 | "collapsed": false, 1144 | "deletable": true, 1145 | "editable": true 1146 | }, 1147 | "outputs": [ 1148 | { 1149 | "data": { 1150 | "text/plain": [ 1151 | "{\n", 1152 | " 1 : Beta actin\n", 1153 | " 2 : ZO1\n", 1154 | " 3 : Sec61 beta\n", 1155 | " 4 : Alpha actinin\n", 1156 | " 5 : Tom20\n", 1157 | " 6 : Lamin B1\n", 1158 | " 7 : Alpha tubulin\n", 1159 | " 8 : Desmoplakin\n", 1160 | " 9 : Myosin IIB\n", 1161 | " 10 : Fibrillarin\n", 1162 | "}\n", 1163 | "{\n", 1164 | " 1 : Alpha actinin\n", 1165 | " 2 : Alpha tubulin\n", 1166 | " 3 : Beta actin\n", 1167 | " 4 : Desmoplakin\n", 1168 | " 5 : Fibrillarin\n", 1169 | " 6 : Lamin B1\n", 1170 | " 7 : Myosin IIB\n", 1171 | " 8 : Sec61 beta\n", 1172 | " 9 : Tom20\n", 1173 | " 10 : ZO1\n", 1174 | "}\n", 1175 | "{\n", 1176 | " 1 : Alpha actinin\n", 1177 | " 2 : Alpha tubulin\n", 1178 | " 3 : Beta actin\n", 1179 | " 4 : Desmoplakin\n", 1180 | " 5 : Fibrillarin\n", 1181 | " 6 : Lamin B1\n", 1182 | " 7 : Myosin IIB\n", 1183 | " 8 : Sec61 beta\n", 1184 | " 9 : Tom20\n", 1185 | " " 1186 | ] 1187 | }, 1188 | "execution_count": 20, 1189 | "metadata": {}, 1190 | "output_type": "execute_result" 1191 | }, 1192 | { 1193 | "data": { 1194 | "text/plain": [ 1195 | "10 : ZO1\n", 1196 | "}\n", 1197 | "{\n", 1198 | " 1 : Beta actin\n", 1199 | " 2 : ZO1\n", 1200 | " 3 : Sec61 beta\n", 1201 | " 4 : Alpha actinin\n", 1202 | " 5 : Tom20\n", 1203 | " 6 : Lamin B1\n", 1204 | " 7 : Alpha tubulin\n", 1205 | " 8 : Desmoplakin\n", 1206 | " 9 : Myosin IIB\n", 1207 | " 10 : Fibrillarin\n", 1208 | "}\n", 1209 | " 4\n", 1210 | " 7\n", 1211 | " 1\n", 1212 | " 8\n", 1213 | " 10\n", 1214 | " 6\n", 1215 | " 9\n", 1216 | " 3\n", 1217 | " 5\n", 1218 | " 2\n", 1219 | "[torch.FloatTensor of size 10]\n", 1220 | "\n" 1221 | ] 1222 | }, 1223 | "execution_count": 20, 1224 | "metadata": {}, 1225 | "output_type": "execute_result" 1226 | } 1227 | ], 1228 | "source": [] 1229 | }, 1230 | { 1231 | "cell_type": "code", 1232 | "execution_count": 9, 1233 | "metadata": { 1234 | "collapsed": false, 1235 | "deletable": true, 1236 | "editable": true 1237 | }, 1238 | "outputs": [ 1239 | { 1240 | "data": { 1241 | "text/plain": [ 1242 | "ZO1\t\n" 1243 | ] 1244 | }, 1245 | "execution_count": 9, 1246 | "metadata": {}, 1247 | "output_type": "execute_result" 1248 | } 1249 | ], 1250 | "source": [] 1251 | }, 1252 | { 1253 | "cell_type": "code", 1254 | "execution_count": null, 1255 | "metadata": { 1256 | "collapsed": true, 1257 | "deletable": true, 1258 | "editable": true 1259 | }, 1260 | "outputs": [], 1261 | "source": [] 1262 | } 1263 | ], 1264 | "metadata": { 1265 | "kernelspec": { 1266 | "display_name": "iTorch", 1267 | "language": "lua", 1268 | "name": "itorch" 1269 | }, 1270 | "language_info": { 1271 | "name": "lua", 1272 | "version": "5.1" 1273 | } 1274 | }, 1275 | "nbformat": 4, 1276 | "nbformat_minor": 2 1277 | } 1278 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | Torch Integrated Cell 2 | =============================== 3 | 4 | ![Model Architecture](doc/images/model_arch.png?raw=true "Model Architecture") 5 | 6 | Image-driven generative cell modelling with adversarial autoencoders: https://arxiv.org/abs/1705.00092 7 | 8 | ## For the *updated 3D version*, please see: 9 | **Building a 3D Integrated Cell** 10 | Manuscript: https://www.biorxiv.org/content/early/2017/12/21/238378 11 | GitHub: https://github.com/AllenCellModeling/pytorch_integrated_cell 12 | 13 | ## Installation 14 | Installing on linux is recommended. 15 | 16 | ### prerequisites 17 | Running on docker is recommended, though not required. 18 | 19 | - install torch on docker / nvidia-docker as in e.g. this guide: https://github.com/gregjohnso/dl-docker 20 | - download the training images: `aws s3 cp s3://aics.integrated.cell.arxiv.paper.data . --recursive --no-sign-request` 21 | 22 | ### Steps: 23 | After you clone this repository, you will need to edit the mount points for the images in `run_docker.sh` to point to where you saved them. 24 | Once those locations are properly set, you can start the docker image with 25 | 26 | `bash run_docker.sh` 27 | 28 | Once you're in the docker container, you can train the model with 29 | 30 | `bash train_model_2D.sh` 31 | 32 | This will take a while, probably about 12-18 hours. 33 | 34 | ## Project website 35 | Example outputs of this model can be viewed at http://www.allencell.org 36 | 37 | ## Citation 38 | If you find this code useful in your research, please consider citing the following paper: 39 | 40 | @article{johnson2017generative, 41 | title={Generative Modeling with Conditional Autoencoders: Building an Integrated Cell}, 42 | author={Gregory R. Johnson, Rory M. Donovan-Maiye, Mary M. Maleckar}, 43 | journal={arXiv preprint arXiv:1705.00092}, 44 | year={2017}, 45 | url={https://arxiv.org/abs/1705.00092} 46 | } 47 | 48 | ## Contact 49 | Gregory Johnson 50 | E-mail: gregj@alleninstitute.org 51 | 52 | ## License 53 | This program is free software: you can redistribute it and/or modify 54 | it under the terms of the GNU General Public License as published by 55 | the Free Software Foundation, either version 3 of the License, or 56 | (at your option) any later version. 57 | 58 | This program is distributed in the hope that it will be useful, 59 | but WITHOUT ANY WARRANTY; without even the implied warranty of 60 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 61 | GNU General Public License for more details. 62 | 63 | You should have received a copy of the GNU General Public License 64 | along with this program. If not, see . 65 | -------------------------------------------------------------------------------- /run_docker.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | nvidia-docker run -it \ 4 | -v /directory/containing/all/ten/image/class/directories:/root/images/release_4_1_17/release_v2/aligned/2D \ 5 | -v $(pwd):/root/torch_integrated_cell \ 6 | floydhub/dl-docker:gpu \ 7 | bash 8 | -------------------------------------------------------------------------------- /setup.lua: -------------------------------------------------------------------------------- 1 | setup = {} 2 | 3 | 4 | function setup.getModelOpts() 5 | 6 | local model_opts = {} 7 | model_opts.cuda = cuda 8 | 9 | model_opts.parent_dir = 'AAE_shape_learner_v2' 10 | model_opts.model_name = 'cell_learner_3_caae_learner' 11 | model_opts.save_dir = model_opts.parent_dir .. '/' .. model_opts.model_name 12 | model_opts.image_dir = '/root/images/2016_11_08_Nuc_Cell_Seg_8_cell_lines_V22/processed_aligned/2D' 13 | model_opts.image_sub_size = 16.96875/2 14 | 15 | model_opts.nLatentDims = 16 16 | model_opts.channel_inds_in = torch.LongTensor{1,3} 17 | model_opts.channel_inds_out = torch.LongTensor{1,3} 18 | 19 | model_opts.rotate = true 20 | model_opts.nChIn = model_opts.channel_inds_in:size(1) 21 | model_opts.nChOut = model_opts.channel_inds_out:size(1) 22 | model_opts.nClasses = 0 23 | model_opts.nOther = 0 24 | model_opts.dropoutRate = 0.2 25 | model_opts.fullConv = true 26 | 27 | model_opts.adversarialGen = false 28 | 29 | model_opts.test_model = false 30 | model_opts.verbose = true 31 | 32 | paths.mkdir(model_opts.parent_dir) 33 | 34 | return model_opts 35 | end 36 | 37 | 38 | 39 | function setup.getModel() 40 | 41 | print('Loading model type: ' .. model_opts.model_name) 42 | 43 | package.loaded['models/' .. model_opts.model_name] = nil 44 | Model = require ('models/' .. model_opts.model_name) 45 | autoencoder = nil 46 | encoder = nil 47 | decoder = nil 48 | 49 | if paths.filep(model_opts.save_dir .. '/decoder.t7') then 50 | print('Loading model from ' .. model_opts.save_dir) 51 | 52 | print('Loading encoder') 53 | encoder = torch.load(model_opts.save_dir .. '/encoder.t7') 54 | encoder:float() 55 | encoder:clearState() 56 | collectgarbage() 57 | 58 | print('Loading adversary') 59 | adversary = torch.load(model_opts.save_dir .. '/adversary.t7') 60 | adversary:float() 61 | adversary:clearState() 62 | collectgarbage() 63 | 64 | if paths.filep(model_opts.save_dir .. '/adversaryGen.t7') then 65 | print('Loading adversaryGen') 66 | adversaryGen = torch.load(model_opts.save_dir .. '/adversaryGen.t7') 67 | adversaryGen:float() 68 | adversaryGen:clearState() 69 | collectgarbage() 70 | else 71 | adversaryGen = nil 72 | end 73 | 74 | print('Loading decoder') 75 | decoder = torch.load(model_opts.save_dir .. '/decoder.t7') 76 | decoder:float() 77 | decoder:clearState() 78 | collectgarbage() 79 | 80 | 81 | 82 | if paths.filep(model_opts.save_dir .. '/rng.t7') then 83 | print('Loading RNG') 84 | torch.setRNGState(torch.load(model_opts.save_dir .. '/rng.t7')) 85 | cutorch.setRNGState(torch.load(model_opts.save_dir .. '/rng_cuda.t7')) 86 | end 87 | 88 | 89 | -- print(gpuIDs) 90 | set_gpu_id = function(model, gpuIDs) 91 | for i = 1, #model.modules do 92 | model.modules[i].device = gpuIDs[i] 93 | 94 | if i == #model.modules then 95 | model.modules[i].outdevice = gpu1 96 | else 97 | model.modules[i].outdevice = gpuIDs[i+1] 98 | end 99 | end 100 | return model 101 | end 102 | 103 | gpuIDs = torch.zeros(#encoder.modules):fill(gpu1) 104 | -- gpuIDs:sub(9,#encoder.modules):fill(gpu2) 105 | encoder = set_gpu_id(encoder, gpuIDs) 106 | 107 | gpuIDs = torch.zeros(#decoder.modules):fill(gpu2) 108 | decoder = set_gpu_id(decoder, gpuIDs) 109 | 110 | gpuIDs = torch.zeros(#adversary.modules):fill(gpu2) 111 | adversary = set_gpu_id(adversary, gpuIDs) 112 | 113 | if adversaryGen ~= nil then 114 | gpuIDs = torch.zeros(#adversaryGen.modules):fill(gpu3) 115 | adversaryGen = set_gpu_id(adversaryGen, gpuIDs) 116 | end 117 | 118 | print('Done loading model') 119 | else 120 | print('Creating new model') 121 | print(model_opts) 122 | Model:create(model_opts) 123 | paths.mkdir(model_opts.save_dir) 124 | print('Done creating model') 125 | 126 | decoder = Model.decoder 127 | encoder = Model.encoder 128 | adversary = Model.adversary 129 | adversaryGen = Model.adversaryGen 130 | 131 | 132 | -- print(gpuIDs) 133 | set_gpu_id = function(model, gpuIDs) 134 | for i = 1, #model.modules do 135 | model.modules[i] = nn.GPU(model.modules[i], gpuIDs[i]) 136 | -- model.modules[i].device = gpuIDs[i] 137 | if i == #model.modules then 138 | model.modules[i].outdevice = gpu1 139 | else 140 | model.modules[i].outdevice = gpuIDs[i+1] 141 | end 142 | end 143 | return model 144 | end 145 | 146 | gpuIDs = torch.zeros(#encoder.modules):fill(gpu1) 147 | -- gpuIDs:sub(4,#encoder.modules):fill(gpu2) 148 | encoder = set_gpu_id(encoder, gpuIDs) 149 | 150 | gpuIDs = torch.zeros(#decoder.modules):fill(gpu2) 151 | decoder = set_gpu_id(decoder, gpuIDs) 152 | 153 | gpuIDs = torch.zeros(#adversary.modules):fill(gpu2) 154 | adversary = set_gpu_id(adversary, gpuIDs) 155 | 156 | gpuIDs = torch.zeros(#adversaryGen.modules):fill(gpu3) 157 | adversaryGen = set_gpu_id(adversaryGen, gpuIDs) 158 | 159 | opts_optnet = {inplace=true, mode='training'} 160 | optnet.optimizeMemory(encoder, dataProvider:getImages(torch.LongTensor{1,2}, 'train'), opts_optnet) 161 | optnet.optimizeMemory(decoder, encoder.output, opts_optnet) 162 | optnet.optimizeMemory(adversary, encoder.output[#encoder.output], opts_optnet) 163 | optnet.optimizeMemory(adversaryGen, dataProvider:getImages(torch.LongTensor{1,2}, 'train'), opts_optnet) 164 | 165 | Model = nil 166 | end 167 | 168 | criterion_out = nn.BCECriterion() 169 | criterion_label = nn.ClassNLLCriterion() 170 | criterion_other = nn.MSECriterion() 171 | criterion_latent = nn.MSECriterion() 172 | 173 | criterion_adv = nn.BCECriterion() 174 | criterionAdvGen = nn.BCECriterion() 175 | 176 | if cuda then 177 | -- set which parts of the model are on which gpu 178 | print('Converting to cuda') 179 | 180 | encoder:cuda() 181 | decoder:cuda() 182 | adversary:cuda() 183 | 184 | if adversaryGen ~= nil then 185 | adversaryGen:cuda() 186 | end 187 | 188 | criterion_label:cuda() 189 | criterion_other:cuda() 190 | criterion_latent:cuda() 191 | criterion_out:cuda() 192 | criterion_adv:cuda() 193 | criterionAdvGen:cuda() 194 | 195 | print('Done converting to cuda') 196 | end 197 | 198 | if model_opts.test_model then 199 | 200 | print('Data size') 201 | print(dataProvider.train.inds:size()) 202 | 203 | encoder:evaluate() 204 | decoder:evaluate() 205 | adversary:evaluate() 206 | 207 | local im_in, im_out = dataProvider:getImages(torch.LongTensor{1}, 'train') 208 | local label = dataProvider:getLabels(torch.LongTensor{1}, 'train'):clone() 209 | 210 | print('Testing encoder') 211 | local code = encoder:forward(im_in:cuda()); 212 | 213 | print(im_in:type()) 214 | print(im_in:size()) 215 | 216 | print('Code size:') 217 | print(code) 218 | 219 | print(label) 220 | 221 | print(torch.cat(code, 2):size()) 222 | 223 | print('Testing decoder') 224 | 225 | local im_out_hat = decoder:forward(code) 226 | 227 | print('Out size:') 228 | print(im_out_hat:size()) 229 | 230 | print(criterion_out:forward(im_out_hat, im_out:cuda())) 231 | 232 | -- itorch.image(imtools.im2projection(im_out)) 233 | -- itorch.image(imtools.im2projection(im_out_hat)) 234 | 235 | print('Testing adversary') 236 | print(adversary:forward(code[#code])) 237 | 238 | encoder:training() 239 | decoder:training() 240 | adversary:training() 241 | adversaryGen:training() 242 | end 243 | 244 | cudnn.benchmark = false 245 | cudnn.fastest = false 246 | 247 | cudnn.convert(encoder, cudnn) 248 | cudnn.convert(decoder, cudnn) 249 | cudnn.convert(adversary, cudnn) 250 | 251 | if adversaryGen ~= nil then 252 | cudnn.convert(adversaryGen, cudnn) 253 | end 254 | 255 | print('Done getting parameters') 256 | end 257 | 258 | function setup.getLearnOpts(model_opts) 259 | 260 | opt_path = model_opts.save_dir .. '/opt.t7' 261 | optEnc_path = model_opts.save_dir .. '/optEnc.t7' 262 | optDec_path = model_opts.save_dir .. '/optDec.t7' 263 | optAdv_path = model_opts.save_dir .. '/optD.t7' 264 | optAdvGen_path = model_opts.save_dir .. '/optAdvGen.t7' 265 | 266 | stateEnc_path = model_opts.save_dir .. '/stateEnc.t7' 267 | stateDec_path = model_opts.save_dir .. '/stateDec.t7' 268 | stateAdv_path = model_opts.save_dir .. '/stateD.t7' 269 | stateAdvGen_path = model_opts.save_dir .. '/stateAdvGen.t7' 270 | 271 | plots_path = model_opts.save_dir .. '/plots.t7' 272 | 273 | if paths.filep(opt_path) then 274 | print('Loading previous optimizer state') 275 | 276 | opt = torch.load(opt_path) 277 | 278 | optEnc = torch.load(optEnc_path) 279 | optDec = torch.load(optDec_path) 280 | optAdv = torch.load(optAdv_path) 281 | optAdvGen = torch.load(optAdv_path) 282 | 283 | 284 | stateEnc = utils.table2cuda(torch.load(stateEnc_path)) 285 | stateDec = utils.table2cuda(torch.load(stateDec_path)) 286 | stateAdv = utils.table2cuda(torch.load(stateAdv_path)) 287 | 288 | 289 | plots = torch.load(plots_path) 290 | 291 | if paths.filep(stateAdvGen_path) then 292 | stateAdvGen = utils.table2cuda(torch.load(stateAdvGen_path)) 293 | 294 | losses = plots[1][3]:totable() 295 | latentlosses = plots[2][3]:totable() 296 | advlosses = plots[3][3]:totable() 297 | advGenLosses = plots[4][3]:totable() 298 | advMinimaxLoss = plots[5][3]:totable() 299 | advGenMinimaxLoss = plots[6][3]:totable() 300 | reencodelosses = plots[7][3]:totable() 301 | 302 | else 303 | print('Could not find stateAdvGen_path: ' .. stateAdvGen_path) 304 | stateAdvGen = {} 305 | 306 | losses = plots[1][3]:totable() 307 | latentlosses = plots[2][3]:totable() 308 | advlosses = plots[3][3]:totable() 309 | advMinimaxLoss = plots[4][3]:totable() 310 | reencodelosses = plots[5][3]:totable() 311 | end 312 | 313 | else 314 | opt = {} 315 | 316 | opt.epoch = 0 317 | opt.nepochs = 2000 318 | opt.adversarial = true 319 | opt.adversarialGen = model_opts.adversarialGen 320 | 321 | opt.learningRateA = 0.01 322 | opt.learningRateAdv = 0.01 323 | 324 | opt.min_rateA = 0.0000001 325 | opt.min_rateD = 0.0001 326 | opt.learningRateDecay = 0.999 327 | opt.optimizer = 'adam' 328 | opt.batchSize = 64 329 | opt.verbose = model_opts.verbose 330 | -- opt.updateD = 0.50 331 | -- opt.updateA = 0.40 332 | opt.update_thresh = 0.58 333 | opt.saveProgressIter = 5 334 | opt.saveStateIter = 50 335 | 336 | 337 | optEnc = {} 338 | optEnc.optimizer = opt.optimizer 339 | optEnc.learningRate = opt.learningRateA 340 | optEnc.min_rateA = 0.0000001 341 | optEnc.beta1 = 0.5 342 | -- optEnc.momentum = 0 343 | -- optEnc.numUpdates = 0 344 | -- optEnc.coefL2 = 0 345 | -- optEnc.coefL1 = 0 346 | 347 | optDec = {} 348 | optDec.optimizer = opt.optimizer 349 | optDec.learningRate = opt.learningRateA 350 | optDec.min_rateA = 0.0000001 351 | optDec.beta1 = 0.5 352 | -- optDec.momentum = 0 353 | -- optDec.numUpdates = 0 354 | -- optDec.coefL2 = 0 355 | -- optDec.coefL1 = 0 356 | 357 | optAdv = {} 358 | optAdv.optimizer = opt.optimizer 359 | optAdv.learningRate = opt.learningRateAdv 360 | optAdv.min_rateA = 0.0000001 361 | optAdv.beta1 = 0.5 362 | -- optAdv.momentum = 0 363 | -- optAdv.numUpdates = 0 364 | -- optAdv.coefL2 = 0 365 | -- optAdv.coefL1 = 0 366 | 367 | optAdvGen = {} 368 | optAdvGen.optimizer = opt.optimizer 369 | optAdvGen.learningRate = opt.learningRateAdv 370 | optAdvGen.min_rateA = 0.0000001 371 | optAdvGen.beta1 = 0.5 372 | -- optAdvGen.momentum = 0 373 | -- optAdvGen.numUpdates = 0 374 | -- optAdvGen.coefL2 = 0 375 | -- optAdvGen.coefL1 = 0 376 | 377 | stateEnc = {} 378 | stateDec = {} 379 | stateAdv = {} 380 | stateAdvGen = {} 381 | 382 | losses, latentlosses, advlosses, advGenLosses, advMinimaxLoss, advGenMinimaxLoss, advlossesGen, reencodelosses = {}, {}, {}, {}, {}, {}, {}, {}, {} 383 | end 384 | 385 | plots = {} 386 | loss, advloss, advlossGen = {}, {}, {} 387 | end 388 | 389 | return setup 390 | 391 | -------------------------------------------------------------------------------- /train.lua: -------------------------------------------------------------------------------- 1 | -- require 'optim2' 2 | require 'optnet' 3 | 4 | local learner = {} 5 | 6 | 7 | 8 | 9 | 10 | function optim_step(net, loss, optParam, optStates) 11 | -- this function assumes that all modules are nn.GPU-decorated 12 | local function feval_dummy(param) 13 | if thisparam ~= param then 14 | thisparam:copy(param) 15 | end 16 | return loss, thisgrad 17 | end 18 | 19 | local c = 1 20 | for i = 1, #net.modules do 21 | local gpu = net.modules[i].device 22 | cutorch.setDevice(gpu) 23 | 24 | local theta, gradTheta = net.modules[i]:parameters() 25 | 26 | for j = 1,#theta do 27 | thisparam = theta[j] 28 | thisgrad = gradTheta[j] 29 | 30 | optState = optStates[c] or {} 31 | 32 | for k, v in pairs(optState) do 33 | if type(v) == 'userdata' and v:getDevice() ~= gpu then 34 | optState[k] = v:clone() 35 | end 36 | end 37 | 38 | optim[optParam.optimizer](feval_dummy, thisparam, optParam, optState) 39 | optStates[c] = optState 40 | c = c+1 41 | end 42 | 43 | cutorch.setDevice(gpu1) 44 | end 45 | end 46 | 47 | function updateL1(net, deltaL1) 48 | if net.modules ~= nil then 49 | 50 | for i, module in ipairs(net.modules) do 51 | updateL1(module, deltaL1) 52 | end 53 | elseif torch.type(net):find('L1Penalty') then 54 | newL1 = net.l1weight + deltaL1 55 | if newL1 < 0 then 56 | newL1 = 0 57 | end 58 | net.l1weight = newL1 59 | end 60 | end 61 | 62 | function learner.loop(nIn) 63 | local x -- Minibatch 64 | local label 65 | 66 | local advLoss = 0 67 | local advGenLoss = 0 68 | local minimaxLoss = 0 69 | local reconLoss = 0 70 | local latentLoss = 0 71 | local reencodeLoss = 0 72 | local minimaxDecLoss = 0 73 | local minimaxDecLoss2 = 0 74 | -- Create optimiser function evaluation 75 | 76 | local dAdvDz = function(params) 77 | adversary:zeroGradParameters() 78 | 79 | local input = torch.Tensor(opt.batchSize, model_opts.nLatentDims):normal(0, 1):typeAs(x_in) 80 | local label = torch.ones(opt.batchSize):typeAs(x_in) -- Labels for real samples 81 | 82 | local output = adversary:forward(input) 83 | local errEnc_real = criterion_adv:forward(output, label) 84 | local df_do = criterion_adv:backward(output, label) 85 | adversary:backward(input, df_do) 86 | 87 | codes = encoder:forward(x_in) 88 | local input = codes[#codes] 89 | label = torch.zeros(opt.batchSize):typeAs(x_in) -- Labels for generated samples 90 | 91 | output = adversary:forward(input) 92 | local errEnc_fake = criterion_adv:forward(output, label) 93 | local df_do = criterion_adv:backward(output, label) 94 | adversary:backward(input, df_do) 95 | 96 | advLoss = (errEnc_real + errEnc_fake)/2 97 | 98 | return advLoss, gradParametersAdv 99 | end 100 | 101 | local dAdvGenDx = function(params) 102 | adversaryGen:zeroGradParameters() 103 | 104 | input = x_out 105 | if model_opts.nClasses > 0 then 106 | label = classLabel 107 | else 108 | label = torch.ones(opt.batchSize):typeAs(x_in) 109 | end 110 | 111 | local output = adversaryGen:forward(input) 112 | local errD_real = criterion:forward(output, label) 113 | local df_do = criterion:backward(output, label) 114 | adversaryGen:backward(input, df_do) 115 | 116 | zFake = {} 117 | c = 1 118 | if model_opts.nClasses > 0 then 119 | zFake[c] = classLabelOneHot 120 | c = c+1 121 | end 122 | if model_opts.nOther > 0 then 123 | zFake[c] = code 124 | c = c+1 125 | end 126 | 127 | zFake[c] = torch.Tensor(opt.batchSize, model_opts.nLatentDims):normal(0, 1):typeAs(x_in) 128 | 129 | input = decoder:forward(zFake) 130 | 131 | if model_opts.nClasses > 0 then 132 | label = torch.Tensor(opt.batchSize):typeAs(x_in):fill(model_opts.nClasses+1) 133 | else 134 | label = torch.zeros(opt.batchSize):typeAs(x_in) 135 | end 136 | 137 | local output = adversaryGen:forward(input) 138 | local errD_fake = criterion:forward(output, label) 139 | local df_do = criterion:backward(output, label) 140 | adversaryGen:backward(input, df_do) 141 | 142 | advGenLoss = (errD_real + errD_fake)/2 143 | 144 | return advGenLoss, gradParametersAdvGen 145 | end 146 | 147 | local dAutoencoderDx = function(params) 148 | -- if thetaEnc ~= params then 149 | -- thetaEnc:copy(params) 150 | -- end 151 | 152 | encoder:zeroGradParameters() 153 | -- decoder:zeroGradParameters() 154 | -- adversary:zeroGradParameters() 155 | 156 | -- the encoder has already gone forward 157 | xHat = decoder:forward(codes) 158 | 159 | -- Versus criterion 160 | -- x_out = nn.ReLU(true):cuda():forward(x_out) 161 | 162 | reconLoss = criterion_out:forward(xHat, x_out) 163 | loss = reconLoss 164 | local gradLoss = criterion_out:backward(xHat, x_out) 165 | 166 | -- Backwards pass 167 | encoder:backward(x_in, decoder:backward(codes, gradLoss)) 168 | 169 | -- Now the regularization pass 170 | encoder_out = encoder.output 171 | 172 | c = 1 173 | gradLosses = {{}} 174 | 175 | labelLoss = 0 176 | shapeLoss = 0 177 | 178 | if model_opts.nClasses > 0 then 179 | local labelHat = encoder_out[c] 180 | 181 | labelLoss = criterion_label:forward(labelHat, classLabel) 182 | local labelGradLoss = criterion_label:backward(labelHat, classLabel) 183 | 184 | -- loss = loss + labelLoss 185 | gradLosses[c] = labelGradLoss 186 | 187 | c = c+1 188 | end 189 | 190 | if model_opts.nOther > 0 then 191 | local shapeHat = encoder_out[c] 192 | shapeLoss = criterion_other:forward(shapeHat, code) 193 | local shapeGradLoss = criterion_other:backward(shapeHat, code) 194 | 195 | -- loss = loss + shapeLoss 196 | gradLosses[c] = shapeGradLoss 197 | 198 | c = c+1 199 | end 200 | 201 | local yReal = torch.ones(opt.batchSize):typeAs(x_in) 202 | -- Train autoencoder (generator) to play a minimax game with the adversary (discriminator): min_G max_D log(1 - D(G(x))) 203 | local predFake = adversary:forward(encoder.output[c]) 204 | minimaxLoss = criterion_adv:forward(predFake, yReal) 205 | local gradMinimaxLoss = criterion_adv:backward(predFake, yReal) 206 | local gradMinimax = adversary:updateGradInput(encoder.output[c], gradMinimaxLoss) -- Do not calculate gradient wrt adversary parameters 207 | gradLosses[c] = gradMinimax*model_opts.advLatentRatio 208 | 209 | -- gradLosses[c] = nn.Clip(-1, 1):cuda():forward(gradLosses[c]) 210 | 211 | encoder:backward(x_in, gradLosses) 212 | 213 | latentLoss = minimaxLoss + shapeLoss + labelLoss 214 | 215 | loss = reconLoss + latentLoss 216 | -- if fakeLoss < 0.79 and realLoss < 0.79 then -- -ln(0.45) 217 | cutorch.synchronizeAll() 218 | 219 | return loss, gradThetaEnc 220 | end 221 | 222 | local dDecdAdvGen = function(params) 223 | decoder:zeroGradParameters() 224 | 225 | --[[ the three lines below were already executed in fDx, so save computation 226 | noise:uniform(-1, 1) -- regenerate random noise 227 | local fake = netG:forward(noise) 228 | input:copy(fake) ]]-- 229 | label = nil 230 | if model_opts.nClasses > 0 then 231 | label = classLabel 232 | else 233 | label = torch.ones(opt.batchSize):typeAs(x_in) 234 | end 235 | 236 | local output = adversaryGen.output -- netD:forward(input) was already executed in fDx, so save computation 237 | minimaxDecLoss = criterion:forward(output, label) 238 | local df_do = criterion:backward(output, label) 239 | local df_dg = adversaryGen:updateGradInput(input, df_do):clone() 240 | 241 | adversaryGen:clearState() 242 | 243 | decoder:backward(zFake, df_dg*model_opts.advGenRatio) 244 | return minimaxDecLoss, gradParametersG 245 | 246 | end 247 | 248 | local dDecdAdvGen2 = function(params) 249 | 250 | --[[ the three lines below were already executed in fDx, so save computation 251 | noise:uniform(-1, 1) -- regenerate random noise 252 | local fake = netG:forward(noise) 253 | input:copy(fake) ]]-- 254 | if model_opts.nClasses > 0 then 255 | label = classLabel 256 | else 257 | label = torch.ones(opt.batchSize):typeAs(x_in) 258 | end 259 | 260 | -- local xHat = decoder:forward(codes) 261 | 262 | local output = adversaryGen:forward(xHat) -- netD:forward(input) was already executed in fDx, so save computation 263 | minimaxDecLoss2 = criterion:forward(output, label) 264 | local df_do = criterion:backward(output, label) 265 | df_dg = adversaryGen:updateGradInput(xHat, df_do) 266 | 267 | decoder:backward(codes, df_dg*model_opts.advGenRatio) 268 | return minimaxDecLoss, gradParametersG 269 | 270 | end 271 | 272 | 273 | local ndat = nIn or dataProvider.train.inds:size()[1] 274 | -- main loop 275 | print("Starting learning") 276 | while opt.epoch < opt.nepochs do 277 | local tic = torch.tic() 278 | 279 | opt.epoch = opt.epoch+1 280 | 281 | local indices = torch.randperm(ndat):long():split(opt.batchSize) 282 | indices[#indices] = nil 283 | local N = #indices * opt.batchSize 284 | 285 | for t,v in ipairs(indices) do 286 | collectgarbage() 287 | 288 | x_in, x_out = dataProvider:getImages(v, 'train') 289 | -- Forward pass 290 | x_in = x_in:cuda() 291 | x_out = x_out:cuda() 292 | 293 | 294 | if model_opts.nOther > 0 then 295 | classLabelOneHot = dataProvider:getLabels(v, 'train'):cuda() 296 | __, classLabel = torch.max(classLabelOneHot, 2) 297 | classLabel = torch.squeeze(classLabel:typeAs(x_in)) 298 | 299 | classLabelOneHot = torch.log(classLabelOneHot) 300 | classLabelOneHot:maskedFill(classLabelOneHot:eq(-math.huge), -25) 301 | 302 | end 303 | 304 | if model_opts.nClasses > 0 then 305 | code = dataProvider:getCodes(v, 'train') 306 | end 307 | 308 | -- decoder:zeroGradParameters() 309 | -- update the decoder's advarsary 310 | 311 | if model_opts.useGanD then 312 | dAdvGenDx() 313 | optim_step(adversaryGen, advGenLoss, optAdvGen, stateAdvGen) 314 | end 315 | 316 | -- update the encoder's advarsary 317 | dAdvDz() 318 | optim_step(adversary, advLoss, optAdv, stateAdv) 319 | adversary:clearState() 320 | 321 | if model_opts.useGanD then 322 | dDecdAdvGen() 323 | end 324 | dAutoencoderDx() 325 | -- codes = utils.shallowcopy(codes) 326 | -- encoder:clearState() 327 | 328 | if model_opts.useGanD then 329 | dDecdAdvGen2() 330 | end 331 | 332 | optim_step(encoder, loss, optEnc, stateEnc) 333 | optim_step(decoder, reconLoss+minimaxDecLoss+minimaxDecLoss2, optDec, stateDec) 334 | 335 | 336 | 337 | losses[#losses + 1] = reconLoss 338 | latentlosses[#latentlosses+1] = latentLoss 339 | reencodelosses[#reencodelosses+1] = reencodeLoss 340 | advlosses[#advlosses + 1] = advLoss 341 | advGenLosses[#advGenLosses + 1] = advGenLoss 342 | 343 | advMinimaxLoss[#advMinimaxLoss + 1] = minimaxLoss 344 | advGenMinimaxLoss[#advGenMinimaxLoss + 1] = minimaxDecLoss+minimaxDecLoss2 345 | 346 | 347 | encoder:clearState() 348 | decoder:clearState() 349 | adversaryGen:clearState() 350 | adversary:clearState() 351 | end 352 | collectgarbage() 353 | 354 | x_in, x_out = nil, nil 355 | 356 | recon_loss = torch.mean(torch.Tensor(losses)[{{-#indices,-1}}]); 357 | latent_loss = torch.mean(torch.Tensor(latentlosses)[{{#indices, -1}}]) 358 | reencode_loss = torch.mean(torch.Tensor(reencodelosses)[{{#indices, -1}}]) 359 | 360 | adv_loss = torch.mean(torch.Tensor(advlosses)[{{-#indices,-1}}]); 361 | advGen_loss = torch.mean(torch.Tensor(advGenLosses)[{{-#indices,-1}}]); 362 | minimax_latent_loss = torch.mean(torch.Tensor(advMinimaxLoss)[{{-#indices,-1}}]); 363 | minimax_gen_loss = torch.mean(torch.Tensor(advGenMinimaxLoss)[{{-#indices,-1}}]); 364 | 365 | print('Epoch ' .. opt.epoch .. '/' .. opt.nepochs .. ' Recon loss: ' .. recon_loss .. ' Adv loss: ' .. adv_loss .. ' AdvGen loss: ' .. advGen_loss .. ' time: ' .. torch.toc(tic)) 366 | print(minimax_latent_loss) 367 | print(minimax_gen_loss) 368 | 369 | if recon_loss == math.huge or recon_loss ~= recon_loss or latent_loss == math.huge or latent_loss ~= latent_loss then 370 | print('Exiting') 371 | break 372 | end 373 | 374 | -- Plot training curve(s) 375 | local plots = {{'Reconstruction', torch.linspace(1, #losses, #losses), torch.Tensor(losses), '-'}} 376 | plots[#plots + 1] = {'Latent', torch.linspace(1, #latentlosses, #latentlosses), torch.Tensor(latentlosses), '-'} 377 | plots[#plots + 1] = {'Adversary', torch.linspace(1, #advlosses, #advlosses), torch.Tensor(advlosses), '-'} 378 | plots[#plots + 1] = {'AdversaryGen', torch.linspace(1, #advGenLosses, #advGenLosses), torch.Tensor(advGenLosses), '-'} 379 | plots[#plots + 1] = {'MinimaxAdvLatent', torch.linspace(1, #advMinimaxLoss, #advMinimaxLoss), torch.Tensor(advMinimaxLoss), '-'} 380 | plots[#plots + 1] = {'MinimaxAdvGen', torch.linspace(1, #advGenMinimaxLoss, #advGenMinimaxLoss), torch.Tensor(advGenMinimaxLoss), '-'} 381 | plots[#plots + 1] = {'Reencode', torch.linspace(1, #reencodelosses, #reencodelosses), torch.Tensor(reencodelosses), '-'} 382 | 383 | if opt.epoch % opt.saveProgressIter == 0 then 384 | 385 | encoder:evaluate() 386 | decoder:evaluate() 387 | rotate_tmp = model_opts.rotate 388 | dataProvider.opts.rotate = false 389 | 390 | local x_in, x_out = dataProvider:getImages(torch.linspace(1,10,10):long(), 'train') 391 | recon_train = evalIm(x_in,x_out) 392 | 393 | local x_in, x_out = dataProvider:getImages(torch.linspace(1,10,10):long(), 'test') 394 | recon_test = evalIm(x_in,x_out) 395 | 396 | local reconstructions = torch.cat(recon_train, recon_test,2) 397 | 398 | image.save(model_opts.save_dir .. '/progress.png', reconstructions) 399 | 400 | embeddings = {} 401 | embeddings.train = torch.zeros(ndat, model_opts.nLatentDims) 402 | 403 | indices = torch.linspace(1,ndat,ndat):long():split(opt.batchSize) 404 | 405 | start = 1 406 | for t,v in ipairs(indices) do 407 | collectgarbage() 408 | stop = start + v:size(1) - 1 409 | 410 | x_in = dataProvider:getImages(v, 'train') 411 | -- Forward pass 412 | x_in = x_in:cuda() 413 | 414 | codes = encoder:forward(x_in) 415 | embeddings.train:sub(start, stop, 1,model_opts.nLatentDims):copy(codes[#codes]) 416 | 417 | start = stop + 1 418 | end 419 | 420 | ntest = dataProvider.test.inds:size()[1] 421 | embeddings.test = torch.zeros(ntest, model_opts.nLatentDims) 422 | indices = torch.linspace(1,ntest,ntest):long():split(opt.batchSize) 423 | 424 | start = 1 425 | for t,v in ipairs(indices) do 426 | collectgarbage() 427 | stop = start + v:size(1) - 1 428 | 429 | x_in = dataProvider:getImages(v, 'test') 430 | -- Forward pass 431 | x_in = x_in:cuda() 432 | 433 | codes = encoder:forward(x_in) 434 | 435 | embeddings.test:sub(start, stop, 1,model_opts.nLatentDims):copy(codes[#codes]) 436 | 437 | start = stop + 1 438 | end 439 | 440 | x_in = nil 441 | 442 | dataProvider.opts.rotate = rotate_tmp 443 | 444 | encoder:training() 445 | decoder:training() 446 | 447 | torch.save(model_opts.save_dir .. '/progress_embeddings.t7', embeddings, 'binary', false) 448 | embeddings = nil 449 | 450 | torch.save(model_opts.save_dir .. '/plots_tmp.t7', plots, 'binary', false) 451 | torch.save(model_opts.save_dir .. '/epoch_tmp.t7', opt.epoch, 'binary', false) 452 | end 453 | 454 | if opt.epoch % opt.saveStateIter == 0 then 455 | print('Saving model.') 456 | 457 | -- save the optimizer states 458 | torch.save(stateAdv_path, utils.table2float(stateAdv), 'binary', false) 459 | torch.save(stateAdvGen_path, utils.table2float(stateAdvGen), 'binary', false) 460 | 461 | torch.save(stateEnc_path, utils.table2float(stateEnc), 'binary', false) 462 | torch.save(stateDec_path, utils.table2float(stateDec), 'binary', false) 463 | 464 | 465 | -- save the options 466 | torch.save(opt_path, opt, 'binary', false) 467 | torch.save(optEnc_path, optEnc, 'binary', false) 468 | torch.save(optDec_path, optDec, 'binary', false) 469 | torch.save(optAdv_path, optAdv, 'binary', false) 470 | torch.save(optAdvGen_path, optAdvGen, 'binary', false) 471 | 472 | decoder:clearState() 473 | encoder:clearState() 474 | adversary:clearState() 475 | adversaryGen:clearState() 476 | 477 | torch.save(model_opts.save_dir .. '/plots.t7', plots, 'binary', false) 478 | torch.save(model_opts.save_dir .. '/epoch.t7', opt.epoch, 'binary', false) 479 | 480 | torch.save(model_opts.save_dir .. '/decoder.t7', decoder:float(), 'binary', false) 481 | torch.save(model_opts.save_dir .. '/encoder.t7', encoder:float(), 'binary', false) 482 | torch.save(model_opts.save_dir .. '/adversary.t7', adversary:float(), 'binary', false) 483 | torch.save(model_opts.save_dir .. '/adversaryGen.t7', adversaryGen:float(), 'binary', false) 484 | 485 | torch.save(model_opts.save_dir .. '/rng.t7', torch.getRNGState(), 'binary', false) 486 | torch.save(model_opts.save_dir .. '/rng_cuda.t7', cutorch.getRNGState(), 'binary', false) 487 | 488 | decoder:cuda() 489 | encoder:cuda() 490 | adversary:cuda() 491 | adversaryGen:cuda() 492 | 493 | end 494 | 495 | plots = nil 496 | end 497 | end 498 | 499 | return learner -------------------------------------------------------------------------------- /train_model_2D.lua: -------------------------------------------------------------------------------- 1 | require 'cutorch' -- Use CUDA if available 2 | require 'cudnn' 3 | require 'cunn' 4 | require 'paths' 5 | require 'imtools' 6 | 7 | 8 | optim = require 'optim' 9 | gnuplot = require 'gnuplot' 10 | image = require 'image' 11 | nn = require 'nn' 12 | optnet = require 'optnet' 13 | gnuplot = require 'gnuplot' 14 | 15 | cuda = true 16 | hasCudnn = true 17 | 18 | print('Has cudnn: ') 19 | print(hasCudnn) 20 | 21 | -- package.loaded['modelTools'] = nil 22 | -- require 'modelTools' 23 | 24 | 25 | DataProvider = require('DataProvider2D') 26 | learner = require 'train' 27 | setup = require 'setup' 28 | 29 | unpack = table.unpack or unpack 30 | 31 | count = cutorch.getDeviceCount() 32 | -- torch.setdefaulttensortype('torch.CudaTensor') 33 | print('GPU is ON') 34 | for i = 1, count do 35 | print('Device ' .. i .. ':') 36 | freeMemory, totalMemory = cutorch.getMemoryUsage(i) 37 | print('\t Free memory ' .. freeMemory) 38 | print('\t Total memory ' .. totalMemory) 39 | end 40 | 41 | 42 | cmd = torch.CmdLine() 43 | cmd:option('-imsize', 256, 'desired size of images in dataset') 44 | cmd:option('-gpu1', 1, 'gpu for the encoder') 45 | cmd:option('-gpu2', 2, 'gpu for the encoder') 46 | cmd:option('-gpu3', 3, 'gpu for the encoder') 47 | cmd:option('-gpu', -1, 'gpu for the everything(overrides other gpu settings') 48 | cmd:option('-imdir', '/root/images/release_4_1_17/release_v2/aligned/2D', 'parent directory for images') 49 | cmd:option('-learnrate', 0.0002, 'learning rate') 50 | cmd:option('-advgenratio', 1E-4, 'ratio for advGen update') 51 | cmd:option('-advlatentratio', 1E-4, 'ratio for advLatent update') 52 | cmd:option('-suffix', '', 'string suffix') 53 | cmd:option('-ganNoise', 0.01, 'injection noise for the GAN') 54 | cmd:option('-ganNoiseAllLayers', false, 'add noise on all GAN layers') 55 | cmd:option('-nepochs', 150, 'number of epochs') 56 | cmd:option('-nepochspt2', 300, 'number of epochs for pt2') 57 | cmd:option('-useGanD', 1, 'use a GAN on the decoder') 58 | cmd:option('-beta1', 0.5, 'beta1 parameter for ADAM descent') 59 | cmd:option('-ndat', -1, 'number of training data to use') 60 | 61 | params = cmd:parse(arg) 62 | params.useGanD = params.useGanD > 0 63 | 64 | print(params) 65 | 66 | if params.gpu == -1 then 67 | gpu1 = params.gpu1 68 | gpu2 = params.gpu2 69 | gpu3 = params.gpu3 70 | else 71 | gpu1 = params.gpu 72 | gpu2 = params.gpu 73 | gpu3 = params.gpu 74 | end 75 | print('Setting default GPU to ' .. gpu1) 76 | 77 | cutorch.setDevice(gpu1) 78 | torch.setnumthreads(12) 79 | 80 | -- Set up Torch 81 | print('Setting up') 82 | torch.setdefaulttensortype('torch.FloatTensor') 83 | torch.manualSeed(1) 84 | cutorch.manualSeed(torch.random()) 85 | 86 | 87 | 88 | 89 | evalIm = function(x_in, x_out) 90 | local xHat, latentHat, latentHat_var = decoder:forward(encoder:forward(x_in:cuda())) 91 | 92 | xHat = imtools.mat2img(xHat) 93 | x_out = imtools.mat2img(x_out) 94 | 95 | -- Plot reconstructions 96 | recon = torch.cat(image.toDisplayTensor(x_out, 1, 10), image.toDisplayTensor(xHat, 1, 10), 2) 97 | 98 | return recon 99 | end 100 | 101 | -- Now setup the probabilstic autoencoder 102 | 103 | 104 | imsize = params.imsize 105 | 106 | model_opts = setup.getModelOpts() 107 | model_opts.parent_dir = 'trainAAE2D_' .. imsize 108 | model_opts.model_name = 'caae-nopool_GAN_v2' 109 | model_opts.save_dir = model_opts.parent_dir .. '/' .. model_opts.model_name .. '_AAEGAN' .. params.suffix 110 | model_opts.image_dir = params.imdir 111 | model_opts.ganNoise = params.ganNoise 112 | model_opts.useGanD = params.useGanD 113 | 114 | model_opts.verbose = false 115 | 116 | model_opts.imsize = imsize 117 | model_opts.image_sub_size = 1.4707*(512/imsize) 118 | 119 | print(model_opts.save_dir) 120 | 121 | setup.getLearnOpts(model_opts) 122 | model_opts.rotate = false 123 | 124 | dataProvider = DataProvider.create(model_opts.image_dir, model_opts) 125 | 126 | x = dataProvider:getImages(torch.LongTensor{1}, 'train') 127 | print(x:size()) 128 | 129 | 130 | opt.batchSize = 32 131 | opt.nepochs = params.nepochs 132 | 133 | nLatentDims = 16 134 | model_opts.nLatentDims = nLatentDims 135 | 136 | 137 | optEnc.learningRate = params.learnrate 138 | optEnc.beta1 = params.beta1 139 | optDec.learningRate = params.learnrate 140 | optDec.beta1 = params.beta1 141 | optAdvGen.learningRate = params.learnrate 142 | optAdvGen.beta1 = params.beta1 143 | optAdv.learningRate = params.learnrate 144 | optAdv.beta1 = params.beta1 145 | 146 | model_opts.advGenRatio = params.advgenratio 147 | model_opts.advLatentRatio = params.advlatentratio 148 | 149 | opt.saveStateIter = 25 150 | opt.saveProgressIter = 5 151 | 152 | criterion = nn.BCECriterion():cuda() 153 | 154 | 155 | if params.ndat == -1 then 156 | ndat = dataProvider.train.inds:size()[1] 157 | else 158 | ndat = params.ndat 159 | end 160 | 161 | if opt.epoch ~= opt.nepochs then 162 | 163 | setup.getModel() 164 | learner.loop(ndat) 165 | 166 | decoder = nil 167 | adversary = nil 168 | adversaryGen = nil 169 | end 170 | 171 | 172 | -- -- Now setup the conditional probabilstic autoencoder 173 | 174 | setup.getModel() 175 | 176 | 177 | shapeEncoder = encoder 178 | shapeEncoder:evaluate() 179 | 180 | 181 | for i = 1,#shapeEncoder.modules do 182 | shapeEncoder.modules[i]:clearState() 183 | end 184 | 185 | collectgarbage() 186 | 187 | shapeDataProvider = dataProvider 188 | 189 | model_opts.channel_inds_in = torch.LongTensor{1,2,3} 190 | model_opts.channel_inds_out = torch.LongTensor{1,2,3} 191 | model_opts.nChOut = 3 192 | model_opts.nChIn = 3 193 | 194 | embedding_file = model_opts.save_dir .. '/progress_embeddings.t7' 195 | model_opts.save_dir = model_opts.parent_dir .. '/' .. model_opts.model_name .. '_AAEGAN' .. params.suffix .. '_pattern' 196 | print(model_opts.save_dir) 197 | 198 | model_opts.nLatentDims = nLatentDims 199 | model_opts.nClasses = shapeDataProvider:getLabels(torch.LongTensor{1}, 'train'):size()[2] 200 | model_opts.nOther = nLatentDims 201 | 202 | 203 | setup.getLearnOpts(model_opts) 204 | 205 | 206 | optEnc.learningRate = params.learnrate 207 | optEnc.beta1 = params.beta1 208 | optDec.learningRate = params.learnrate 209 | optDec.beta1 = params.beta1 210 | optAdvGen.learningRate = params.learnrate 211 | optAdvGen.beta1 = params.beta1 212 | optAdv.learningRate = params.learnrate 213 | optAdv.beta1 = params.beta1 214 | 215 | 216 | 217 | model_opts.advGenRatio = params.advgenratio 218 | model_opts.advLatentRatio = params.advlatentratio 219 | 220 | 221 | criterion = criterion_label 222 | 223 | opt.saveStateIter = 25 224 | opt.nepochs = params.nepochspt2 225 | 226 | 227 | shape_embeddings = torch.load(embedding_file) 228 | 229 | 230 | dataProvider = DataProvider.create(model_opts.image_dir, model_opts) 231 | function dataProvider:getCodes(indices, train_or_test) 232 | local codes = shape_embeddings[train_or_test]:index(1, indices):cuda() 233 | return codes 234 | end 235 | setup.getModel() 236 | 237 | if opt.epoch ~= opt.nepochs then 238 | 239 | learner.loop(ndat) 240 | 241 | end 242 | 243 | encoder:evaluate() 244 | decoder:evaluate() 245 | 246 | print_images = function(encoder, decoder, dataProvider, train_or_test, save_dir) 247 | 248 | require 'paths' 249 | 250 | paths.mkdir(save_dir) 251 | 252 | ndat = dataProvider[train_or_test].labels:size(1) 253 | nclasses = dataProvider[train_or_test].labels:size(2) 254 | 255 | for j = 1,ndat do 256 | -- print('Printing image ' .. j) 257 | img_num = torch.LongTensor{j} 258 | 259 | local label = dataProvider:getLabels(torch.LongTensor{j,j}, train_or_test) 260 | local img = dataProvider:getImages(torch.LongTensor{j,j}, train_or_test):cuda() 261 | local imsize = img:size() 262 | local out_latent = nil 263 | out_latent = encoder:forward(img) 264 | 265 | out_latent[3] = torch.zeros(2,model_opts.nLatentDims):cuda() 266 | 267 | local out_img = torch.Tensor(nclasses, 3, imsize[3], imsize[4]):cuda() 268 | 269 | im_path = dataProvider.image_paths[dataProvider[train_or_test].inds[j]] 270 | tokens = utils.split(im_path, '/|.') 271 | im_class = tokens[7] 272 | im_id = tokens[8] 273 | 274 | print('Printing images for ' .. im_class .. ' ' .. im_id) 275 | 276 | save_path = save_dir .. '/' .. im_class .. '_'.. im_id .. '_orig.png' 277 | image.save(save_path, img[1]) 278 | 279 | for i = 1,#dataProvider.classes do 280 | save_path = save_dir .. '/' .. im_class .. '_'.. im_id .. '_pred_' .. dataProvider.classes[i] .. '.png' 281 | 282 | if not paths.filep(save_path) then 283 | one_hot = torch.ones(2,nclasses):fill(-25):cuda() 284 | one_hot[{{1},{i}}] = 0 285 | 286 | out_latent[1] = one_hot 287 | out_img = decoder:forward(out_latent)[1][2] 288 | 289 | 290 | image.save(save_path, out_img) 291 | end 292 | 293 | end 294 | end 295 | end 296 | 297 | print_images(encoder, decoder, dataProvider, 'train', model_opts.save_dir .. '/' .. 'pred_train') 298 | print_images(encoder, decoder, dataProvider, 'test', model_opts.save_dir .. '/' .. 'pred_test') -------------------------------------------------------------------------------- /train_model_2D.sh: -------------------------------------------------------------------------------- 1 | th train_model_2D.lua -imsize 256 -gpu 1 -imdir "/root/images/release_4_1_17/release_v2/aligned/2D" -learnrate 0.0002 -advgenratio 0.00001 -ganNoise 0.05 -nepochs 150 -nepochspt2 220 -suffix "v1_4_alldat" 2 | 3 | 4 | -------------------------------------------------------------------------------- /utils.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | 4 | utils = {} 5 | 6 | function utils.ind2sub(matsize, ndx) 7 | 8 | matsize_tmp = torch.LongTensor(im_wl:size()) 9 | for i = 1,im_wl:size():size() do 10 | matsize_tmp[i] = matsize[i] 11 | end 12 | 13 | matsize = matsize_tmp 14 | 15 | cp = torch.cumprod(matsize) 16 | 17 | sub = torch.zeros(matsize:size()) 18 | 19 | for i = cp:size()[1], 2, -1 do 20 | vi = ((ndx-1) % cp[i-1]) + 1 21 | vj = (ndx - vi)/cp[i-1] + 1 22 | sub[i] = vj 23 | ndx = vi 24 | end 25 | sub[1] = ndx; 26 | 27 | return sub 28 | end 29 | 30 | function utils.alphanumsort(o) 31 | -- Shamelessley lifted from 32 | -- http://notebook.kulchenko.com/algorithms/alphanumeric-natural-sorting-for-humans-in-lua 33 | -- grj 5/16/16 34 | local function padnum(d) return ("%012d"):format(d) end 35 | table.sort(o, function(a,b) 36 | return tostring(a):gsub("%d+",padnum) < tostring(b):gsub("%d+",padnum) end) 37 | return o 38 | end 39 | 40 | function utils.meshgrid(x, y) 41 | local xx = torch.repeatTensor(x, y:size(1),1) 42 | local yy = torch.repeatTensor(y:view(-1,1), 1, x:size(1)) 43 | return xx, yy 44 | end 45 | 46 | function utils.unique(input) 47 | 48 | local b = {} 49 | local range 50 | 51 | -- print(type(input)) 52 | 53 | if type(input) == 'table' then 54 | range = #input 55 | else 56 | range = input:numel() 57 | end 58 | 59 | local c = 0 60 | for i = 1, range do 61 | if b[input[i]] == nil then 62 | c = c+1 63 | b[input[i]] = c 64 | end 65 | end 66 | 67 | local u_vals = {} 68 | for i in pairs(b) do 69 | table.insert(u_vals,i) 70 | end 71 | 72 | local inds = torch.zeros(range) 73 | for i = 1, range do 74 | inds[i] = b[input[i]] 75 | end 76 | 77 | u_vals_tmp = {} 78 | 79 | for i = 1,#u_vals do 80 | ind = torch.nonzero(torch.eq(inds, torch.FloatTensor(inds:size()[1]):fill(i)))[1][1] 81 | u_vals_tmp[i] = input[ind] 82 | 83 | end 84 | u_vals = u_vals_tmp 85 | 86 | return u_vals, inds 87 | end 88 | 89 | function utils.table2float(orig) 90 | local orig_type = type(orig) 91 | local copy 92 | if orig_type == 'table' then 93 | copy = {} 94 | for orig_key, orig_value in next, orig, nil do 95 | copy[orig_key] = utils.table2float(orig_value) 96 | end 97 | setmetatable(copy, utils.table2float(getmetatable(orig))) 98 | elseif orig_type == 'userdata' then -- number, string, boolean, etc 99 | copy = orig:float() 100 | end 101 | return copy 102 | end 103 | 104 | function utils.table2cuda(orig) 105 | local orig_type = type(orig) 106 | local copy 107 | if orig_type == 'table' then 108 | copy = {} 109 | for orig_key, orig_value in next, orig, nil do 110 | copy[orig_key] = utils.table2cuda(orig_value) 111 | end 112 | setmetatable(copy, utils.table2cuda(getmetatable(orig))) 113 | elseif orig_type == 'userdata' then -- number, string, boolean, etc 114 | copy = orig:cuda() 115 | end 116 | return copy 117 | end 118 | 119 | 120 | 121 | function utils.split(inputstr, sep) 122 | -- liberated from http://stackoverflow.com/questions/1426954/split-string-in-lua 123 | if sep == nil then 124 | sep = "%s" 125 | end 126 | 127 | local t={} ; i=1 128 | for str in string.gmatch(inputstr, "([^"..sep.."]+)") do 129 | t[i] = str 130 | i = i + 1 131 | end 132 | 133 | return t 134 | end 135 | 136 | function utils.normalize(tensor) 137 | mu = torch.mean(tensor) 138 | std = torch.std(tensor) 139 | tensor_out = (torch.Tensor(tensor:size()):copy(tensor)-mu)/std 140 | 141 | return tensor_out 142 | end 143 | 144 | function utils.shallowcopy(orig) 145 | local orig_type = type(orig) 146 | local copy 147 | if orig_type == 'table' then 148 | copy = {} 149 | for orig_key, orig_value in pairs(orig) do 150 | copy[orig_key] = orig_value:clone() 151 | end 152 | else -- number, string, boolean, etc 153 | copy = orig:clone() 154 | end 155 | return copy 156 | end 157 | 158 | function utils.deepcopy(orig) 159 | local orig_type = type(orig) 160 | local copy 161 | if orig_type == 'table' then 162 | copy = {} 163 | for orig_key, orig_value in next, orig, nil do 164 | copy[utils.deepcopy(orig_key)] = utils.deepcopy(orig_value) 165 | end 166 | setmetatable(copy, utils.deepcopy(getmetatable(orig))) 167 | else -- number, string, boolean, etc 168 | copy = orig:clone() 169 | end 170 | return copy 171 | end 172 | 173 | --------------------------------------------------------------------------------