├── README.md ├── VAE.lua ├── dcgan_vae.lua ├── discriminator.lua └── tiled_images.png /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Deep Convolutional Variational Autoencoder w/ Generative Adversarial Network 3 | 4 | A combination of the [DCGAN implementation](https://github.com/soumith/dcgan.torch) by soumith and the [variational autoencoder](https://github.com/Kaixhin/Autoencoders) by Kaixhin. 5 | 6 | The model produces 64x64 images from inputs of any size via center cropping. You can modify the code relatively easily to produce different sized outputs (adding more convolutional layers, for instance), as well as to rescale images instead of cropping them. Images are randomly flipped horizontally to get better coverage on training data. 7 | 8 | I have added white noise to the original inputs that go through the discriminator after reading this [post on stabilizing GANS](http://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/). The noise level is annealed over time to help the generator and discriminator converge. 9 | 10 | # Results on Wikimedia Paintings Dataset 11 | 12 | ![](https://github.com/staturecrane/dcgan_vae_torch/blob/master/tiled_images.png) 13 | 14 | ## Prerequisites 15 | 1. Torch7 16 | 2. CUDA 17 | 3. CUDNN 18 | 4. DPNN 19 | 5. Lua File System 20 | 6. optim 21 | 7. xlua 22 | 23 | To run, execute the script using 24 | 25 | ``` 26 | th dcgan_vae.lua -i [input folder destination] -o [output folder destination] -c [destination for saving model checkpoints] -r [reconstructions folder] 27 | ``` 28 | 29 | where the input folder is expected to contain color images. The model resamples the training set after every epoch so as to fit on a GPU and still (eventually) sample all of the data. "Output" is for samples generated by the model, and "reconstructions folder" is to just save some reconstructions from the training set, to see how the VAE is doing (it's not going to do particularly well, but that's okay; it's there to assist the GAN). 30 | -------------------------------------------------------------------------------- /VAE.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | require 'dpnn' 3 | 4 | local VAE = {} 5 | 6 | function VAE.get_encoder(channels, naf, z_dim) 7 | encoder = nn.Sequential() 8 | encoder:add(nn.SpatialConvolution(channels, naf, 4, 4, 2, 2, 1, 1)) 9 | encoder:add(nn.ReLU()) 10 | encoder:add(nn.SpatialConvolution(naf, naf * 2, 4, 4, 2, 2, 1, 1)) 11 | encoder:add(nn.SpatialBatchNormalization(naf * 2)):add(nn.ReLU()) 12 | encoder:add(nn.SpatialConvolution(naf * 2, naf * 4, 4, 4, 2, 2, 1, 1)) 13 | encoder:add(nn.SpatialBatchNormalization(naf * 4)):add(nn.ReLU()) 14 | encoder:add(nn.SpatialConvolution(naf * 4, naf * 8, 4, 4, 2, 2, 1, 1)) 15 | encoder:add(nn.SpatialBatchNormalization(naf * 8)):add(nn.ReLU()) 16 | 17 | zLayer = nn.ConcatTable() 18 | zLayer:add(nn.SpatialConvolution(naf * 8, z_dim, 4, 4)) 19 | zLayer:add(nn.SpatialConvolution(naf * 8, z_dim, 4, 4)) 20 | encoder:add(zLayer) 21 | 22 | return encoder 23 | end 24 | 25 | function VAE.get_sampler() 26 | epsilonModule = nn.Sequential() 27 | epsilonModule:add(nn.MulConstant(0)) 28 | epsilonModule:add(nn.WhiteNoise(0, 0.01)) 29 | 30 | noiseModule = nn.Sequential() 31 | noiseModuleInternal = nn.ConcatTable() 32 | stdModule = nn.Sequential() 33 | stdModule:add(nn.MulConstant(0.5)) -- Compute 1/2 log σ^2 = log σ 34 | stdModule:add(nn.Exp()) -- Compute σ 35 | noiseModuleInternal:add(stdModule) -- Standard deviation σ 36 | noiseModuleInternal:add(epsilonModule) -- Sample noise ε 37 | noiseModule:add(noiseModuleInternal) 38 | noiseModule:add(nn.CMulTable()) 39 | 40 | sampler = nn.Sequential() 41 | samplerInternal = nn.ParallelTable() 42 | samplerInternal:add(nn.Identity()) 43 | samplerInternal:add(noiseModule) 44 | sampler:add(samplerInternal) 45 | sampler:add(nn.CAddTable()) 46 | 47 | return sampler 48 | end 49 | 50 | function VAE.get_decoder(channels, ngf, z_dim) 51 | decoder = nn.Sequential() 52 | decoder:add(nn.SpatialFullConvolution(z_dim, ngf * 8, 4, 4)) 53 | decoder:add(nn.SpatialBatchNormalization(ngf * 8)):add(nn.ReLU(true)) 54 | decoder:add(nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1)) 55 | decoder:add(nn.SpatialBatchNormalization(ngf * 4)):add(nn.ReLU(true)) 56 | decoder:add(nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1)) 57 | decoder:add(nn.SpatialBatchNormalization(ngf * 2)):add(nn.ReLU(true)) 58 | decoder:add(nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1)) 59 | decoder:add(nn.SpatialBatchNormalization(ngf)):add(nn.ReLU(true)) 60 | decoder:add(nn.SpatialFullConvolution(ngf, channels, 4, 4, 2, 2, 1, 1)) 61 | decoder:add(nn.Sigmoid()) 62 | 63 | return decoder 64 | end 65 | 66 | return VAE 67 | -------------------------------------------------------------------------------- /dcgan_vae.lua: -------------------------------------------------------------------------------- 1 | require 'image' 2 | require 'xlua' 3 | require 'nn' 4 | require 'dpnn' 5 | require 'optim' 6 | require 'lfs' 7 | local VAE = require 'VAE' 8 | local discriminator = require 'discriminator' 9 | 10 | hasCudnn, cudnn = pcall(require, 'cudnn') 11 | assert(hasCudnn) --check to make sure you have CUDA and CUDNN-enabled GPU 12 | 13 | local argparse = require 'argparse' 14 | local parser = argparse('dcgan_vae', 'a Torch implementation of the deep convolutional generative adversarial network, with variational autoencoder') 15 | parser:option('-i --input', 'input directory for image dataset') 16 | parser:option('-o --output', 'output directory for generated images') 17 | parser:option('-c --checkpoints', 'directory for saving checkpoints') 18 | parser:option('-r --reconstruction', 'directory to put samples of reconstructions') 19 | 20 | args = parser:parse() 21 | 22 | input = args.input 23 | output_folder = args.output 24 | checkpoints = args.checkpoints 25 | reconstruct_folder = args.reconstruction 26 | 27 | --ensure tensors are of correct type 28 | torch.setdefaulttensortype('torch.FloatTensor') 29 | torch.setnumthreads(1) 30 | 31 | function getFilenames() 32 | queue = {} 33 | count = 1 34 | for file in lfs.dir(input) do 35 | if file ~= '.' and file ~= '..' then 36 | queue[count] = file 37 | count = count + 1 38 | end 39 | end 40 | return queue 41 | end 42 | 43 | function getNumber(num) 44 | length = #tostring(num) 45 | filename = "" 46 | for i=1, (6 - length) do 47 | filename = filename .. 0 48 | end 49 | filename = filename .. num 50 | return filename 51 | end 52 | 53 | train_size = 200 54 | batch_size = 50 55 | channels = 3 56 | dim = 64 57 | 58 | train = torch.Tensor(train_size, channels, dim, dim) 59 | train = train:cuda() 60 | 61 | filenames = getFilenames() 62 | 63 | function fillTensor(tensor) 64 | for i = 1, train_size do 65 | local image_x = image.load(input .. filenames[torch.random(1, #filenames)]) 66 | local flip_or_not = torch.random(1, 2) 67 | if flip_or_not == 1 then 68 | image_x = image.hflip(image_x) 69 | end 70 | local image_ok, image_crop = pcall(image.crop, image_x, 'c', dim, dim) 71 | if image_ok then 72 | tensor[i] = image_crop 73 | else 74 | print('image cannot be cropped to ' .. dim .. 'x' .. dim .. '. Skipping...') 75 | end 76 | end 77 | return tensor 78 | end 79 | 80 | train = fillTensor(train) 81 | 82 | feature_size = channels * dim * dim 83 | 84 | --initialize the weights to non-zero 85 | function weights_init(m) 86 | local name = torch.type(m) 87 | if name:find('Convolution') then 88 | m.weight:normal(0.0, 0.02) 89 | m:noBias() 90 | elseif name:find('BatchNormalization') then 91 | if m.weight then m.weight:normal(1.0, 0.02) end 92 | if m.bias then m.bias:fill(0) end 93 | end 94 | end 95 | 96 | z_dim = 100 97 | ndf = 64 98 | ngf = 64 99 | naf = 64 100 | 101 | encoder = VAE.get_encoder(channels, naf, z_dim) 102 | sampler = VAE.get_sampler() 103 | decoder = VAE.get_decoder(channels, ngf, z_dim) 104 | 105 | netG = nn.Sequential() 106 | netG:add(encoder) 107 | netG:add(sampler) 108 | netG:add(decoder) 109 | netG:apply(weights_init) 110 | 111 | netD = discriminator.get_discriminator(channels, ndf) 112 | netD:apply(weights_init) 113 | 114 | netG = netG:cuda() 115 | netD = netD:cuda() 116 | cudnn.convert(netG, cudnn) 117 | cudnn.convert(netD, cudnn) 118 | 119 | criterion = nn.BCECriterion() 120 | criterion = criterion:cuda() 121 | 122 | m_criterion = nn.MSECriterion() 123 | m_criterion = m_criterion:cuda() 124 | 125 | optimStateG = { 126 | learningRate = 0.0002, 127 | beta1 = 0.5 128 | } 129 | 130 | optimStateD = { 131 | learningRate = 0.0002, 132 | beta1 = 0.5 133 | } 134 | 135 | --noise to pass through decoder to generate random samples from Z 136 | noise_x = torch.Tensor(batch_size, z_dim, 1, 1) 137 | noise_x = noise_x:cuda() 138 | noise_x:normal(0, 0.01) 139 | 140 | --label, real or fake, for our GAN 141 | label = torch.Tensor(batch_size) 142 | 143 | label = label:cuda() 144 | 145 | real_label = 1 146 | fake_label = 0 147 | 148 | dNoise = .1 149 | 150 | epoch_tm = torch.Timer() 151 | tm = torch.Timer() 152 | data_tm = torch.Timer() 153 | 154 | --to keep track of our reconstructions 155 | reconstruct_count = 1 156 | 157 | parametersD, gradParametersD = netD:getParameters() 158 | parametersG, gradParametersG = netG:getParameters() 159 | 160 | errD = 0 161 | errG = 0 162 | errA = 0 163 | 164 | 165 | --training evaluation for discriminator. 166 | fDx = function(x) 167 | if x ~= parametersD then 168 | parametersD:copy(x) 169 | end 170 | gradParametersD:zero() 171 | 172 | -- train with real 173 | label:fill(real_label) 174 | --slightly noise inputs to help stabilize GAN and allow for convergence 175 | input_x = nn.WhiteNoise(0, dNoise):cuda():forward(input_x) 176 | output = netD:forward(input_x) 177 | errD_real = criterion:forward(output, label) 178 | df_do = criterion:backward(output, label) 179 | if (errG < .7 or errD > 1.0) then netD:backward(input_x, df_do) end 180 | 181 | -- train with fake 182 | noise_x:normal(0, 0.01) 183 | fake = decoder:forward(noise_x) 184 | --input_x:copy(fake) 185 | label:fill(fake_label) 186 | output = netD:forward(fake) 187 | errD_fake = criterion:forward(output, label) 188 | df_do = criterion:backward(output, label) 189 | if (errG < .7 or errD > 1.0) then netD:backward(fake, df_do) end 190 | 191 | errD = errD_real + errD_fake 192 | gradParametersD:clamp(-5, 5) 193 | return errD, gradParametersD 194 | end 195 | 196 | --training evaluation for variational autoencoder 197 | fAx = function(x) 198 | if x ~= parametersG then 199 | parametersG:copy(x) 200 | end 201 | --reconstruction loss 202 | gradParametersG:zero() 203 | output = netG:forward(input_x) 204 | --print(output:size(), input_x:size()) 205 | errA = m_criterion:forward(output, input_x) 206 | df_do = m_criterion:backward(output, input_x) 207 | netG:backward(input_x, df_do) 208 | 209 | --KLLoss 210 | nElements = output:nElement() 211 | mean, log_var = table.unpack(encoder.output) 212 | var = torch.exp(log_var) 213 | KLLoss = -0.5 * torch.sum(1 + log_var - torch.pow(mean, 2) - var) 214 | KLLoss = KLLoss / nElements 215 | errA = errA + KLLoss 216 | gradKLLoss = {mean / nElements, 0.5*(var - 1) / nElements} 217 | encoder:backward(input_x, gradKLLoss) 218 | if reconstruct_count % 10 == 0 then 219 | if reconstruct_folder then 220 | image.save(reconstruct_folder .. 'reconstruction' .. getNumber(reconstruct_count) .. '.png', output[1]) 221 | end 222 | end 223 | return errA, gradParametersG 224 | end 225 | 226 | --training evaluation for generator 227 | fGx = function(x) 228 | if x ~= parametersG then 229 | parametersG:copy(x) 230 | end 231 | gradParametersG:zero() 232 | label:fill(real_label) 233 | output = netD.output 234 | errG = criterion:forward(output, label) 235 | df_do = criterion:backward(output, label) 236 | df_dg = netD:updateGradInput(input_x, df_do) 237 | if (errD < .7 or errG > 1.0) then decoder:backward(noise_x, df_dg) end 238 | gradParametersG:clamp(-5, 5) 239 | return errG, gradParametersG 240 | end 241 | 242 | --generate samples from Z 243 | generate = function(epoch) 244 | noise_x:normal(0, 0.01) 245 | local generations = decoder:forward(noise_x) 246 | image.save(output_folder .. getNumber(epoch) .. '.png', generations[1]) 247 | end 248 | 249 | require 'optim' 250 | require 'cunn' 251 | 252 | for epoch = 1, 50000 do 253 | epoch_tm:reset() 254 | --main training loop 255 | for i = 1, train_size, batch_size do 256 | --split training data into mini-batches 257 | local size = math.min(i + batch_size - 1, train_size) - i 258 | input_x = train:narrow(1, size, batch_size) 259 | tm:reset() 260 | --optim takes an evaluation function, the parameters of the model you wish to train, and the optimization options, such as learning rate and momentum 261 | optim.adam(fAx, parametersG, optimStateG) --VAE 262 | optim.adam(fDx, parametersD, optimStateD) --discriminator 263 | optim.adam(fGx, parametersG, optimStateG) --generator 264 | collectgarbage('collect') 265 | end 266 | reconstruct_count = reconstruct_count + 1 267 | if errG then 268 | print("Generator loss: " .. errG .. ", Autoencoder loss: " .. errA .. ", Discriminator loss: " .. errD) 269 | else print("Discriminator loss: " .. errD) 270 | end 271 | parametersD, gradParametersD = nil, nil 272 | parametersG, gradParametersG = nil, nil 273 | --save and/or clear model state for next training batch 274 | if epoch % 1000 == 0 then 275 | torch.save(checkpoints .. epoch .. '_net_G.t7', netG:clearState()) 276 | torch.save(checkpoints .. epoch .. '_net_D.t7', netD:clearState()) 277 | else 278 | netG:clearState() 279 | netD:clearState() 280 | end 281 | generate(epoch) 282 | train = fillTensor(train) 283 | parametersD, gradParametersD = netD:getParameters() 284 | parametersG, gradParametersG = netG:getParameters() 285 | --simulated annealing for the discriminator's noise parameter 286 | if epoch % 10 == 0 then dNoise = dNoise * 0.99 end 287 | print(('End of epoch %d / %d \t Time Taken: %.3f'):format( 288 | epoch, 10000, epoch_tm:time().real)) 289 | end 290 | -------------------------------------------------------------------------------- /discriminator.lua: -------------------------------------------------------------------------------- 1 | require 'nn' 2 | 3 | local discriminator = {} 4 | 5 | function discriminator.get_discriminator(channels, ndf) 6 | netD = nn.Sequential() 7 | netD:add(nn.SpatialConvolution(channels, ndf, 4, 4, 2, 2, 1, 1)) 8 | netD:add(nn.LeakyReLU(0.2, true)) 9 | netD:add(nn.SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1)) 10 | netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true)) 11 | netD:add(nn.SpatialConvolution(ndf * 2, ndf * 4, 4, 4, 2, 2, 1, 1)) 12 | netD:add(nn.SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true)) 13 | netD:add(nn.SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1)) 14 | netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) 15 | netD:add(nn.SpatialConvolution(ndf * 8, 1, 4, 4)) 16 | netD:add(nn.Sigmoid()) 17 | netD:add(nn.View(1):setNumInputDims(3)) 18 | 19 | return netD 20 | end 21 | 22 | return discriminator 23 | -------------------------------------------------------------------------------- /tiled_images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/staturecrane/dcgan_vae_torch/63bf753a388781c42f4e01a2a6c9139e381f7fca/tiled_images.png --------------------------------------------------------------------------------