├── CONTRIBUTING.md ├── LICENSE ├── PATENTS ├── README.md ├── cifar ├── README.md ├── datasets │ ├── coarse_to_fine_cifar10.lua │ └── scaled_cifar10.lua ├── layers │ └── SpatialConvolutionUpsample.lua ├── scripts │ ├── train_cifar.lua │ ├── train_cifar_classcond.lua │ ├── train_cifar_coarse_to_fine.lua │ └── train_cifar_coarse_to_fine_classcond.lua ├── train │ ├── adversarial.lua │ ├── conditional_adversarial.lua │ └── double_conditional_adversarial.lua └── utils │ └── image.lua └── lsun ├── README.md ├── data.lua ├── donkey_imagenet.lua ├── donkey_lsun.lua ├── imagenet.lua ├── layers ├── SpatialConvolutionUpsample.lua ├── cudnnSpatialConvolutionUpsample.lua ├── test_cudnn.lua └── test_upsampler.lua ├── main.lua ├── model.lua ├── modelGenerator ├── gentest.lua └── modelGen.lua └── train.lua /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to eyescream 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 1. Fork the repo and create your branch from `master`. 11 | 2. If you've added code that should be tested, add tests 12 | 3. If you've changed APIs, update the documentation. 13 | 4. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## Coding Style 26 | * 3 spaces for indentation rather than tabs 27 | * 80 character line length 28 | 29 | ## License 30 | By contributing to eyescream, you agree that your contributions will be licensed 31 | under its BSD license. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For eyescream software 4 | 5 | Copyright (c) 2015-present, Facebook, Inc. All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Facebook nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /PATENTS: -------------------------------------------------------------------------------- 1 | Additional Grant of Patent Rights Version 2 2 | 3 | "Software" means the eyescream software distributed by Facebook, Inc. 4 | 5 | Facebook, Inc. ("Facebook") hereby grants to each recipient of the Software 6 | ("you") a perpetual, worldwide, royalty-free, non-exclusive, irrevocable 7 | (subject to the termination provision below) license under any Necessary 8 | Claims, to make, have made, use, sell, offer to sell, import, and otherwise 9 | transfer the Software. For avoidance of doubt, no license is granted under 10 | Facebook’s rights in any patent claims that are infringed by (i) modifications 11 | to the Software made by you or any third party or (ii) the Software in 12 | combination with any software or other technology. 13 | 14 | The license granted hereunder will terminate, automatically and without notice, 15 | if you (or any of your subsidiaries, corporate affiliates or agents) initiate 16 | directly or indirectly, or take a direct financial interest in, any Patent 17 | Assertion: (i) against Facebook or any of its subsidiaries or corporate 18 | affiliates, (ii) against any party if such Patent Assertion arises in whole or 19 | in part from any software, technology, product or service of Facebook or any of 20 | its subsidiaries or corporate affiliates, or (iii) against any party relating 21 | to the Software. Notwithstanding the foregoing, if Facebook or any of its 22 | subsidiaries or corporate affiliates files a lawsuit alleging patent 23 | infringement against you in the first instance, and you respond by filing a 24 | patent infringement counterclaim in that lawsuit against that party that is 25 | unrelated to the Software, the license granted hereunder will not terminate 26 | under section (i) of this paragraph due to such counterclaim. 27 | 28 | A "Necessary Claim" is a claim of a patent owned by Facebook that is 29 | necessarily infringed by the Software standing alone. 30 | 31 | A "Patent Assertion" is any lawsuit or other action alleging direct, indirect, 32 | or contributory infringement or inducement to infringe any patent, including a 33 | cross-claim or counterclaim. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Eyescream Project 2 | 3 | Generating Natural Images using Neural Networks. 4 | 5 | For our research summary on this work, please read the Arxiv paper: [http://arxiv.org/abs/1506.05751](http://arxiv.org/abs/1506.05751) 6 | 7 | For a high-level blog post with a live demo, please go to this website: [http://soumith.ch/eyescream](http://soumith.ch/eyescream) 8 | 9 | This repository contains the code to train neural networks and reproduce our results from scratch. 10 | 11 | ## Requirements 12 | Eyescream requires or works with 13 | * Mac OS X or Linux 14 | * NVIDIA GPU with compute capability of 3.5 or above. 15 | 16 | ## Installing Dependencies 17 | * Install [Torch](http://torch.ch) 18 | * Install the nngraph and tds packages: 19 | 20 | ``` 21 | luarocks install tds 22 | luarocks install nngraph 23 | ``` 24 | 25 | ## Training your neural networks 26 | 27 | * If you want to train the CIFAR-10 image generators, read the README in the cifar/ folder 28 | * If you want to train the LSUN/Imagenet image generators, read the README in the lsun/ folder 29 | 30 | 31 | ## Discuss the paper/code at 32 | * groups.google.com/forum/#!forum/torch7 33 | 34 | See the CONTRIBUTING file for how to help out. 35 | 36 | ## License 37 | Eyescream is BSD-licensed. We also provide an additional patent grant. -------------------------------------------------------------------------------- /cifar/README.md: -------------------------------------------------------------------------------- 1 | The pre-packaged CIFAR-10 dataset can be downloaded (to the same folder as this README) via this command: 2 | ``` 3 | wget -c http://torch7.s3-website-us-east-1.amazonaws.com/data/cifar-10-torch.tar.gz 4 | tar -xvf cifar-10-torch.tar.gz 5 | ``` 6 | 7 | Training a standard GAN [Goodfellow et. al. 2014](http://arxiv.org/abs/1406.2661) on 32x32 cifar images: 8 | ``` 9 | th scripts/train_cifar.lua --scale 32 10 | ``` 11 | 12 | Training a coarse-to-fine GAN on 16 -> 32 images: 13 | ``` 14 | th scripts/train_cifar_coarse_to_fine.lua --coarseSize 16 --fineSize 32 15 | ``` 16 | 17 | 18 | Training a class conditional GAN on 32x32 cifar images: 19 | ``` 20 | th scripts/train_cifar_classcond.lua --scale 32 21 | ``` 22 | 23 | Training a class conditional coarse-to-fine GAN on 16 -> 32 images: 24 | ``` 25 | th scripts/train_cifar_coarse_to_fine_classcond.lua --coarseSize 16 --fineSize 32 26 | ``` 27 | 28 | The default training hyperparameters should suffice for all cases. 29 | Add -g to run on a specific gpu (default is cpu). 30 | If you have display (https://github.com/szym/display) installed, -p will plot sample generations and training curves after every epoch. 31 | The number of channels in the hidden layers of G and D can be specified with --hidden_G --hidden_D . 32 | To see model parameters for decent models at every scale please see https://gist.github.com/soumith/e3f722173ea16c1ea0d9. 33 | -------------------------------------------------------------------------------- /cifar/datasets/coarse_to_fine_cifar10.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'paths' 3 | require 'image' 4 | image_utils = require 'utils.image' 5 | 6 | cifar = {} 7 | 8 | cifar.path_dataset = 'cifar-10-batches-t7/' 9 | 10 | cifar.coarseSize = 16 11 | cifar.fineSize = 32 12 | 13 | function cifar.init(fineSize, coarseSize) 14 | cifar.fineSize = fineSize 15 | cifar.coarseSize = coarseSize 16 | end 17 | 18 | function cifar.loadTrainSet(start, stop, augment, crop) 19 | return cifar.loadDataset(true, start, stop, augment, crop) 20 | end 21 | 22 | function cifar.loadTestSet(crop) 23 | return cifar.loadDataset(false, nil, nil, nil, crop) 24 | end 25 | 26 | function cifar.loadDataset(isTrain, start, stop, augment, crop) 27 | local data 28 | local labels 29 | local defaultType = torch.getdefaulttensortype() 30 | if isTrain then -- load train data 31 | data = torch.FloatTensor(50000, 3, 32, 32) 32 | labels = torch.FloatTensor(50000) 33 | for i = 0,4 do 34 | local subset = torch.load(cifar.path_dataset .. 'data_batch_' .. (i+1) .. '.t7', 'ascii') 35 | data[{ {i*10000+1, (i+1)*10000} }] = subset.data:t():reshape(10000, 3, 32, 32) 36 | labels[{ {i*10000+1, (i+1)*10000} }] = subset.labels 37 | end 38 | else -- load test data 39 | subset = torch.load(cifar.path_dataset .. 'test_batch.t7', 'ascii') 40 | data = subset.data:t():reshape(10000, 3, 32, 32):type('torch.FloatTensor') 41 | labels = subset.labels:t():type(defaultType) 42 | end 43 | 44 | local start = start or 1 45 | local stop = stop or data:size(1) 46 | 47 | -- select chunk 48 | data = data[{ {start, stop} }] 49 | labels = labels[{ {start, stop} }] 50 | labels:add(1) -- becasue indexing is 1-based 51 | local N = stop - start + 1 52 | print(' loaded ' .. N .. ' examples') 53 | 54 | local dataset = {} 55 | dataset.data = data -- on cpu 56 | dataset.labels = labels 57 | 58 | dataset.coarseData = torch.FloatTensor(N, 3, cifar.fineSize, cifar.fineSize) 59 | dataset.fineData = torch.FloatTensor(N, 3, cifar.fineSize, cifar.fineSize) 60 | dataset.diffData = torch.FloatTensor(N, 3, cifar.fineSize, cifar.fineSize) 61 | 62 | -- Coarse data 63 | function dataset:makeCoarse() 64 | for i = 1,N do 65 | local tmp = image.scale(self.data[i], cifar.coarseSize, cifar.coarseSize) 66 | self.coarseData[i] = image.scale(tmp, cifar.fineSize, cifar.fineSize) 67 | end 68 | end 69 | 70 | -- Fine data 71 | function dataset:makeFine() 72 | for i = 1,N do 73 | self.fineData[i] = image.scale(self.data[i], cifar.fineSize, cifar.fineSize) 74 | end 75 | end 76 | 77 | -- Diff (coarse - fine) 78 | function dataset:makeDiff() 79 | for i=1,N do 80 | self.diffData[i] = torch.add(self.fineData[i], -1, self.coarseData[i]) 81 | end 82 | end 83 | 84 | function dataset:size() 85 | return N 86 | end 87 | 88 | function dataset:numClasses() 89 | return 10 90 | end 91 | 92 | local labelvector = torch.zeros(10) 93 | 94 | setmetatable(dataset, {__index = function(self, index) 95 | local diff = self.diffData[index] 96 | local cond = self.coarseData[index] 97 | local fine = self.fineData[index] 98 | labelvector:zero() 99 | labelvector[self.labels[index]] = 1 100 | local example = {diff, labelvector, cond, fine} 101 | return example 102 | end}) 103 | 104 | return dataset 105 | end 106 | -------------------------------------------------------------------------------- /cifar/datasets/scaled_cifar10.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'image' 3 | require 'paths' 4 | 5 | cifar = {} 6 | 7 | cifar.path_dataset = 'cifar-10-batches-t7/' 8 | cifar.scale = 32 9 | 10 | function cifar.setScale(scale) 11 | cifar.scale = scale 12 | end 13 | 14 | function cifar.loadTrainSet(start, stop) 15 | return cifar.loadDataset(true, start, stop) 16 | end 17 | 18 | function cifar.loadTestSet() 19 | return cifar.loadDataset(false) 20 | end 21 | 22 | function cifar.loadDataset(isTrain, start, stop) 23 | local data 24 | local labels 25 | local defaultType = torch.getdefaulttensortype() 26 | if isTrain then -- load train data 27 | data = torch.FloatTensor(50000, 3, 32, 32) 28 | labels = torch.FloatTensor(50000) 29 | for i = 0,4 do 30 | local subset = torch.load(cifar.path_dataset .. 'data_batch_' .. (i+1) .. '.t7', 'ascii') 31 | data[{ {i*10000+1, (i+1)*10000} }] = subset.data:t():reshape(10000, 3, 32, 32) 32 | labels[{ {i*10000+1, (i+1)*10000} }] = subset.labels 33 | end 34 | else -- load test data 35 | subset = torch.load(cifar.path_dataset .. 'test_batch.t7', 'ascii') 36 | data = subset.data:t():reshape(10000, 3, 32, 32) 37 | labels = subset.labels:reshape(10000) 38 | end 39 | 40 | local start = start or 1 41 | local stop = stop or data:size(1) 42 | 43 | -- select chunk 44 | data = data[{ {start, stop} }] 45 | labels = labels[{ {start, stop} }] 46 | labels:add(1) -- becasue indexing is 1-based 47 | local N = stop - start + 1 48 | print(' loaded ' .. N .. ' examples') 49 | 50 | local dataset = {} 51 | dataset.data = data 52 | dataset.labels = labels 53 | dataset.scaled = torch.Tensor(N, 3, cifar.scale, cifar.scale) 54 | 55 | function dataset:scaleData() 56 | for n = 1,N do 57 | dataset.scaled[n] = image.scale(dataset.data[n], cifar.scale, cifar.scale) 58 | end 59 | end 60 | 61 | 62 | function dataset:size() 63 | return N 64 | end 65 | 66 | function dataset:numClasses() 67 | return 10 68 | end 69 | 70 | local labelvector = torch.zeros(10) 71 | 72 | setmetatable(dataset, {__index = function(self, index) 73 | local input = self.scaled[index] 74 | local class = self.labels[index] 75 | local label = labelvector:zero() 76 | label[class] = 1 77 | local example = {input, class, label} 78 | return example 79 | end}) 80 | 81 | return dataset 82 | end 83 | -------------------------------------------------------------------------------- /cifar/layers/SpatialConvolutionUpsample.lua: -------------------------------------------------------------------------------- 1 | local SpatialConvolutionUpsample, parent = torch.class('nn.SpatialConvolutionUpsample','nn.SpatialConvolution') 2 | 3 | function SpatialConvolutionUpsample:__init(nInputPlane, nOutputPlane, kW, kH, factor) 4 | factor = factor or 2 5 | assert(kW and kH and nInputPlane and nOutputPlane) 6 | assert(kW % 2 == 1, 'kW has to be odd') 7 | assert(kH % 2 == 1, 'kH has to be odd') 8 | self.factor = factor 9 | self.kW = kW 10 | self.kH = kH 11 | self.nInputPlaneU = nInputPlane 12 | self.nOutputPlaneU = nOutputPlane 13 | parent.__init(self, nInputPlane, nOutputPlane * factor * factor, kW, kH, 1, 1, (kW-1)/2) 14 | end 15 | 16 | function SpatialConvolutionUpsample:updateOutput(input) 17 | self.output = parent.updateOutput(self, input) 18 | if input:dim() == 4 then 19 | self.h = input:size(3) 20 | self.w = input:size(4) 21 | self.output = self.output:view(input:size(1), self.nOutputPlaneU, self.h*self.factor, self.w*self.factor) 22 | else 23 | self.h = input:size(2) 24 | self.w = input:size(3) 25 | self.output = self.output:view(self.nOutputPlaneU, self.h*self.factor, self.w*self.factor) 26 | end 27 | return self.output 28 | end 29 | 30 | function SpatialConvolutionUpsample:updateGradInput(input, gradOutput) 31 | if input:dim() == 4 then 32 | gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 33 | else 34 | gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 35 | end 36 | self.gradInput = parent.updateGradInput(self, input, gradOutput) 37 | return self.gradInput 38 | end 39 | 40 | function SpatialConvolutionUpsample:accGradParameters(input, gradOutput, scale) 41 | if input:dim() == 4 then 42 | gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 43 | else 44 | gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 45 | end 46 | parent.accGradParameters(self, input, gradOutput, scale) 47 | end 48 | 49 | function SpatialConvolutionUpsample:accUpdateGradParameters(input, gradOutput, scale) 50 | if input:dim() == 4 then 51 | gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 52 | else 53 | gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 54 | end 55 | parent.accUpdateGradParameters(self, input, gradOutput, scale) 56 | end 57 | -------------------------------------------------------------------------------- /cifar/scripts/train_cifar.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'cunn' 4 | require 'optim' 5 | require 'image' 6 | require 'datasets.scaled_cifar10' 7 | require 'pl' 8 | require 'paths' 9 | image_utils = require 'utils.image' 10 | ok, disp = pcall(require, 'display') 11 | if not ok then print('display not found. unable to plot') end 12 | adversarial = require 'train.adversarial' 13 | 14 | 15 | ---------------------------------------------------------------------- 16 | -- parse command-line options 17 | opt = lapp[[ 18 | -s,--save (default "logs") subdirectory to save logs 19 | --saveFreq (default 10) save every saveFreq epochs 20 | -n,--network (default "") reload pretrained network 21 | -p,--plot plot while training 22 | -r,--learningRate (default 0.02) learning rate 23 | -b,--batchSize (default 100) batch size 24 | -m,--momentum (default 0) momentum, for SGD only 25 | --coefL1 (default 0) L1 penalty on the weights 26 | --coefL2 (default 0) L2 penalty on the weights 27 | -t,--threads (default 4) number of threads 28 | -g,--gpu (default -1) gpu to run on (default cpu) 29 | -d,--noiseDim (default 100) dimensionality of noise vector 30 | --K (default 1) number of iterations to optimize D for 31 | -w, --window (default 3) windsow id of sample image 32 | --hidden_G (default 8000) number of units in hidden layers of G 33 | --hidden_D (default 1600) number of units in hidden layers of D 34 | --scale (default 32) scale of images to train on 35 | ]] 36 | 37 | if opt.gpu < 0 or opt.gpu > 3 then opt.gpu = false end 38 | print(opt) 39 | 40 | -- fix seed 41 | torch.manualSeed(1) 42 | 43 | -- threads 44 | torch.setnumthreads(opt.threads) 45 | print(' set nb of threads to ' .. torch.getnumthreads()) 46 | 47 | if opt.gpu then 48 | cutorch.setDevice(opt.gpu + 1) 49 | print(' using device ' .. opt.gpu) 50 | torch.setdefaulttensortype('torch.CudaTensor') 51 | else 52 | torch.setdefaulttensortype('torch.FloatTensor') 53 | end 54 | 55 | 56 | classes = {'0','1'} 57 | opt.geometry = {3, opt.scale, opt.scale} 58 | 59 | function setWeights(weights, std) 60 | weights:randn(weights:size()) 61 | weights:mul(std) 62 | end 63 | 64 | local input_sz = opt.geometry[1] * opt.geometry[2] * opt.geometry[3] 65 | 66 | if opt.network == '' then 67 | ---------------------------------------------------------------------- 68 | -- define D network to train 69 | local numhid = opt.hidden_D 70 | model_D = nn.Sequential() 71 | model_D:add(nn.Reshape(input_sz)) 72 | model_D:add(nn.Linear(input_sz, numhid)) 73 | model_D:add(nn.ReLU()) 74 | model_D:add(nn.Dropout()) 75 | model_D:add(nn.Linear(numhid, numhid)) 76 | model_D:add(nn.ReLU()) 77 | model_D:add(nn.Dropout()) 78 | model_D:add(nn.Linear(numhid,1)) 79 | model_D:add(nn.Sigmoid()) 80 | 81 | -- Init weights 82 | setWeights(model_D.modules[2].weight, 0.005) 83 | setWeights(model_D.modules[5].weight, 0.005) 84 | setWeights(model_D.modules[8].weight, 0.005) 85 | setWeights(model_D.modules[2].bias, 0) 86 | setWeights(model_D.modules[5].bias, 0) 87 | setWeights(model_D.modules[8].bias, 0) 88 | 89 | ---------------------------------------------------------------------- 90 | -- define G network to train 91 | local numhid = opt.hidden_G 92 | model_G = nn.Sequential() 93 | model_G:add(nn.Linear(opt.noiseDim, numhid)) 94 | model_G:add(nn.ReLU()) 95 | model_G:add(nn.Linear(numhid, numhid)) 96 | model_G:add(nn.Sigmoid()) 97 | model_G:add(nn.Linear(numhid, input_sz)) 98 | model_G:add(nn.Reshape(opt.geometry[1], opt.geometry[2], opt.geometry[3])) 99 | 100 | -- Init weights 101 | setWeights(model_G.modules[1].weight, 0.05) 102 | setWeights(model_G.modules[3].weight, 0.05) 103 | setWeights(model_G.modules[5].weight, 0.05) 104 | setWeights(model_G.modules[1].bias, 0) 105 | setWeights(model_G.modules[3].bias, 0) 106 | setWeights(model_G.modules[5].bias, 0) 107 | else 108 | print(' reloading previously trained network: ' .. opt.network) 109 | tmp = torch.load(opt.network) 110 | model_D = tmp.D 111 | model_G = tmp.G 112 | end 113 | 114 | -- loss function: negative log-likelihood 115 | criterion = nn.BCECriterion() 116 | 117 | -- retrieve parameters and gradients 118 | parameters_D,gradParameters_D = model_D:getParameters() 119 | parameters_G,gradParameters_G = model_G:getParameters() 120 | 121 | -- print networks 122 | print('Discriminator network:') 123 | print(model_D) 124 | print('Generator network:') 125 | print(model_G) 126 | 127 | 128 | ---------------------------------------------------------------------- 129 | -- get/create dataset 130 | -- 131 | ntrain = 45000 132 | nval = 5000 133 | 134 | cifar.setScale(opt.scale) 135 | 136 | -- create training set and normalize 137 | trainData = cifar.loadTrainSet(1, ntrain) 138 | mean, std = image_utils.normalize(trainData.data) 139 | trainData:scaleData() 140 | 141 | -- create validation set and normalize 142 | valData = cifar.loadTrainSet(ntrain+1, ntrain+nval) 143 | image_utils.normalize(valData.data, mean, std) 144 | valData:scaleData() 145 | 146 | 147 | -- this matrix records the current confusion across classes 148 | confusion = optim.ConfusionMatrix(classes) 149 | 150 | -- log results to files 151 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) 152 | testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) 153 | 154 | if opt.gpu then 155 | print('Copy model to gpu') 156 | model_D:cuda() 157 | model_G:cuda() 158 | end 159 | 160 | -- Training parameters 161 | sgdState_D = { 162 | learningRate = opt.learningRate, 163 | momentum = opt.momentum 164 | } 165 | 166 | sgdState_G = { 167 | learningRate = opt.learningRate, 168 | momentum = opt.momentum 169 | } 170 | 171 | -- Get examples to plot 172 | function getSamples(dataset, N) 173 | local numperclass = numperclass or 10 174 | local N = N or 8 175 | local noise_inputs = torch.Tensor(N, opt.noiseDim) 176 | 177 | -- Generate samples 178 | noise_inputs:uniform(-1, 1) 179 | local samples = model_G:forward(noise_inputs) 180 | 181 | local to_plot = {} 182 | for i=1,N do 183 | to_plot[#to_plot+1] = samples[i]:float() 184 | end 185 | 186 | return to_plot 187 | end 188 | 189 | 190 | -- training loop 191 | while true do 192 | -- train/test 193 | adversarial.train(trainData) 194 | adversarial.test(valData) 195 | 196 | sgdState_D.momentum = math.min(sgdState_D.momentum + 0.0008, 0.7) 197 | sgdState_D.learningRate = math.max(sgdState_D.learningRate / 1.000004, 0.000001) 198 | sgdState_G.momentum = math.min(sgdState_G.momentum + 0.0008, 0.7) 199 | sgdState_G.learningRate = math.max(sgdState_G.learningRate / 1.000004, 0.000001) 200 | 201 | -- plot errors 202 | if opt.plot and epoch and epoch % 1 == 0 then 203 | local to_plot = getSamples(valData, 100) 204 | torch.setdefaulttensortype('torch.FloatTensor') 205 | 206 | trainLogger:style{['% mean class accuracy (train set)'] = '-'} 207 | testLogger:style{['% mean class accuracy (test set)'] = '-'} 208 | trainLogger:plot() 209 | testLogger:plot() 210 | 211 | disp.image(to_plot, {win=opt.window, width=700, title=opt.save}) 212 | if opt.gpu then 213 | torch.setdefaulttensortype('torch.CudaTensor') 214 | else 215 | torch.setdefaulttensortype('torch.FloatTensor') 216 | end 217 | end 218 | end 219 | -------------------------------------------------------------------------------- /cifar/scripts/train_cifar_classcond.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nngraph' 3 | require 'cunn' 4 | require 'optim' 5 | require 'image' 6 | require 'datasets.scaled_cifar10' 7 | require 'pl' 8 | require 'paths' 9 | image_utils = require 'utils.image' 10 | ok, disp = pcall(require, 'display') 11 | if not ok then print('display not found. unable to plot') end 12 | adversarial = require 'train.conditional_adversarial' 13 | 14 | 15 | ---------------------------------------------------------------------- 16 | -- parse command-line options 17 | opt = lapp[[ 18 | -s,--save (default "logs") subdirectory to save logs 19 | --saveFreq (default 2) save every saveFreq epochs 20 | -n,--network (default "") reload pretrained network 21 | -p,--plot plot while training 22 | -r,--learningRate (default 0.02) learning rate 23 | -b,--batchSize (default 128) batch size 24 | -m,--momentum (default 0) momentum 25 | --coefL1 (default 0) L1 penalty on the weights 26 | --coefL2 (default 0) L2 penalty on the weights 27 | -t,--threads (default 4) number of threads 28 | -g,--gpu (default -1) gpu to run on (default cpu) 29 | -d,--noiseDim (default 100) dimensionality of noise vector 30 | --K (default 1) number of iterations to optimize D for 31 | -w, --window (default 3) windsow id of sample image 32 | --hidden_G (default 8000) number of units in hidden layers of G 33 | --hidden_D (default 1600) number of units in hidden layers of D 34 | --scale (default 32) scale of images to train on 35 | ]] 36 | 37 | if opt.gpu < 0 or opt.gpu > 3 then opt.gpu = false end 38 | print(opt) 39 | 40 | -- fix seed 41 | torch.manualSeed(1) 42 | 43 | -- threads 44 | torch.setnumthreads(opt.threads) 45 | print(' set nb of threads to ' .. torch.getnumthreads()) 46 | 47 | if opt.gpu then 48 | cutorch.setDevice(opt.gpu + 1) 49 | print(' using device ' .. opt.gpu) 50 | torch.setdefaulttensortype('torch.CudaTensor') 51 | else 52 | torch.setdefaulttensortype('torch.FloatTensor') 53 | end 54 | 55 | classes = {'0','1'} 56 | opt.geometry = {3, opt.scale, opt.scale} 57 | opt.condDim = 10 58 | cifar_classes = {'airplane', 'automobile', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck'} 59 | 60 | function setWeights(weights, std) 61 | weights:randn(weights:size()) 62 | weights:mul(std) 63 | end 64 | 65 | local input_sz = opt.geometry[1] * opt.geometry[2] * opt.geometry[3] 66 | 67 | if opt.network == '' then 68 | ---------------------------------------------------------------------- 69 | -- define D network to train 70 | local numhid = opt.hidden_D 71 | x_I = nn.Identity()() 72 | x_C = nn.Identity()() 73 | c1 = nn.JoinTable(2, 2)({nn.Reshape(input_sz)(x_I), x_C}) 74 | h1 = nn.Linear(input_sz + opt.condDim, numhid)(c1) 75 | h2 = nn.Linear(numhid, numhid)(nn.Dropout()(nn.ReLU()(h1))) 76 | h3 = nn.Linear(numhid, 1)(nn.Dropout()(nn.ReLU()(h2))) 77 | out = nn.Sigmoid()(h3) 78 | model_D = nn.gModule({x_I, x_C}, {out}) 79 | 80 | ---------------------------------------------------------------------- 81 | -- define G network to train 82 | local numhid = opt.hidden_G 83 | model_G = nn.Sequential() 84 | model_G:add(nn.JoinTable(2, 2)) 85 | model_G:add(nn.Linear(opt.noiseDim + opt.condDim, numhid)) 86 | model_G:add(nn.ReLU()) 87 | model_G:add(nn.Linear(numhid, numhid)) 88 | model_G:add(nn.ReLU()) 89 | model_G:add(nn.Linear(numhid, opt.geometry[1]*opt.geometry[2]*opt.geometry[3])) 90 | model_G:add(nn.View(opt.geometry[1], opt.geometry[2], opt.geometry[3])) 91 | else 92 | print(' reloading previously trained network: ' .. opt.network) 93 | tmp = torch.load(opt.network) 94 | model_D = tmp.D 95 | model_G = tmp.G 96 | end 97 | 98 | -- loss function: negative log-likelihood 99 | criterion = nn.BCECriterion() 100 | 101 | -- retrieve parameters and gradients 102 | parameters_D,gradParameters_D = model_D:getParameters() 103 | parameters_G,gradParameters_G = model_G:getParameters() 104 | 105 | -- print networks 106 | print('Discriminator network:') 107 | print(model_D) 108 | print('Generator network:') 109 | print(model_G) 110 | 111 | local nparams = 0 112 | for i=1,#model_D.forwardnodes do 113 | if model_D.forwardnodes[i].data ~= nil and model_D.forwardnodes[i].data.module ~= nil and model_D.forwardnodes[i].data.module.weight ~= nil then 114 | nparams = nparams + model_D.forwardnodes[i].data.module.weight:nElement() 115 | end 116 | end 117 | print('\nNumber of free parameters in D: ' .. nparams) 118 | 119 | local nparams = 0 120 | for i=1,#model_G.modules do 121 | if model_G.modules[i].weight ~= nil then 122 | nparams = nparams + model_G.modules[i].weight:nElement() 123 | end 124 | end 125 | print('Number of free parameters in G: ' .. nparams .. '\n') 126 | 127 | ---------------------------------------------------------------------- 128 | -- get/create dataset 129 | -- 130 | ntrain = 45000 131 | nval = 5000 132 | 133 | cifar.setScale(opt.scale) 134 | 135 | -- create training set and normalize 136 | trainData = cifar.loadTrainSet(1, ntrain) 137 | mean, std = image_utils.normalize(trainData.data) 138 | trainData:scaleData() 139 | 140 | -- create validation set and normalize 141 | valData = cifar.loadTrainSet(ntrain+1, ntrain+nval) 142 | image_utils.normalize(valData.data, mean, std) 143 | valData:scaleData() 144 | 145 | -- this matrix records the current confusion across classes 146 | confusion = optim.ConfusionMatrix(classes) 147 | 148 | -- log results to files 149 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) 150 | testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) 151 | 152 | if opt.gpu then 153 | print('Copy model to gpu') 154 | model_D:cuda() 155 | model_G:cuda() 156 | end 157 | 158 | -- Training parameters 159 | sgdState_D = { 160 | learningRate = opt.learningRate, 161 | momentum = opt.momentum 162 | } 163 | 164 | sgdState_G = { 165 | learningRate = opt.learningRate, 166 | momentum = opt.momentum 167 | } 168 | 169 | -- Get examples to plot 170 | function getSamples(dataset, N, numperclass) 171 | local numperclass = numperclass or 10 172 | local N = N or 8 173 | local noise_inputs = torch.Tensor(N, opt.noiseDim) 174 | local cond_inputs = torch.zeros(N, opt.condDim) 175 | 176 | -- Generate samples 177 | noise_inputs:uniform(-1, 1) 178 | local class = math.random(10) 179 | local classes = {} 180 | for n = 1,N do 181 | cond_inputs[n][class] = 1 182 | classes[n] = cifar_classes[class] 183 | if n % numperclass ==0 then class = math.random(10) end 184 | end 185 | local samples = model_G:forward({noise_inputs, cond_inputs}) 186 | 187 | local to_plot = {} 188 | for i=1,N do 189 | to_plot[#to_plot+1] = samples[i]:float() 190 | end 191 | 192 | return to_plot 193 | end 194 | 195 | -- training loop 196 | while true do 197 | -- train/test 198 | adversarial.train(trainData) 199 | adversarial.test(valData) 200 | 201 | sgdState_D.momentum = math.min(sgdState_D.momentum + 0.0008, 0.7) 202 | sgdState_D.learningRate = math.max(sgdState_D.learningRate / 1.000004, 0.000001) 203 | sgdState_G.momentum = math.min(sgdState_G.momentum + 0.0008, 0.7) 204 | sgdState_G.learningRate = math.max(sgdState_G.learningRate / 1.000004, 0.000001) 205 | 206 | -- plot errors 207 | if opt.plot and epoch and epoch % 1 == 0 then 208 | local to_plot, labels = getSamples(valData, 100, 10) 209 | torch.setdefaulttensortype('torch.FloatTensor') 210 | 211 | trainLogger:style{['% mean class accuracy of D (train set)'] = '-'} 212 | testLogger:style{['% mean class accuracy of D (test set)'] = '-'} 213 | trainLogger:plot() 214 | testLogger:plot() 215 | 216 | disp.image(to_plot, {win=opt.window, width=600, title=opt.save}) 217 | if opt.gpu then 218 | torch.setdefaulttensortype('torch.CudaTensor') 219 | else 220 | torch.setdefaulttensortype('torch.FloatTensor') 221 | end 222 | end 223 | end 224 | -------------------------------------------------------------------------------- /cifar/scripts/train_cifar_coarse_to_fine.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nngraph' 3 | require 'cunn' 4 | require 'optim' 5 | require 'image' 6 | require 'datasets.coarse_to_fine_cifar10' 7 | require 'pl' 8 | require 'paths' 9 | image_utils = require 'utils.image' 10 | ok, disp = pcall(require, 'display') 11 | if not ok then print('display not found. unable to plot') end 12 | adversarial = require 'train.conditional_adversarial' 13 | require 'layers.SpatialConvolutionUpsample' 14 | 15 | 16 | ---------------------------------------------------------------------- 17 | -- parse command-line options 18 | opt = lapp[[ 19 | -s,--save (default "logs") subdirectory to save logs 20 | --saveFreq (default 5) save every saveFreq epochs 21 | -n,--network (default "") reload pretrained network 22 | -p,--plot plot while training 23 | -r,--learningRate (default 0.02) learning rate 24 | -b,--batchSize (default 128) batch size 25 | -m,--momentum (default 0.5) momentum 26 | --coefL1 (default 0) L1 penalty on the weights 27 | --coefL2 (default 0) L2 penalty on the weights 28 | -t,--threads (default 4) number of threads 29 | -g,--gpu (default -1) gpu to run on (default cpu) 30 | -d,--noiseDim (default 100) dimensionality of noise vector 31 | --K (default 1) number of iterations to optimize D for 32 | -w, --window (default 3) windsow id of sample image 33 | --hidden_G (default 64) number of channels in hidden layers of G 34 | --hidden_D (default 64) number of channels in hidden layers of D 35 | --coarseSize (default 16) coarse scale 36 | --fineSize (default 32) fine scale 37 | ]] 38 | 39 | if opt.gpu < 0 or opt.gpu > 3 then opt.gpu = false end 40 | print(opt) 41 | 42 | -- fix seed 43 | torch.manualSeed(1) 44 | 45 | -- threads 46 | torch.setnumthreads(opt.threads) 47 | print(' set nb of threads to ' .. torch.getnumthreads()) 48 | 49 | if opt.gpu then 50 | cutorch.setDevice(opt.gpu + 1) 51 | print(' using device ' .. opt.gpu) 52 | torch.setdefaulttensortype('torch.CudaTensor') 53 | else 54 | torch.setdefaulttensortype('torch.FloatTensor') 55 | end 56 | 57 | opt.noiseDim = {1, opt.fineSize, opt.fineSize} 58 | classes = {'0','1'} 59 | opt.geometry = {3, opt.fineSize, opt.fineSize} 60 | opt.condDim = {3, opt.fineSize, opt.fineSize} 61 | 62 | local input_sz = opt.geometry[1] * opt.geometry[2] * opt.geometry[3] 63 | 64 | if opt.network == '' then 65 | ---------------------------------------------------------------------- 66 | -- define D network to train 67 | local nplanes = opt.hidden_D 68 | model_D = nn.Sequential() 69 | model_D:add(nn.CAddTable()) 70 | model_D:add(nn.SpatialConvolution(3, nplanes, 5, 5)) --28 x 28 71 | model_D:add(nn.ReLU()) 72 | model_D:add(nn.SpatialConvolution(nplanes, nplanes, 5, 5, 2, 2)) 73 | local sz =math.floor( ( (opt.fineSize - 5 + 1) - 5) / 2 + 1) 74 | model_D:add(nn.Reshape(nplanes*sz*sz)) 75 | model_D:add(nn.ReLU()) 76 | model_D:add(nn.Dropout()) 77 | model_D:add(nn.Linear(nplanes*sz*sz, 1)) 78 | model_D:add(nn.Sigmoid()) 79 | 80 | ---------------------------------------------------------------------- 81 | -- define G network to train 82 | local nplanes = opt.hidden_G 83 | model_G = nn.Sequential() 84 | model_G:add(nn.JoinTable(2, 2)) 85 | model_G:add(nn.SpatialConvolutionUpsample(3+1, nplanes, 7, 7, 1)) -- 3 color channels + conditional 86 | model_G:add(nn.ReLU()) 87 | model_G:add(nn.SpatialConvolutionUpsample(nplanes, nplanes, 7, 7, 1)) -- 3 color channels + conditional 88 | model_G:add(nn.ReLU()) 89 | model_G:add(nn.SpatialConvolutionUpsample(nplanes, 3, 5, 5, 1)) -- 3 color channels + conditional 90 | model_G:add(nn.View(opt.geometry[1], opt.geometry[2], opt.geometry[3])) 91 | 92 | else 93 | print(' reloading previously trained network: ' .. opt.network) 94 | tmp = torch.load(opt.network) 95 | model_D = tmp.D 96 | model_G = tmp.G 97 | end 98 | 99 | -- loss function: negative log-likelihood 100 | criterion = nn.BCECriterion() 101 | 102 | -- retrieve parameters and gradients 103 | parameters_D,gradParameters_D = model_D:getParameters() 104 | parameters_G,gradParameters_G = model_G:getParameters() 105 | 106 | -- print networks 107 | print('Discriminator network:') 108 | print(model_D) 109 | print('Evaluator network:') 110 | print(model_G) 111 | 112 | local nparams = 0 113 | for i=1,#model_D.modules do 114 | if model_D.modules[i].weight ~= nil then 115 | nparams = nparams + model_D.modules[i].weight:nElement() 116 | end 117 | end 118 | print('\nNumber of free parameters in D: ' .. nparams) 119 | 120 | 121 | local nparams = 0 122 | for i=1,#model_G.modules do 123 | if model_G.modules[i].weight ~= nil then 124 | nparams = nparams + model_G.modules[i].weight:nElement() 125 | end 126 | end 127 | print('Number of free parameters in G: ' .. nparams .. '\n') 128 | 129 | ---------------------------------------------------------------------- 130 | -- get/create dataset 131 | -- 132 | ntrain = 45000 133 | nval = 5000 134 | 135 | cifar.init(opt.fineSize, opt.coarseSize) 136 | 137 | -- create training set and normalize 138 | trainData = cifar.loadTrainSet(1, ntrain) 139 | mean, std = image_utils.normalize(trainData.data) 140 | trainData:makeFine() 141 | trainData:makeCoarse() 142 | trainData:makeDiff() 143 | 144 | -- create validation set and normalize 145 | valData = cifar.loadTrainSet(ntrain+1, ntrain+nval) 146 | image_utils.normalize(valData.data, mean, std) 147 | valData:makeFine() 148 | valData:makeCoarse() 149 | valData:makeDiff() 150 | 151 | -- this matrix records the current confusion across classes 152 | confusion = optim.ConfusionMatrix(classes) 153 | 154 | -- log results to files 155 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) 156 | testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) 157 | 158 | if opt.gpu then 159 | print('Copying model to gpu') 160 | model_D:cuda() 161 | model_G:cuda() 162 | end 163 | 164 | -- Training parameters 165 | sgdState_D = { 166 | learningRate = opt.learningRate, 167 | momentum = opt.momentum 168 | } 169 | 170 | sgdState_G = { 171 | learningRate = opt.learningRate, 172 | momentum = opt.momentum 173 | } 174 | 175 | -- Get examples to plot 176 | function getSamples(dataset, N) 177 | local N = N or 8 178 | local noise_inputs = torch.Tensor(N, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3]) 179 | local cond_inputs = torch.Tensor(N, opt.condDim[1], opt.condDim[2], opt.condDim[3]) 180 | local gt_diff = torch.Tensor(N, opt.geometry[1], opt.geometry[2], opt.geometry[3]) 181 | local gt = torch.Tensor(N, 3, opt.fineSize, opt.fineSize) 182 | 183 | -- Generate samples 184 | noise_inputs:uniform(-1, 1) 185 | for n = 1,N do 186 | local rand = math.random(dataset:size()) 187 | local sample = dataset[rand] 188 | cond_inputs[n] = sample[3]:clone() 189 | gt[n] = sample[4]:clone() 190 | gt_diff[n] = sample[1]:clone() 191 | end 192 | local samples = model_G:forward({noise_inputs, cond_inputs}) 193 | local preds_D = model_D:forward({samples, cond_inputs}) 194 | 195 | local to_plot = {} 196 | for i=1,N do 197 | local pred = torch.add(cond_inputs[i]:float(), samples[i]:float()) 198 | to_plot[#to_plot+1] = gt[i]:float() 199 | to_plot[#to_plot+1] = pred 200 | to_plot[#to_plot+1] = cond_inputs[i]:float() 201 | to_plot[#to_plot+1] = samples[i]:float() 202 | end 203 | return to_plot 204 | end 205 | 206 | 207 | -- training loop 208 | while true do 209 | -- train/test 210 | adversarial.train(trainData, ntrain) 211 | adversarial.test(valData, nval) 212 | 213 | adversarial.approxParzen(valData, 200, opt.batchSize) 214 | 215 | sgdState_D.momentum = math.min(sgdState_D.momentum + 0.0008, 0.7) 216 | sgdState_D.learningRate = math.max(sgdState_D.learningRate / 1.000004, 0.000001) 217 | sgdState_G.momentum = math.min(sgdState_G.momentum + 0.0008, 0.7) 218 | sgdState_G.learningRate = math.max(sgdState_G.learningRate / 1.000004, 0.000001) 219 | 220 | -- plot errors 221 | if opt.plot and epoch and epoch % 1 == 0 then 222 | local to_plot = getSamples(valData, 16) 223 | torch.setdefaulttensortype('torch.FloatTensor') 224 | 225 | trainLogger:style{['% mean class accuracy of D (train set)'] = '-'} 226 | testLogger:style{['% mean class accuracy of D (test set)'] = '-'} 227 | trainLogger:plot() 228 | testLogger:plot() 229 | 230 | disp.image(to_plot, {win=opt.window, width=700, title=opt.save}) 231 | if opt.gpu then 232 | torch.setdefaulttensortype('torch.CudaTensor') 233 | else 234 | torch.setdefaulttensortype('torch.FloatTensor') 235 | end 236 | end 237 | end 238 | -------------------------------------------------------------------------------- /cifar/scripts/train_cifar_coarse_to_fine_classcond.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'cunn' 3 | require 'nngraph' 4 | require 'optim' 5 | require 'image' 6 | require 'datasets.coarse_to_fine_cifar10' 7 | require 'pl' 8 | require 'paths' 9 | image_utils = require 'utils.image' 10 | ok, disp = pcall(require, 'display') 11 | if not ok then print('display not found. unable to plot') end 12 | adversarial = require 'train.double_conditional_adversarial' 13 | require 'layers.SpatialConvolutionUpsample' 14 | 15 | 16 | ---------------------------------------------------------------------- 17 | -- parse command-line options 18 | opt = lapp[[ 19 | -s,--save (default "logs") subdirectory to save logs 20 | --saveFreq (default 2) save every saveFreq epochs 21 | -n,--network (default "") reload pretrained network 22 | -p,--plot plot while training 23 | -r,--learningRate (default 0.02) learning rate 24 | -b,--batchSize (default 128) batch size 25 | -m,--momentum (default 0) momentum 26 | --coefL1 (default 0) L1 penalty on the weights 27 | --coefL2 (default 0) L2 penalty on the weights 28 | -t,--threads (default 4) number of threads 29 | -g,--gpu (default -1) gpu to run on (default cpu) 30 | -d,--noiseDim (default 100) dimensionality of noise vector 31 | --K (default 1) number of iterations to optimize D for 32 | -w, --window (default 3) windsow id of sample image 33 | --hidden_G (default 64) number of channels in hidden layers of G 34 | --hidden_D (default 64) number of channels in hidden layers of D 35 | --coarseSize (default 16) coarse scale 36 | --fineSize (default 32) fine scale 37 | ]] 38 | 39 | if opt.gpu < 0 or opt.gpu > 3 then opt.gpu = false end 40 | print(opt) 41 | 42 | -- fix seed 43 | torch.manualSeed(1) 44 | 45 | -- threads 46 | torch.setnumthreads(opt.threads) 47 | print(' set nb of threads to ' .. torch.getnumthreads()) 48 | 49 | if opt.gpu then 50 | cutorch.setDevice(opt.gpu + 1) 51 | print(' using device ' .. opt.gpu) 52 | torch.setdefaulttensortype('torch.CudaTensor') 53 | else 54 | torch.setdefaulttensortype('torch.FloatTensor') 55 | end 56 | 57 | opt.noiseDim = {1, opt.fineSize, opt.fineSize} 58 | classes = {'0','1'} 59 | opt.geometry = {3, opt.fineSize, opt.fineSize} 60 | opt.condDim1 = 10 61 | opt.condDim2 = {3, opt.fineSize, opt.fineSize} 62 | cifar_classes = {'airplane', 'automobile', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck'} 63 | 64 | local input_sz = opt.geometry[1] * opt.geometry[2] * opt.geometry[3] 65 | 66 | if opt.network == '' then 67 | ---------------------------------------------------------------------- 68 | -- define D network to train 69 | local nplanes = opt.hidden_D 70 | x_d = nn.Identity()() 71 | x_c1 = nn.Identity()() 72 | x_c2 = nn.Identity()() 73 | d1 = nn.CAddTable()({x_d, x_c2}) 74 | c1 = nn.Linear(opt.condDim1, opt.condDim2[2]*opt.condDim2[3])(x_c1) 75 | c2 = nn.Reshape(1, opt.condDim2[2], opt.condDim2[3])(nn.ReLU()(c1)) 76 | d2 = nn.JoinTable(2, 2)({d1, c2}) 77 | d3 = nn.SpatialConvolution(3+1, nplanes, 5, 5)(d2) 78 | d4 = nn.SpatialConvolution(nplanes, nplanes, 5, 5, 2, 2)(nn.ReLU()(d3)) 79 | local sz =math.floor( ( (opt.fineSize - 5 + 1) - 5) / 2 + 1) 80 | d5 = nn.Reshape(nplanes*sz*sz)(d4) 81 | d6 = nn.Linear(nplanes*sz*sz, 1)(nn.Dropout()(nn.ReLU()(d5))) 82 | d7 = nn.Sigmoid()(d6) 83 | model_D = nn.gModule({x_d, x_c1, x_c2}, {d7}) 84 | 85 | ---------------------------------------------------------------------- 86 | -- define G network to train 87 | local nplanes = opt.hidden_G 88 | x_n = nn.Identity()() -- noise (shaped as coarse map) 89 | g_c1 = nn.Identity()() -- class vector 90 | g_c2 = nn.Identity()() -- coarse map 91 | class1 = nn.Linear(opt.condDim1, opt.condDim2[2]*opt.condDim2[3])(g_c1) 92 | class2 = nn.Reshape(1, opt.condDim2[2], opt.condDim2[3])(nn.ReLU()(class1)) --convert class vector into map 93 | g1 = nn.JoinTable(2, 2)({x_n, class2, g_c2}) -- combine maps into 5 channels 94 | g2 = nn.SpatialConvolutionUpsample(5, nplanes, 7, 7, 1)(g1) 95 | g3 = nn.SpatialConvolutionUpsample(nplanes, nplanes, 7, 7, 1)(nn.ReLU()(g2)) 96 | g4 = nn.SpatialConvolutionUpsample(nplanes, 3, 5, 5, 1)(nn.ReLU()(g3)) 97 | model_G = nn.gModule({x_n, g_c1, g_c2}, {g4}) 98 | 99 | else 100 | print(' reloading previously trained network: ' .. opt.network) 101 | tmp = torch.load(opt.network) 102 | model_D = tmp.D 103 | model_G = tmp.G 104 | end 105 | 106 | -- loss function: negative log-likelihood 107 | criterion = nn.BCECriterion() 108 | 109 | -- retrieve parameters and gradients 110 | parameters_D,gradParameters_D = model_D:getParameters() 111 | parameters_G,gradParameters_G = model_G:getParameters() 112 | 113 | -- print networks 114 | print('Discriminator network:') 115 | print(model_D) 116 | print('Generator network:') 117 | print(model_G) 118 | 119 | local nparams = 0 120 | for i=1,#model_D.forwardnodes do 121 | if model_D.forwardnodes[i].data.module ~= nil and model_D.forwardnodes[i].data.module.weight ~= nil then 122 | nparams = nparams + model_D.forwardnodes[i].data.module.weight:nElement() 123 | end 124 | end 125 | print('\nNumber of free parameters in D: ' .. nparams) 126 | 127 | local nparams = 0 128 | for i=1,#model_G.forwardnodes do 129 | if model_G.forwardnodes[i].data.module ~= nil and model_G.forwardnodes[i].data.module.weight ~= nil then 130 | nparams = nparams + model_G.forwardnodes[i].data.module.weight:nElement() 131 | end 132 | end 133 | print('Number of free parameters in G: ' .. nparams .. '\n') 134 | 135 | ---------------------------------------------------------------------- 136 | -- get/create dataset 137 | -- 138 | ntrain = 45000 139 | nval = 5000 140 | 141 | cifar.init(opt.fineSize, opt.coarseSize) 142 | 143 | -- create training set and normalize 144 | trainData = cifar.loadTrainSet(1, ntrain) 145 | mean, std = image_utils.normalize(trainData.data) 146 | trainData:makeFine() 147 | trainData:makeCoarse() 148 | trainData:makeDiff() 149 | 150 | -- create validation set and normalize 151 | valData = cifar.loadTrainSet(ntrain+1, ntrain+nval) 152 | image_utils.normalize(valData.data, mean, std) 153 | valData:makeFine() 154 | valData:makeCoarse() 155 | valData:makeDiff() 156 | 157 | -- this matrix records the current confusion across classes 158 | confusion = optim.ConfusionMatrix(classes) 159 | 160 | -- log results to files 161 | trainLogger = optim.Logger(paths.concat(opt.save, 'train.log')) 162 | testLogger = optim.Logger(paths.concat(opt.save, 'test.log')) 163 | 164 | if opt.gpu then 165 | print('Copy model to gpu') 166 | model_D:cuda() 167 | model_G:cuda() 168 | end 169 | 170 | -- Training parameters 171 | sgdState_D = { 172 | learningRate = opt.learningRate, 173 | momentum = opt.momentum 174 | } 175 | 176 | sgdState_G = { 177 | learningRate = opt.learningRate, 178 | momentum = opt.momentum 179 | } 180 | 181 | -- Get examples to plot 182 | function getSamples(dataset, N, perclass) 183 | local N = N or 8 184 | local perclass = perclass or 10 185 | local noise_inputs = torch.Tensor(N, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3]) 186 | local cond_inputs1 = torch.Tensor(N, opt.condDim1) 187 | local cond_inputs2 = torch.Tensor(N, opt.condDim2[1], opt.condDim2[2], opt.condDim2[3]) 188 | local gt_diff = torch.Tensor(N, opt.geometry[1], opt.geometry[2], opt.geometry[3]) 189 | local gt = torch.Tensor(N, 3, opt.fineSize, opt.fineSize) 190 | 191 | -- Generate samples 192 | noise_inputs:uniform(-1, 1) 193 | local class = 1 194 | local classes = {} 195 | for n = 1,N do 196 | classes[n] = cifar_classes[class] 197 | local rand 198 | local sample 199 | while true do 200 | rand = math.random(dataset:size()) 201 | sample = dataset[rand] 202 | local max, ind = torch.max(sample[2], 1) 203 | if ind[1] == class then 204 | break 205 | end 206 | end 207 | cond_inputs1[n] = sample[2]:clone() 208 | cond_inputs2[n] = sample[3]:clone() 209 | gt[n] = sample[4]:clone() 210 | gt_diff[n] = sample[1]:clone() 211 | if n % perclass == 0 then class = class + 1 end 212 | if class > #cifar_classes then class = 1 end 213 | end 214 | local samples = model_G:forward({noise_inputs, cond_inputs1, cond_inputs2}) 215 | local preds_D = model_D:forward({samples, cond_inputs1, cond_inputs2}) 216 | 217 | local to_plot = {} 218 | for i=1,N do 219 | local pred = torch.add(cond_inputs2[i]:float(), samples[i]:float()) 220 | to_plot[#to_plot+1] = gt[i]:float() 221 | to_plot[#to_plot+1] = pred 222 | to_plot[#to_plot+1] = cond_inputs2[i]:float() 223 | to_plot[#to_plot+1] = samples[i]:float() 224 | end 225 | 226 | return to_plot 227 | end 228 | 229 | -- training loop 230 | while true do 231 | -- train/test 232 | adversarial.train(trainData) 233 | adversarial.test(valData) 234 | 235 | sgdState_D.momentum = math.min(sgdState_D.momentum + 0.0008, 0.7) 236 | sgdState_D.learningRate = math.max(sgdState_D.learningRate / 1.00004, 0.000001) 237 | sgdState_G.momentum = math.min(sgdState_G.momentum + 0.0008, 0.7) 238 | sgdState_G.learningRate = math.max(sgdState_G.learningRate / 1.00004, 0.000001) 239 | 240 | -- plot errors 241 | if opt.plot and epoch and epoch % 1 == 0 then 242 | local to_plot = getSamples(valData, 16, 2) 243 | torch.setdefaulttensortype('torch.FloatTensor') 244 | 245 | trainLogger:style{['% mean class accuracy of D (train set)'] = '-'} 246 | testLogger:style{['% mean class accuracy of D (test set)'] = '-'} 247 | trainLogger:plot() 248 | testLogger:plot() 249 | 250 | disp.image(to_plot, {win=opt.window, width=600, title=opt.save}) 251 | if opt.gpu then 252 | torch.setdefaulttensortype('torch.CudaTensor') 253 | else 254 | torch.setdefaulttensortype('torch.FloatTensor') 255 | end 256 | end 257 | end 258 | -------------------------------------------------------------------------------- /cifar/train/adversarial.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'cunn' 4 | require 'optim' 5 | require 'pl' 6 | 7 | local adversarial = {} 8 | 9 | -- training function 10 | function adversarial.train(dataset, N) 11 | epoch = epoch or 1 12 | local N = N or dataset:size() 13 | local dataBatchSize = opt.batchSize / 2 14 | local time = sys.clock() 15 | 16 | -- do one epoch 17 | print('\n on training set:') 18 | print(" online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ' lr = ' .. sgdState_D.learningRate .. ', momentum = ' .. sgdState_D.momentum .. ']') 19 | for t = 1,N,dataBatchSize do 20 | 21 | local inputs = torch.Tensor(opt.batchSize, opt.geometry[1], opt.geometry[2], opt.geometry[3]) 22 | local targets = torch.Tensor(opt.batchSize) 23 | local noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim) 24 | 25 | ---------------------------------------------------------------------- 26 | -- create closure to evaluate f(X) and df/dX of discriminator 27 | local fevalD = function(x) 28 | collectgarbage() 29 | if x ~= parameters_D then -- get new parameters 30 | parameters_D:copy(x) 31 | end 32 | 33 | gradParameters_D:zero() -- reset gradients 34 | 35 | -- forward pass 36 | local outputs = model_D:forward(inputs) 37 | local f = criterion:forward(outputs, targets) 38 | 39 | -- backward pass 40 | local df_do = criterion:backward(outputs, targets) 41 | model_D:backward(inputs, df_do) 42 | 43 | -- penalties (L1 and L2): 44 | if opt.coefL1 ~= 0 or opt.coefL2 ~= 0 then 45 | local norm,sign= torch.norm,torch.sign 46 | -- Loss: 47 | f = f + opt.coefL1 * norm(parameters_D,1) 48 | f = f + opt.coefL2 * norm(parameters_D,2)^2/2 49 | -- Gradients: 50 | gradParameters_D:add( sign(parameters_D):mul(opt.coefL1) + parameters_D:clone():mul(opt.coefL2) ) 51 | end 52 | -- update confusion (add 1 since targets are binary) 53 | for i = 1,opt.batchSize do 54 | local c 55 | if outputs[i][1] > 0.5 then c = 2 else c = 1 end 56 | confusion:add(c, targets[i]+1) 57 | end 58 | 59 | return f,gradParameters_D 60 | end 61 | 62 | ---------------------------------------------------------------------- 63 | -- create closure to evaluate f(X) and df/dX of generator 64 | local fevalG = function(x) 65 | collectgarbage() 66 | if x ~= parameters_G then -- get new parameters 67 | parameters_G:copy(x) 68 | end 69 | 70 | gradParameters_G:zero() -- reset gradients 71 | 72 | -- forward pass 73 | local samples = model_G:forward(noise_inputs) 74 | local outputs = model_D:forward(samples) 75 | local f = criterion:forward(outputs, targets) 76 | 77 | -- backward pass 78 | local df_samples = criterion:backward(outputs, targets) 79 | model_D:backward(samples, df_samples) 80 | local df_do = model_D.modules[1].gradInput 81 | model_G:backward(noise_inputs, df_do) 82 | 83 | -- penalties (L1 and L2): 84 | if opt.coefL1 ~= 0 or opt.coefL2 ~= 0 then 85 | local norm,sign= torch.norm,torch.sign 86 | -- Loss: 87 | f = f + opt.coefL1 * norm(parameters_G,1) 88 | f = f + opt.coefL2 * norm(parameters_G,2)^2/2 89 | -- Gradients: 90 | gradParameters_G:add( sign(parameters_G):mul(opt.coefL1) + parameters_G:clone():mul(opt.coefL2) ) 91 | end 92 | 93 | return f,gradParameters_G 94 | end 95 | 96 | ---------------------------------------------------------------------- 97 | -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 98 | -- Get half a minibatch of real, half fake 99 | for k=1,opt.K do 100 | -- (1.1) Real data 101 | local k = 1 102 | for i = t,math.min(t+dataBatchSize-1,dataset:size()) do 103 | local idx = math.random(dataset:size()) 104 | local sample = dataset[idx] 105 | inputs[k] = sample[1]:clone() 106 | k = k + 1 107 | end 108 | targets[{{1,dataBatchSize}}]:fill(1) 109 | -- (1.2) Sampled data 110 | noise_inputs:uniform(-1, 1) 111 | local samples = model_G:forward(noise_inputs[{{dataBatchSize+1,opt.batchSize}}]) 112 | for i = 1, dataBatchSize do 113 | inputs[k] = samples[i]:clone() 114 | k = k + 1 115 | end 116 | targets[{{dataBatchSize+1,opt.batchSize}}]:fill(0) 117 | 118 | optim.sgd(fevalD, parameters_D, sgdState_D) 119 | 120 | end -- end for K 121 | 122 | ---------------------------------------------------------------------- 123 | -- (2) Update G network: maximize log(D(G(z))) 124 | noise_inputs:uniform(-1, 1) 125 | targets:fill(1) 126 | optim.sgd(fevalG, parameters_G, sgdState_G) 127 | 128 | -- display progress 129 | xlua.progress(t, dataset:size()) 130 | end -- end for loop over dataset 131 | 132 | -- time taken 133 | time = sys.clock() - time 134 | time = time / dataset:size() 135 | print(" time to learn 1 sample = " .. (time*1000) .. 'ms') 136 | 137 | -- print confusion matrix 138 | print(confusion) 139 | trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid * 100} 140 | confusion:zero() 141 | 142 | -- save/log current net 143 | if epoch % opt.saveFreq == 0 then 144 | local filename = paths.concat(opt.save, 'adversarial.net') 145 | os.execute('mkdir -p ' .. sys.dirname(filename)) 146 | if paths.filep(filename) then 147 | os.execute('mv ' .. filename .. ' ' .. filename .. '.old') 148 | end 149 | print(' saving network to '..filename) 150 | torch.save(filename, {D = model_D, G = model_G, opt = opt}) 151 | end 152 | 153 | -- next epoch 154 | epoch = epoch + 1 155 | end 156 | 157 | -- test function 158 | function adversarial.test(dataset) 159 | local time = sys.clock() 160 | local N = N or dataset:size() 161 | 162 | print('\n on testing Set:') 163 | for t = 1,N,opt.batchSize do 164 | -- display progress 165 | xlua.progress(t, dataset:size()) 166 | 167 | ---------------------------------------------------------------------- 168 | --(1) Real data 169 | local inputs = torch.Tensor(opt.batchSize,opt.geometry[1],opt.geometry[2], opt.geometry[3]) 170 | local targets = torch.ones(opt.batchSize) 171 | local k = 1 172 | for i = t,math.min(t+opt.batchSize-1,dataset:size()) do 173 | local idx = math.random(dataset:size()) 174 | local sample = dataset[idx] 175 | local input = sample[1]:clone() 176 | inputs[k] = input 177 | k = k + 1 178 | end 179 | local preds = model_D:forward(inputs) -- get predictions from D 180 | -- add to confusion matrix 181 | for i = 1,opt.batchSize do 182 | local c 183 | if preds[i][1] > 0.5 then c = 2 else c = 1 end 184 | confusion:add(c, targets[i] + 1) 185 | end 186 | 187 | ---------------------------------------------------------------------- 188 | -- (2) Generated data (don't need this really, since no 'validation' generations) 189 | local noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim):uniform(-1, 1) 190 | local inputs = model_G:forward(noise_inputs) 191 | local targets = torch.zeros(opt.batchSize) 192 | local preds = model_D:forward(inputs) -- get predictions from D 193 | -- add to confusion matrix 194 | for i = 1,opt.batchSize do 195 | local c 196 | if preds[i][1] > 0.5 then c = 2 else c = 1 end 197 | confusion:add(c, targets[i] + 1) 198 | end 199 | end -- end loop over dataset 200 | 201 | -- timing 202 | time = sys.clock() - time 203 | time = time / dataset:size() 204 | print(" time to test 1 sample = " .. (time*1000) .. 'ms') 205 | 206 | -- print confusion matrix 207 | print(confusion) 208 | testLogger:add{['% mean class accuracy (test set)'] = confusion.totalValid * 100} 209 | confusion:zero() 210 | end 211 | 212 | return adversarial 213 | -------------------------------------------------------------------------------- /cifar/train/conditional_adversarial.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'optim' 3 | require 'pl' 4 | require 'paths' 5 | 6 | local adversarial = {} 7 | 8 | -- training function 9 | function adversarial.train(dataset, N) 10 | epoch = epoch or 1 11 | local N = N or dataset:size() 12 | local time = sys.clock() 13 | local dataBatchSize = opt.batchSize / 2 14 | 15 | local inputs = torch.Tensor(opt.batchSize, opt.geometry[1], opt.geometry[2], opt.geometry[3]) 16 | local targets = torch.Tensor(opt.batchSize) 17 | local noise_inputs 18 | if type(opt.noiseDim) == 'number' then 19 | noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim) 20 | else 21 | noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3]) 22 | end 23 | local cond_inputs 24 | if type(opt.condDim) == 'number' then 25 | cond_inputs = torch.Tensor(opt.batchSize, opt.condDim) 26 | else 27 | cond_inputs = torch.Tensor(opt.batchSize, opt.condDim[1], opt.condDim[2], opt.condDim[3]) 28 | end 29 | 30 | -- do one epoch 31 | print('\n on training set:') 32 | print(" online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ' lr = ' .. sgdState_D.learningRate .. ', momentum = ' .. sgdState_D.momentum .. ']') 33 | for t = 1,N,dataBatchSize*opt.K do 34 | 35 | 36 | 37 | ---------------------------------------------------------------------- 38 | -- create closure to evaluate f(X) and df/dX of discriminator 39 | local fevalD = function(x) 40 | collectgarbage() 41 | if x ~= parameters_D then -- get new parameters 42 | parameters_D:copy(x) 43 | end 44 | 45 | gradParameters_D:zero() -- reset gradients 46 | 47 | -- forward pass 48 | local outputs = model_D:forward({inputs, cond_inputs}) 49 | local f = criterion:forward(outputs, targets) 50 | 51 | -- backward pass 52 | local df_do = criterion:backward(outputs, targets) 53 | model_D:backward({inputs, cond_inputs}, df_do) 54 | 55 | -- penalties (L1 and L2): 56 | if opt.coefL1 ~= 0 or opt.coefL2 ~= 0 then 57 | local norm,sign= torch.norm,torch.sign 58 | -- Loss: 59 | f = f + opt.coefL1 * norm(parameters_D,1) 60 | f = f + opt.coefL2 * norm(parameters_D,2)^2/2 61 | -- Gradients: 62 | gradParameters_D:add( sign(parameters_D):mul(opt.coefL1) + parameters_D:clone():mul(opt.coefL2) ) 63 | end 64 | -- update confusion (add 1 since classes are binary) 65 | for i = 1,opt.batchSize do 66 | local c 67 | if outputs[i][1] > 0.5 then c = 2 else c = 1 end 68 | confusion:add(c, targets[i]+1) 69 | end 70 | 71 | return f,gradParameters_D 72 | end 73 | 74 | ---------------------------------------------------------------------- 75 | -- create closure to evaluate f(X) and df/dX of generator 76 | local fevalG = function(x) 77 | collectgarbage() 78 | if x ~= parameters_G then -- get new parameters 79 | parameters_G:copy(x) 80 | end 81 | 82 | gradParameters_G:zero() -- reset gradients 83 | -- debugger.enter() 84 | 85 | -- forward pass 86 | local samples = model_G:forward({noise_inputs, cond_inputs}) 87 | local outputs = model_D:forward({samples, cond_inputs}) 88 | local f = criterion:forward(outputs, targets) 89 | 90 | -- backward pass 91 | local df_samples = criterion:backward(outputs, targets) 92 | model_D:backward({samples, cond_inputs}, df_samples) 93 | local df_do = model_D.gradInput[1] 94 | model_G:backward({noise_inputs, cond_inputs}, df_do) 95 | 96 | -- penalties (L1 and L2): 97 | if opt.coefL1 ~= 0 or opt.coefL2 ~= 0 then 98 | local norm,sign= torch.norm,torch.sign 99 | -- Loss: 100 | f = f + opt.coefL1 * norm(parameters_D,1) 101 | f = f + opt.coefL2 * norm(parameters_D,2)^2/2 102 | -- Gradients: 103 | gradParameters_G:add( sign(parameters_G):mul(opt.coefL1) + parameters_G:clone():mul(opt.coefL2) ) 104 | end 105 | 106 | return f,gradParameters_G 107 | end 108 | 109 | ---------------------------------------------------------------------- 110 | -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 111 | -- Get half a minibatch of real, half fake 112 | for k=1,opt.K do 113 | -- (1.1) Real data 114 | local k = 1 115 | for i = t,math.min(t+dataBatchSize-1,dataset:size()) do 116 | -- load new sample 117 | local idx = math.random(dataset:size()) 118 | local sample = dataset[idx] 119 | inputs[k] = sample[1]:clone() 120 | cond_inputs[k] = sample[3]:clone() 121 | k = k + 1 122 | end 123 | targets[{{1,dataBatchSize}}]:fill(1) 124 | -- (1.2) Sampled data 125 | noise_inputs:uniform(-1, 1) 126 | for i = dataBatchSize+1,opt.batchSize do 127 | local idx = math.random(dataset:size()) 128 | local sample = dataset[idx] 129 | cond_inputs[i] = sample[3]:clone() 130 | end 131 | local samples = model_G:forward({noise_inputs[{{dataBatchSize+1,opt.batchSize}}], cond_inputs[{{dataBatchSize+1,opt.batchSize}}]}) 132 | for i = 1, dataBatchSize do 133 | inputs[k] = samples[i]:clone() 134 | k = k + 1 135 | end 136 | targets[{{dataBatchSize+1,opt.batchSize}}]:fill(0) 137 | 138 | optim.sgd(fevalD, parameters_D, sgdState_D) 139 | end -- end for K 140 | 141 | ---------------------------------------------------------------------- 142 | -- (2) Update G network: maximize log(D(G(z))) 143 | noise_inputs:uniform(-1, 1) 144 | for i = 1,opt.batchSize do 145 | local idx = math.random(dataset:size()) 146 | local sample = dataset[idx] 147 | cond_inputs[i] = sample[3]:clone() 148 | end 149 | targets:fill(1) 150 | optim.sgd(fevalG, parameters_G, sgdState_G) 151 | 152 | -- disp progress 153 | xlua.progress(t, N) 154 | end -- end for loop over dataset 155 | 156 | -- time taken 157 | time = sys.clock() - time 158 | time = time / dataset:size() 159 | print(" time to learn 1 sample = " .. (time*1000) .. 'ms') 160 | 161 | -- print confusion matrix 162 | print(confusion) 163 | trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid * 100} 164 | confusion:zero() 165 | 166 | -- save/log current net 167 | if epoch % opt.saveFreq == 0 then 168 | local filename = paths.concat(opt.save, 'conditional_adversarial.net') 169 | os.execute('mkdir -p ' .. sys.dirname(filename)) 170 | if paths.filep(filename) then 171 | os.execute('mv ' .. filename .. ' ' .. filename .. '.old') 172 | end 173 | print(' saving network to '..filename) 174 | torch.save(filename, {D = model_D, G = model_G, E = model_E, opt = opt}) 175 | end 176 | 177 | -- next epoch 178 | epoch = epoch + 1 179 | end 180 | 181 | -- test function 182 | function adversarial.test(dataset, N) 183 | local time = sys.clock() 184 | local N = N or dataset:size() 185 | 186 | local inputs = torch.Tensor(opt.batchSize, opt.geometry[1], opt.geometry[2], opt.geometry[3]) 187 | local targets = torch.Tensor(opt.batchSize) 188 | local noise_inputs 189 | if type(opt.noiseDim) == 'number' then 190 | noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim) 191 | else 192 | noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3]) 193 | end 194 | local cond_inputs 195 | if type(opt.condDim) == 'number' then 196 | cond_inputs = torch.Tensor(opt.batchSize, opt.condDim) 197 | else 198 | cond_inputs = torch.Tensor(opt.batchSize, opt.condDim[1], opt.condDim[2], opt.condDim[3]) 199 | end 200 | 201 | print('\n on testing set:') 202 | for t = 1,N,opt.batchSize do 203 | -- display progress 204 | xlua.progress(t, N) 205 | 206 | ---------------------------------------------------------------------- 207 | -- (1) Real data 208 | local targets = torch.ones(opt.batchSize) 209 | local k = 1 210 | for i = t,math.min(t+opt.batchSize-1,dataset:size()) do 211 | local idx = math.random(dataset:size()) 212 | local sample = dataset[idx] 213 | inputs[k] = sample[1]:clone() 214 | cond_inputs[k] = sample[3]:clone() 215 | k = k + 1 216 | end 217 | local preds = model_D:forward({inputs, cond_inputs}) -- get predictions from D 218 | -- add to confusion matrix 219 | for i = 1,opt.batchSize do 220 | local c 221 | if preds[i][1] > 0.5 then c = 2 else c = 1 end 222 | confusion:add(c, targets[i] + 1) 223 | end 224 | 225 | ---------------------------------------------------------------------- 226 | -- (2) Generated data (don't need this really, since no 'validation' generations) 227 | noise_inputs:uniform(-1, 1) 228 | local c = 1 229 | for i = 1,opt.batchSize do 230 | sample = dataset[math.random(dataset:size())] 231 | cond_inputs[i] = sample[3]:clone() 232 | end 233 | local samples = model_G:forward({noise_inputs, cond_inputs}) 234 | local targets = torch.zeros(opt.batchSize) 235 | local preds = model_D:forward({samples, cond_inputs}) -- get predictions from D 236 | -- add to confusion matrix 237 | for i = 1,opt.batchSize do 238 | local c 239 | if preds[i][1] > 0.5 then c = 2 else c = 1 end 240 | confusion:add(c, targets[i] + 1) 241 | end 242 | end -- end loop over dataset 243 | 244 | -- timing 245 | time = sys.clock() - time 246 | time = time / dataset:size() 247 | print(" time to test 1 sample = " .. (time*1000) .. 'ms') 248 | 249 | -- print confusion matrix 250 | print(confusion) 251 | testLogger:add{['% mean class accuracy (test set)'] = confusion.totalValid * 100} 252 | confusion:zero() 253 | 254 | return cond_inputs 255 | end 256 | 257 | -- Unnormalized parzen window type estimate (used to track performance during training) 258 | -- Really just a nearest neighbours of ground truth to multiple generations 259 | function adversarial.approxParzen(dataset, nsamples, nneighbors) 260 | best_dist = best_dist or 1e10 261 | print(' evaluating approximate parzen ') 262 | local noise_inputs 263 | if type(opt.noiseDim) == 'number' then 264 | noise_inputs = torch.Tensor(nneighbors, opt.noiseDim) 265 | else 266 | noise_inputs = torch.Tensor(nneighbors, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3]) 267 | end 268 | local cond_inputs 269 | if type(opt.condDim) == 'number' then 270 | cond_inputs = torch.Tensor(nneighbors, opt.condDim) 271 | else 272 | cond_inputs = torch.Tensor(nneighbors, opt.condDim[1], opt.condDim[2], opt.condDim[3]) 273 | end 274 | local distances = torch.Tensor(nsamples) 275 | for n = 1,nsamples do 276 | xlua.progress(n, nsamples) 277 | local sample = dataset[math.random(dataset:size())] 278 | local cond_input = sample[3] 279 | local fine = sample[4]:type(torch.getdefaulttensortype()) 280 | noise_inputs:uniform(-1, 1) 281 | for i = 1,nneighbors do 282 | cond_inputs[i] = cond_input:clone() 283 | end 284 | neighbors = model_G:forward({noise_inputs, cond_inputs}) 285 | neighbors:add(cond_inputs) 286 | -- compute distance 287 | local dist = 1e10 288 | for i = 1,nneighbors do 289 | dist = math.min(torch.dist(neighbors[i], fine), dist) 290 | end 291 | distances[n] = dist 292 | end 293 | print('average || x_' .. opt.fineSize .. ' - G(x_' .. opt.coarseSize .. ') || = ' .. distances:mean()) 294 | 295 | -- save/log current net 296 | if distances:mean() < best_dist then 297 | best_dist = distances:mean() 298 | 299 | local filename = paths.concat(opt.save, 'conditional_adversarial.bestnet') 300 | os.execute('mkdir -p ' .. sys.dirname(filename)) 301 | if paths.filep(filename) then 302 | os.execute('mv ' .. filename .. ' ' .. filename .. '.old') 303 | end 304 | print(' saving network to '..filename) 305 | torch.save(filename, {D = model_D, G = model_G, E = model_E, opt = opt}) 306 | end 307 | return distances 308 | end 309 | 310 | return adversarial 311 | -------------------------------------------------------------------------------- /cifar/train/double_conditional_adversarial.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'optim' 3 | require 'pl' 4 | require 'paths' 5 | require 'image' 6 | 7 | local adversarial = {} 8 | 9 | -- training function 10 | function adversarial.train(dataset, N) 11 | epoch = epoch or 1 12 | local N = N or dataset:size() 13 | local time = sys.clock() 14 | local dataBatchSize = opt.batchSize / 2 15 | 16 | local inputs = torch.Tensor(opt.batchSize, opt.geometry[1], opt.geometry[2], opt.geometry[3]) 17 | local targets = torch.Tensor(opt.batchSize) 18 | local noise_inputs 19 | if type(opt.noiseDim) == 'number' then 20 | noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim) 21 | else 22 | noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3]) 23 | end 24 | local cond_inputs1 25 | local cond_inputs2 26 | if type(opt.condDim1) == 'number' then 27 | cond_inputs1 = torch.Tensor(opt.batchSize, opt.condDim1) 28 | else 29 | cond_inputs1 = torch.Tensor(opt.batchSize, opt.condDim1[1], opt.condDim1[2], opt.condDim1[3]) 30 | end 31 | if type(opt.condDim2) == 'number' then 32 | cond_inputs2 = torch.Tensor(opt.batchSize, opt.condDim2) 33 | else 34 | cond_inputs2 = torch.Tensor(opt.batchSize, opt.condDim2[1], opt.condDim2[2], opt.condDim2[3]) 35 | end 36 | 37 | -- do one epoch 38 | print('\n on training set:') 39 | print(" online epoch # " .. epoch .. ' [batchSize = ' .. opt.batchSize .. ' lr = ' .. sgdState_D.learningRate .. ', momentum = ' .. sgdState_D.momentum .. ']') 40 | for t = 1,N,dataBatchSize*opt.K do 41 | 42 | ---------------------------------------------------------------------- 43 | -- create closure to evaluate f(X) and df/dX of discriminator 44 | local fevalD = function(x) 45 | collectgarbage() 46 | if x ~= parameters_D then -- get new parameters 47 | parameters_D:copy(x) 48 | end 49 | 50 | gradParameters_D:zero() -- reset gradients 51 | 52 | -- forward pass 53 | local outputs = model_D:forward({inputs, cond_inputs1, cond_inputs2}) 54 | local f = criterion:forward(outputs, targets) 55 | 56 | -- backward pass 57 | local df_do = criterion:backward(outputs, targets) 58 | model_D:backward({inputs, cond_inputs1, cond_inputs2}, df_do) 59 | 60 | -- penalties (L1 and L2): 61 | if opt.coefL1 ~= 0 or opt.coefL2 ~= 0 then 62 | local norm,sign= torch.norm,torch.sign 63 | -- Loss: 64 | f = f + opt.coefL1 * norm(parameters_D,1) 65 | f = f + opt.coefL2 * norm(parameters_D,2)^2/2 66 | -- Gradients: 67 | gradParameters_D:add( sign(parameters_D):mul(opt.coefL1) + parameters_D:clone():mul(opt.coefL2) ) 68 | end 69 | -- update confusion (add 1 since classes are binary) 70 | for i = 1,opt.batchSize do 71 | local c 72 | if outputs[i][1] > 0.5 then c = 2 else c = 1 end 73 | confusion:add(c, targets[i]+1) 74 | end 75 | 76 | return f,gradParameters_D 77 | end 78 | 79 | ---------------------------------------------------------------------- 80 | -- create closure to evaluate f(X) and df/dX of generator 81 | local fevalG = function(x) 82 | collectgarbage() 83 | if x ~= parameters_G then -- get new parameters 84 | parameters_G:copy(x) 85 | end 86 | 87 | gradParameters_G:zero() -- reset gradients 88 | 89 | -- forward pass 90 | local samples = model_G:forward({noise_inputs, cond_inputs1, cond_inputs2}) 91 | local outputs = model_D:forward({samples, cond_inputs1, cond_inputs2}) 92 | local f = criterion:forward(outputs, targets) 93 | 94 | -- backward pass 95 | local df_samples = criterion:backward(outputs, targets) 96 | model_D:backward({samples, cond_inputs1, cond_inputs2}, df_samples) 97 | local df_do = model_D.gradInput[1] 98 | model_G:backward({noise_inputs, cond_inputs1, cond_inputs2}, df_do) 99 | 100 | -- penalties (L1 and L2): 101 | if opt.coefL1 ~= 0 or opt.coefL2 ~= 0 then 102 | local norm,sign= torch.norm,torch.sign 103 | -- Loss: 104 | f = f + opt.coefL1 * norm(parameters_D,1) 105 | f = f + opt.coefL2 * norm(parameters_D,2)^2/2 106 | -- Gradients: 107 | gradParameters_G:add( sign(parameters_G):mul(opt.coefL1) + parameters_G:clone():mul(opt.coefL2) ) 108 | end 109 | 110 | return f,gradParameters_G 111 | end 112 | 113 | ---------------------------------------------------------------------- 114 | -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 115 | -- Get half a minibatch of real, half fake 116 | for k=1,opt.K do 117 | -- (1.1) Real data 118 | local k = 1 119 | for i = t,math.min(t+dataBatchSize-1,dataset:size()) do 120 | -- load new sample 121 | local idx = math.random(dataset:size()) 122 | local sample = dataset[idx] 123 | inputs[k] = sample[1]:clone() 124 | cond_inputs1[k] = sample[2]:clone() 125 | cond_inputs2[k] = sample[3]:clone() 126 | k = k + 1 127 | end 128 | targets[{{1,dataBatchSize}}]:fill(1) 129 | -- (1.2) Sampled data 130 | noise_inputs:uniform(-1, 1) 131 | for i = dataBatchSize+1,opt.batchSize do 132 | local idx = math.random(dataset:size()) 133 | local sample = dataset[idx] 134 | cond_inputs1[i] = sample[2]:clone() 135 | cond_inputs2[i] = sample[3]:clone() 136 | end 137 | local samples = model_G:forward({noise_inputs[{{dataBatchSize+1,opt.batchSize}}], cond_inputs1[{{dataBatchSize+1,opt.batchSize}}], cond_inputs2[{{dataBatchSize+1,opt.batchSize}}]}) 138 | for i = 1, dataBatchSize do 139 | inputs[k] = samples[i]:clone() 140 | k = k + 1 141 | end 142 | targets[{{dataBatchSize+1,opt.batchSize}}]:fill(0) 143 | 144 | optim.sgd(fevalD, parameters_D, sgdState_D) 145 | end -- end for K 146 | 147 | ---------------------------------------------------------------------- 148 | -- (2) Update G network: maximize log(D(G(z))) 149 | noise_inputs:uniform(-1, 1) 150 | for i = 1,opt.batchSize do 151 | local idx = math.random(dataset:size()) 152 | local sample = dataset[idx] 153 | cond_inputs1[i] = sample[2]:clone() 154 | cond_inputs2[i] = sample[3]:clone() 155 | end 156 | targets:fill(1) 157 | optim.sgd(fevalG, parameters_G, sgdState_G) 158 | 159 | -- disp progress 160 | xlua.progress(t, N) 161 | end -- end for loop over dataset 162 | 163 | -- time taken 164 | time = sys.clock() - time 165 | time = time / dataset:size() 166 | print(" time to learn 1 sample = " .. (time*1000) .. 'ms') 167 | 168 | -- print confusion matrix 169 | print(confusion) 170 | trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid * 100} 171 | confusion:zero() 172 | 173 | -- save/log current net 174 | if epoch % opt.saveFreq == 0 then 175 | local filename = paths.concat(opt.save, 'conditional_adversarial.net') 176 | os.execute('mkdir -p ' .. sys.dirname(filename)) 177 | if paths.filep(filename) then 178 | os.execute('mv ' .. filename .. ' ' .. filename .. '.old') 179 | end 180 | print(' saving network to '..filename) 181 | torch.save(filename, {D = model_D, G = model_G, E = model_E, opt = opt}) 182 | end 183 | 184 | -- next epoch 185 | epoch = epoch + 1 186 | end 187 | 188 | -- test function 189 | function adversarial.test(dataset, N) 190 | local time = sys.clock() 191 | local N = N or dataset:size() 192 | 193 | local inputs = torch.Tensor(opt.batchSize, opt.geometry[1], opt.geometry[2], opt.geometry[3]) 194 | local targets = torch.Tensor(opt.batchSize) 195 | local noise_inputs 196 | if type(opt.noiseDim) == 'number' then 197 | noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim) 198 | else 199 | noise_inputs = torch.Tensor(opt.batchSize, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3]) 200 | end 201 | local cond_inputs1 202 | local cond_inputs2 203 | if type(opt.condDim1) == 'number' then 204 | cond_inputs1 = torch.Tensor(opt.batchSize, opt.condDim1) 205 | else 206 | cond_inputs1 = torch.Tensor(opt.batchSize, opt.condDim1[1], opt.condDim1[2], opt.condDim1[3]) 207 | end 208 | if type(opt.condDim2) == 'number' then 209 | cond_inputs2 = torch.Tensor(opt.batchSize, opt.condDim2) 210 | else 211 | cond_inputs2 = torch.Tensor(opt.batchSize, opt.condDim2[1], opt.condDim2[2], opt.condDim2[3]) 212 | end 213 | 214 | 215 | print('\n on testing set:') 216 | for t = 1,N,opt.batchSize do 217 | -- display progress 218 | xlua.progress(t, N) 219 | 220 | ---------------------------------------------------------------------- 221 | -- (1) Real data 222 | local targets = torch.ones(opt.batchSize) 223 | local k = 1 224 | for i = t,math.min(t+opt.batchSize-1,dataset:size()) do 225 | local idx = math.random(dataset:size()) 226 | local sample = dataset[idx] 227 | inputs[k] = sample[1]:clone() 228 | cond_inputs1[k] = sample[2]:clone() 229 | cond_inputs2[k] = sample[3]:clone() 230 | k = k + 1 231 | end 232 | local preds = model_D:forward({inputs, cond_inputs1, cond_inputs2}) -- get predictions from D 233 | -- add to confusion matrix 234 | for i = 1,opt.batchSize do 235 | local c 236 | if preds[i][1] > 0.5 then c = 2 else c = 1 end 237 | confusion:add(c, targets[i] + 1) 238 | end 239 | 240 | ---------------------------------------------------------------------- 241 | -- (2) Generated data (don't need this really, since no 'validation' generations) 242 | noise_inputs:uniform(-1, 1) 243 | local c = 1 244 | for i = 1,opt.batchSize do 245 | sample = dataset[math.random(dataset:size())] 246 | cond_inputs1[i] = sample[2]:clone() 247 | cond_inputs2[i] = sample[3]:clone() 248 | end 249 | local samples = model_G:forward({noise_inputs, cond_inputs1, cond_inputs2}) 250 | local targets = torch.zeros(opt.batchSize) 251 | local preds = model_D:forward({samples, cond_inputs1, cond_inputs2}) -- get predictions from D 252 | -- add to confusion matrix 253 | for i = 1,opt.batchSize do 254 | local c 255 | if preds[i][1] > 0.5 then c = 2 else c = 1 end 256 | confusion:add(c, targets[i] + 1) 257 | end 258 | end -- end loop over dataset 259 | 260 | -- timing 261 | time = sys.clock() - time 262 | time = time / dataset:size() 263 | print(" time to test 1 sample = " .. (time*1000) .. 'ms') 264 | 265 | -- print confusion matrix 266 | print(confusion) 267 | testLogger:add{['% mean class accuracy (test set)'] = confusion.totalValid * 100} 268 | confusion:zero() 269 | 270 | local samples = model_G.output 271 | local to_plot = {} 272 | for i=1,opt.batchSize do 273 | to_plot[i] = samples[i]:float() 274 | end 275 | local fname = paths.concat(opt.save, 'epoch-' .. epoch .. '.png') 276 | torch.setdefaulttensortype('torch.FloatTensor') 277 | image.save(fname, image.toDisplayTensor{input=to_plot, scaleeach=true}) 278 | if opt.gpu then 279 | torch.setdefaulttensortype('torch.CudaTensor') 280 | else 281 | torch.setdefaulttensortype('torch.FloatTensor') 282 | end 283 | return cond_inputs 284 | end 285 | 286 | return adversarial 287 | -------------------------------------------------------------------------------- /cifar/utils/image.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'image' 3 | 4 | local img = {} 5 | 6 | function img.normalize(data, mean_, std_) 7 | local mean = mean_ or data:mean(1) 8 | local std = std_ or data:std(1, true) 9 | local eps = 1e-7 10 | for i=1,data:size(1) do 11 | data[i]:add(-1, mean) 12 | data[i]:cdiv(std + eps) 13 | end 14 | return mean, std 15 | end 16 | 17 | function img.normalizeGlobal(data, mean_, std_) 18 | local std = std_ or data:std() 19 | local mean = mean_ or data:mean() 20 | data:add(-mean) 21 | data:mul(1/std) 22 | return mean, std 23 | end 24 | 25 | function img.contrastNormalize(data, new_min, new_max) 26 | local old_max = data:max(1) 27 | local old_min = data:min(1) 28 | local eps = 1e-7 29 | for i=1,data:size(1) do 30 | data[i]:add(-1, old_min) 31 | data[i]:mul(new_max - new_min) 32 | data[i]:cdiv(old_max - old_min + eps) 33 | data[i]:add(new_min) 34 | end 35 | end 36 | 37 | function img.flip(data, labels) 38 | local n = data:size(1) 39 | local N = n*2 40 | local new_data = torch.Tensor(N, data:size(2), data:size(3), data:size(4)):typeAs(data) 41 | local new_labels = torch.Tensor(N) 42 | new_data[{{1,n}}] = data 43 | new_labels[{{1,n}}] = labels:clone() 44 | new_labels[{{n+1,N}}] = labels:clone() 45 | for i = n+1,N do 46 | new_data[i] = image.hflip(data[i-n]) 47 | end 48 | local rp = torch.LongTensor(N) 49 | rp:randperm(N) 50 | return new_data:index(1, rp), new_labels:index(1, rp) 51 | end 52 | 53 | function img.translate(data, w, labels) 54 | local n = data:size(1) 55 | local N = n*5 56 | local ow = data:size(3) 57 | local new_data = torch.Tensor(N, data:size(2), w, w):typeAs(data) 58 | local new_labels = torch.Tensor(N) 59 | local d = ow - w + 1 60 | local m1 = (ow - w) / 2 + 1 61 | local m2 = ow - ((ow - w) / 2) 62 | local x1 = {1, d, 1, d, m1} 63 | local x2 = {w, ow, w, ow, m2} 64 | local y1 = {1, 1, d, d, m1} 65 | local y2 = {w, w, ow, ow, m2} 66 | local k = 1 67 | for i = 1,n do 68 | for j = 1,5 do 69 | new_data[k] = data[{ i, {}, {y1[j], y2[j]}, {x1[j], x2[j]} }]:clone() 70 | new_labels[k] = labels[i] 71 | k = k + 1 72 | end 73 | end 74 | local rp = torch.LongTensor(N) 75 | rp:randperm(N) 76 | return new_data:index(1, rp), new_labels:index(1, rp) 77 | end 78 | 79 | return img 80 | -------------------------------------------------------------------------------- /lsun/README.md: -------------------------------------------------------------------------------- 1 | 2 | To train your generators using the LSUN database, run the command: 3 | 4 | ``` 5 | DATA_ROOT=[path-to-lsun] th main.lua --dataset lsun 6 | ``` 7 | 8 | 9 | For imagenet, you'll have to [preprocess the dataset as described here.](https://github.com/soumith/imagenet-multiGPU.torch#data-processing) 10 | 11 | To train your generators using the Imagenet database for the first time (for the caches to build up), run the command: 12 | 13 | ``` 14 | DATA_ROOT=[path-to-imagenet] th main.lua --dataset imagenet --nDonkeys 0 15 | ``` 16 | 17 | Subsequent runs can activate the multi-threaded data loader by changing the 18 | command-line option to `--nDonkeys 4` (for 4 threads for example) 19 | 20 | 21 | For more command-line options: 22 | 23 | ``` 24 | th main.lua --help 25 | 26 | 27 | --dataset (default "imagenet") imagenet | lsun 28 | --model (default "large") large | small | autogen 29 | -s,--save (default "imgslogs") subdirectory to save logs 30 | --saveFreq (default 10) save every saveFreq epochs 31 | -n,--network (default "") reload pretrained network 32 | -p,--plot plot while training 33 | -r,--learningRateD (default 0.01) learning rate, for SGD only 34 | --learningRateG (default 0.01) learning rate, for SGD only 35 | -b,--batchSize (default 128) batch size 36 | -m,--momentum (default 0) momentum, for SGD only 37 | -w, --window (default 3) windsow id of sample image 38 | --nDonkeys (default 10) number of data loading threads 39 | --cache (default "cache") folder to cache metadata 40 | --epochSize (default 5000) number of samples per epoch 41 | --nEpochs (default 25) 42 | --coarseSize (default 16) 43 | --scaleUp (default 2) How much to upscale coarseSize 44 | --archgen (default 1) 45 | --scratch (default 0) 46 | --forceDonkeys (default 0) 47 | 48 | ``` 49 | -------------------------------------------------------------------------------- /lsun/data.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local Threads = require 'threads' 11 | 12 | local donkey_file 13 | if opt.dataset == 'imagenet' then 14 | donkey_file = 'donkey_imagenet.lua' 15 | elseif opt.dataset == 'lsun' then 16 | donkey_file = 'donkey_lsun.lua' 17 | -- lmdb complains beyond 6 donkeys. wtf. 18 | if opt.nDonkeys > 6 and opt.forceDonkeys == 0 then opt.nDonkeys = 6 end 19 | else 20 | error('Unknown dataset: ', opt.dataset) 21 | end 22 | 23 | do -- start K datathreads (donkeys) 24 | if opt.nDonkeys > 0 then 25 | local options = opt -- make an upvalue to serialize over to donkey threads 26 | donkeys = Threads( 27 | opt.nDonkeys, 28 | function() 29 | require 'torch' 30 | end, 31 | function(idx) 32 | opt = options -- pass to all donkeys via upvalue 33 | tid = idx 34 | local seed = opt.manualSeed + idx 35 | torch.manualSeed(seed) 36 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) 37 | paths.dofile(donkey_file) 38 | end 39 | ); 40 | else -- single threaded data loading. useful for debugging 41 | paths.dofile(donkey_file) 42 | donkeys = {} 43 | function donkeys:addjob(f1, f2) f2(f1()) end 44 | function donkeys:synchronize() end 45 | end 46 | end 47 | 48 | os.execute('mkdir -p '.. opt.save) 49 | 50 | nTest = 0 51 | donkeys:addjob(function() return testLoader:size() end, function(c) nTest = c end) 52 | donkeys:synchronize() 53 | assert(nTest > 0, "Failed to get nTest") 54 | print('nTest: ', nTest) 55 | 56 | 57 | function sanitize(net) 58 | local list = net:listModules() 59 | for _,val in ipairs(list) do 60 | for name,field in pairs(val) do 61 | if torch.type(field) == 'cdata' then val[name] = nil end 62 | if name == 'homeGradBuffers' then val[name] = nil end 63 | if name == 'input_gpu' then val['input_gpu'] = {} end 64 | if name == 'gradOutput_gpu' then val['gradOutput_gpu'] = {} end 65 | if name == 'gradInput_gpu' then val['gradInput_gpu'] = {} end 66 | if (name == 'output' or name == 'gradInput') then 67 | if torch.isTensor(val[name]) then 68 | val[name] = field.new() 69 | end 70 | end 71 | if name == 'buffer' or name == 'buffer2' or name == 'normalized' 72 | or name == 'centered' or name == 'addBuffer' then 73 | val[name] = nil 74 | end 75 | end 76 | end 77 | return net 78 | end 79 | 80 | function merge_table(t1, t2) 81 | local t = {} 82 | for k,v in pairs(t2) do 83 | t[k] = v 84 | end 85 | for k,v in pairs(t1) do 86 | t[k] = v 87 | end 88 | return t 89 | end 90 | -------------------------------------------------------------------------------- /lsun/donkey_imagenet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'image' 11 | paths.dofile('imagenet.lua') 12 | 13 | -- This file contains the data-loading logic and details. 14 | -- It is run by each data-loader thread. 15 | ------------------------------------------ 16 | -------- COMMON CACHES and PATHS 17 | -- a cache file of the training metadata (if doesnt exist, will be created) 18 | local cache = "cache" 19 | os.execute('mkdir -p cache') 20 | local trainCache = paths.concat(cache, 'trainCache.t7') 21 | local testCache = paths.concat(cache, 'testCache.t7') 22 | local meanstdCache = paths.concat(cache, 'meanstdCache.t7') 23 | 24 | -- Check for existence of opt.data 25 | opt.data = os.getenv('DATA_ROOT') or '/data/local/imagenet-fetch/256' 26 | -------------------------------------------------------------------------------------------- 27 | if not os.execute('cd ' .. opt.data) then 28 | error(("could not chdir to '%s'"):format(opt.data)) 29 | end 30 | 31 | local loadSize = {3, opt.loadSize} 32 | local sampleSize = {3, opt.fineSize} 33 | 34 | local function loadImage(path) 35 | local input = image.load(path, 3, 'float') 36 | -- find the smaller dimension, and resize it to loadSize[2] (while keeping aspect ratio) 37 | local iW = input:size(3) 38 | local iH = input:size(2) 39 | if iW < iH then 40 | input = image.scale(input, loadSize[2], loadSize[2] * iH / iW) 41 | else 42 | input = image.scale(input, loadSize[2] * iW / iH, loadSize[2]) 43 | end 44 | return input 45 | end 46 | 47 | function makeData(fine, label) 48 | local diff = fine.new():resizeAs(fine) 49 | local coarse = fine.new():resizeAs(fine) 50 | local mode = 'bilinear' 51 | -- workaround for bug: https://github.com/torch/image/issues/73 52 | if opt.coarseSize == 1 then mode = 'simple' end 53 | for i=1,opt.batchSize do 54 | local tmp = image.scale(fine[i], opt.coarseSize, opt.coarseSize) 55 | image.scale(coarse[i], tmp, mode) 56 | end 57 | torch.add(diff, fine, -1, coarse) 58 | return {diff, label, coarse, fine} 59 | end 60 | 61 | -- channel-wise mean and std. Calculate or load them from disk later in the script. 62 | local mean,std 63 | -------------------------------------------------------------------------------- 64 | -- Hooks that are used for each image that is loaded 65 | 66 | -- function to load the image, jitter it appropriately (random crops etc.) 67 | local trainHook = function(self, path) 68 | collectgarbage() 69 | local input = loadImage(path) 70 | local iW = input:size(3) 71 | local iH = input:size(2) 72 | 73 | -- do random crop 74 | local oW = sampleSize[2]; 75 | local oH = sampleSize[2] 76 | local h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 77 | local w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 78 | local out = image.crop(input, w1, h1, w1 + oW, h1 + oH) 79 | assert(out:size(2) == oW) 80 | assert(out:size(3) == oH) 81 | -- do hflip with probability 0.5 82 | if torch.uniform() > 0.5 then out = image.hflip(out); end 83 | -- mean/std 84 | for i=1,3 do -- channels 85 | if mean then out[{{i},{},{}}]:add(-mean[i]) end 86 | if std then out[{{i},{},{}}]:div(std[i]) end 87 | end 88 | return out 89 | end 90 | 91 | local testHook = function(self, path) 92 | collectgarbage() 93 | local input = loadImage(path) 94 | local oH = sampleSize[2] 95 | local oW = sampleSize[2]; 96 | local iW = input:size(3) 97 | local iH = input:size(2) 98 | local w1 = math.ceil((iW-oW)/2) 99 | local h1 = math.ceil((iH-oH)/2) 100 | local out = image.crop(input, w1, h1, w1+oW, h1+oW) -- center patch 101 | -- mean/std 102 | for i=1,3 do -- channels 103 | if mean then out[{{i},{},{}}]:add(-mean[i]) end 104 | if std then out[{{i},{},{}}]:div(std[i]) end 105 | end 106 | return out 107 | end 108 | 109 | -------------------------------------- 110 | -- trainLoader 111 | if paths.filep(trainCache) then 112 | print('Loading train metadata from cache') 113 | trainLoader = torch.load(trainCache) 114 | trainLoader.sampleHookTrain = trainHook 115 | trainLoader.loadSize = {3, opt.loadSize, opt.loadSize} 116 | trainLoader.sampleSize = {3, sampleSize[2], sampleSize[2]} 117 | else 118 | print('Creating train metadata') 119 | trainLoader = dataLoader{ 120 | paths = {paths.concat(opt.data, 'train')}, 121 | loadSize = {3, loadSize[2], loadSize[2]}, 122 | sampleSize = {3, sampleSize[2], sampleSize[2]}, 123 | split = 100, 124 | verbose = true 125 | } 126 | torch.save(trainCache, trainLoader) 127 | trainLoader.sampleHookTrain = trainHook 128 | end 129 | collectgarbage() 130 | 131 | -- do some sanity checks on trainLoader 132 | do 133 | local class = trainLoader.imageClass 134 | local nClasses = #trainLoader.classes 135 | assert(class:max() <= nClasses, "class logic has error") 136 | assert(class:min() >= 1, "class logic has error") 137 | 138 | end 139 | 140 | -- testLoader 141 | if paths.filep(testCache) then 142 | print('Loading test metadata from cache') 143 | testLoader = torch.load(testCache) 144 | testLoader.sampleHookTest = testHook 145 | testLoader.loadSize = {3, opt.loadSize, opt.loadSize} 146 | testLoader.sampleSize = {3, sampleSize[2], sampleSize[2]} 147 | else 148 | print('Creating test metadata') 149 | testLoader = dataLoader{ 150 | paths = {paths.concat(opt.data, 'val')}, 151 | loadSize = {3, loadSize[2], loadSize[2]}, 152 | sampleSize = {3, sampleSize[2], sampleSize[2]}, 153 | split = 0, 154 | verbose = true, 155 | forceClasses = trainLoader.classes -- force consistent class indices between trainLoader and testLoader 156 | } 157 | torch.save(testCache, testLoader) 158 | testLoader.sampleHookTest = testHook 159 | end 160 | collectgarbage() 161 | ----------------------------------------- 162 | 163 | -- Estimate the per-channel mean/std (so that the loaders can normalize appropriately) 164 | if paths.filep(meanstdCache) then 165 | local meanstd = torch.load(meanstdCache) 166 | mean = meanstd.mean 167 | std = meanstd.std 168 | print('Loaded mean and std from cache.') 169 | else 170 | local tm = torch.Timer() 171 | local nSamples = 10000 172 | print('Estimating the mean (per-channel, shared for all pixels) over ' .. nSamples .. ' randomly sampled training images') 173 | local meanEstimate = {0,0,0} 174 | for i=1,nSamples do 175 | local img = trainLoader:sample(1)[1] 176 | for j=1,3 do 177 | meanEstimate[j] = meanEstimate[j] + img[j]:mean() 178 | end 179 | end 180 | for j=1,3 do 181 | meanEstimate[j] = meanEstimate[j] / nSamples 182 | end 183 | mean = meanEstimate 184 | 185 | print('Estimating the std (per-channel, shared for all pixels) over ' .. nSamples .. ' randomly sampled training images') 186 | local stdEstimate = {0,0,0} 187 | for i=1,nSamples do 188 | local img = trainLoader:sample(1)[1] 189 | for j=1,3 do 190 | stdEstimate[j] = stdEstimate[j] + img[j]:std() 191 | end 192 | end 193 | for j=1,3 do 194 | stdEstimate[j] = stdEstimate[j] / nSamples 195 | end 196 | std = stdEstimate 197 | 198 | local cache = {} 199 | cache.mean = mean 200 | cache.std = std 201 | torch.save(meanstdCache, cache) 202 | print('Time to estimate:', tm:time().real) 203 | end 204 | -------------------------------------------------------------------------------- /lsun/donkey_lsun.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'image' 11 | tds=require 'tds' 12 | require 'lmdb' 13 | 14 | -- This file contains the data-loading logic and details. 15 | -- It is run by each data-loader thread. 16 | ------------------------------------------ 17 | -------- COMMON CACHES and PATHS 18 | classes = {'bedroom', 'bridge', 'church_outdoor', 'classroom', 'conference_room', 19 | 'dining_room', 'kitchen', 'living_room', 'restaurant', 'tower'} 20 | -- classes = {'church_outdoor'} 21 | table.sort(classes) 22 | print('Classes:') 23 | for k,v in pairs(classes) do 24 | print(k, v) 25 | end 26 | 27 | -- Check for existence of opt.data 28 | opt.data = os.getenv('DATA_ROOT') or os.getenv('HOME') .. '/local/lsun' 29 | if not os.execute('cd ' .. opt.data) then 30 | error(("could not chdir to '%s'"):format(opt.data)) 31 | end 32 | 33 | local meanstdCache = 'meanstdCache_' 34 | for i=1,#classes do -- if you change classes to be a different subset, recompute meanstd 35 | meanstdCache = meanstdCache .. classes[i] .. '_' 36 | end 37 | meanstdCache = paths.concat(opt.data, meanstdCache .. '.t7') 38 | trainPath = paths.concat(opt.data, 'train') 39 | valPath = paths.concat(opt.data, 'val') 40 | 41 | ----------------------------------------------------------------------------------------- 42 | if not os.execute('cd ' .. opt.data) then 43 | error(("could not chdir to '%s'"):format(opt.data)) 44 | end 45 | 46 | local loadSize = {3, opt.loadSize} 47 | local sampleSize = {3, opt.fineSize} 48 | 49 | local function loadImage(blob) 50 | local input = image.decompress(blob, 3, 'float') 51 | -- find the smaller dimension, and resize it to loadSize[2] (while keeping aspect ratio) 52 | local iW = input:size(3) 53 | local iH = input:size(2) 54 | if iW < iH then 55 | input = image.scale(input, loadSize[2], loadSize[2] * iH / iW) 56 | else 57 | input = image.scale(input, loadSize[2] * iW / iH, loadSize[2]) 58 | end 59 | return input 60 | end 61 | 62 | function makeData(fine, label) 63 | local diff = fine.new():resizeAs(fine) 64 | local coarse = fine.new():resizeAs(fine) 65 | local mode = 'bilinear' 66 | -- workaround for bug: https://github.com/torch/image/issues/73 67 | if opt.coarseSize == 1 then mode = 'simple' end 68 | for i=1,opt.batchSize do 69 | local tmp = image.scale(fine[i], opt.coarseSize, opt.coarseSize) 70 | image.scale(coarse[i], tmp, mode) 71 | end 72 | torch.add(diff, fine, -1, coarse) 73 | return {diff, label, coarse, fine} 74 | end 75 | 76 | -- channel-wise mean and std. Calculate or load them from disk later in the script. 77 | local mean,std 78 | -------------------------------------------------------------------------------- 79 | -- Hooks that are used for each image that is loaded 80 | 81 | -- function to load the image, jitter it appropriately (random crops etc.) 82 | local trainHook = function(path) 83 | collectgarbage() 84 | local input = loadImage(path) 85 | local iW = input:size(3) 86 | local iH = input:size(2) 87 | 88 | -- do random crop 89 | local oW = sampleSize[2]; 90 | local oH = sampleSize[2] 91 | local h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 92 | local w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 93 | local out = image.crop(input, w1, h1, w1 + oW, h1 + oH) 94 | assert(out:size(2) == oW) 95 | assert(out:size(3) == oH) 96 | -- do hflip with probability 0.5 97 | if torch.uniform() > 0.5 then out = image.hflip(out); end 98 | -- mean/std 99 | for i=1,3 do -- channels 100 | if mean then out[{{i},{},{}}]:add(-mean[i]) end 101 | if std then out[{{i},{},{}}]:div(std[i]) end 102 | end 103 | return out 104 | end 105 | 106 | local testHook = function(path) 107 | collectgarbage() 108 | local input = loadImage(path) 109 | local oH = sampleSize[2] 110 | local oW = sampleSize[2]; 111 | local iW = input:size(3) 112 | local iH = input:size(2) 113 | local w1 = math.ceil((iW-oW)/2) 114 | local h1 = math.ceil((iH-oH)/2) 115 | local out = image.crop(input, w1, h1, w1+oW, h1+oW) -- center patch 116 | -- mean/std 117 | for i=1,3 do -- channels 118 | if mean then out[{{i},{},{}}]:add(-mean[i]) end 119 | if std then out[{{i},{},{}}]:div(std[i]) end 120 | end 121 | return out 122 | end 123 | -------------------------------------- 124 | -- trainLoader 125 | print('initializing train loader') 126 | trainLoader = {} 127 | trainLoader.classes = classes 128 | trainLoader.indices = {} 129 | trainLoader.db = {} 130 | trainLoader.db_reader = {} 131 | for i=1,#classes do 132 | print('initializing: ', classes[i]) 133 | trainLoader.indices[i] = torch.load(paths.concat(trainPath, classes[i] .. '_train_lmdb.t7')) 134 | trainLoader.db[i] = lmdb.env{Path=paths.concat(trainPath, classes[i] .. '_train_lmdb'), 135 | RDONLY=true, NOLOCK=true, NOTLS=true, NOSYNC=true, NOMETASYNC=true, 136 | MaxReaders=20, MaxDBs=20} 137 | trainLoader.db[i]:open() 138 | trainLoader.db_reader[i] = trainLoader.db[i]:txn(true) 139 | end 140 | 141 | function trainLoader:sample(quantity) 142 | local data = torch.Tensor(quantity, sampleSize[1], sampleSize[2], sampleSize[2]) 143 | local label = torch.Tensor(quantity) 144 | for i=1, quantity do 145 | local class = torch.random(1, #self.classes) 146 | local index = torch.random(1, #self.indices[class]) 147 | local imgblob = self.db_reader[class]:getData(trainLoader.indices[class][index]) 148 | local out = trainHook(imgblob) 149 | data[i]:copy(out) 150 | label[i] = class 151 | end 152 | return data, label 153 | end 154 | 155 | -- testLoader 156 | print('initializing test loader') 157 | testLoader = {} 158 | testLoader.classes = classes 159 | testLoader.indices = {} 160 | testLoader.indicesAllClass = tds.hash() 161 | testLoader.indicesAllClassIndex = tds.hash() 162 | testLoader.db = {} 163 | testLoader.db_reader = {} 164 | for i=1,#classes do 165 | testLoader.indices[i] = torch.load(paths.concat(valPath, classes[i] .. '_val_lmdb.t7')) 166 | for j=1,#testLoader.indices[i] do 167 | testLoader.indicesAllClass[#testLoader.indicesAllClass + 1] = i 168 | testLoader.indicesAllClassIndex[#testLoader.indicesAllClassIndex + 1] = j 169 | end 170 | testLoader.db[i] = lmdb.env{Path=paths.concat(valPath, classes[i] .. '_val_lmdb'), 171 | RDONLY=true, NOLOCK=true, NOTLS=true, NOSYNC=true, NOMETASYNC=true, 172 | MaxReaders=20, MaxDBs=20} 173 | testLoader.db[i]:open() 174 | testLoader.db_reader[i] = testLoader.db[i]:txn(true) 175 | end 176 | 177 | function testLoader:size() 178 | return #self.indicesAllClass 179 | end 180 | 181 | function testLoader:get(i1, i2) 182 | local data = torch.Tensor(i2-i1+1, sampleSize[1], sampleSize[2], sampleSize[2]) 183 | local label = torch.Tensor(i2-i1+1) 184 | for i=i1, i2 do 185 | local class = self.indicesAllClass[i] 186 | local classIndex = self.indicesAllClassIndex[i] 187 | local imgblob = self.db_reader[class]:getData(self.indices[class][classIndex]) 188 | local out = testHook(imgblob) 189 | data[i-i1+1]:copy(out) 190 | label[i-i1+1] = class 191 | end 192 | return data, label 193 | end 194 | collectgarbage() 195 | ----------------------------------------- 196 | 197 | -- Estimate the per-channel mean/std (so that the loaders can normalize appropriately) 198 | if paths.filep(meanstdCache) then 199 | local meanstd = torch.load(meanstdCache) 200 | mean = meanstd.mean 201 | std = meanstd.std 202 | print('Loaded mean and std from cache.') 203 | else 204 | local tm = torch.Timer() 205 | local nSamples = 10000 206 | print('Estimating the mean (per-channel, shared for all pixels) over ' .. nSamples .. ' randomly sampled training images') 207 | local meanEstimate = {0,0,0} 208 | for i=1,nSamples do 209 | local img = trainLoader:sample(1)[1] 210 | for j=1,3 do 211 | meanEstimate[j] = meanEstimate[j] + img[j]:mean() 212 | end 213 | end 214 | for j=1,3 do 215 | meanEstimate[j] = meanEstimate[j] / nSamples 216 | end 217 | mean = meanEstimate 218 | 219 | print('Estimating the std (per-channel, shared for all pixels) over ' .. nSamples .. ' randomly sampled training images') 220 | local stdEstimate = {0,0,0} 221 | for i=1,nSamples do 222 | local img = trainLoader:sample(1)[1] 223 | for j=1,3 do 224 | stdEstimate[j] = stdEstimate[j] + img[j]:std() 225 | end 226 | end 227 | for j=1,3 do 228 | stdEstimate[j] = stdEstimate[j] / nSamples 229 | end 230 | std = stdEstimate 231 | 232 | local cache = {} 233 | cache.mean = mean 234 | cache.std = std 235 | torch.save(meanstdCache, cache) 236 | print('Time to estimate:', tm:time().real) 237 | end 238 | -------------------------------------------------------------------------------- /lsun/imagenet.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'torch' 11 | torch.setdefaulttensortype('torch.FloatTensor') 12 | local ffi = require 'ffi' 13 | local class = require('pl.class') 14 | local dir = require 'pl.dir' 15 | local tablex = require 'pl.tablex' 16 | local argcheck = require 'argcheck' 17 | require 'sys' 18 | require 'xlua' 19 | require 'image' 20 | 21 | local dataset = torch.class('dataLoader') 22 | 23 | local initcheck = argcheck{ 24 | pack=true, 25 | help=[[ 26 | A dataset class for images in a flat folder structure (folder-name is class-name). 27 | Optimized for extremely large datasets (upwards of 14 million images). 28 | Tested only on Linux (as it uses command-line linux utilities to scale up) 29 | ]], 30 | {check=function(paths) 31 | local out = true; 32 | for k,v in ipairs(paths) do 33 | if type(v) ~= 'string' then 34 | print('paths can only be of string input'); 35 | out = false 36 | end 37 | end 38 | return out 39 | end, 40 | name="paths", 41 | type="table", 42 | help="Multiple paths of directories with images"}, 43 | 44 | {name="sampleSize", 45 | type="table", 46 | help="a consistent sample size to resize the images"}, 47 | 48 | {name="split", 49 | type="number", 50 | help="Percentage of split to go to Training" 51 | }, 52 | 53 | {name="samplingMode", 54 | type="string", 55 | help="Sampling mode: random | balanced ", 56 | default = "balanced"}, 57 | 58 | {name="verbose", 59 | type="boolean", 60 | help="Verbose mode during initialization", 61 | default = false}, 62 | 63 | {name="loadSize", 64 | type="table", 65 | help="a size to load the images to, initially", 66 | opt = true}, 67 | 68 | {name="forceClasses", 69 | type="table", 70 | help="If you want this loader to map certain classes to certain indices, " 71 | .. "pass a classes table that has {classname : classindex} pairs." 72 | .. " For example: {3 : 'dog', 5 : 'cat'}" 73 | .. "This function is very useful when you want two loaders to have the same " 74 | .. "class indices (trainLoader/testLoader for example)", 75 | opt = true}, 76 | 77 | {name="sampleHookTrain", 78 | type="function", 79 | help="applied to sample during training(ex: for lighting jitter). " 80 | .. "It takes the image path as input", 81 | opt = true}, 82 | 83 | {name="sampleHookTest", 84 | type="function", 85 | help="applied to sample during testing", 86 | opt = true}, 87 | } 88 | 89 | function dataset:__init(...) 90 | 91 | -- argcheck 92 | local args = initcheck(...) 93 | print(args) 94 | for k,v in pairs(args) do self[k] = v end 95 | 96 | if not self.loadSize then self.loadSize = self.sampleSize; end 97 | 98 | if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end 99 | if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end 100 | 101 | -- find class names 102 | self.classes = {} 103 | local classPaths = {} 104 | if self.forceClasses then 105 | for k,v in pairs(self.forceClasses) do 106 | self.classes[k] = v 107 | classPaths[k] = {} 108 | end 109 | end 110 | local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end 111 | -- loop over each paths folder, get list of unique class names, 112 | -- also store the directory paths per class 113 | -- for each class, 114 | for k,path in ipairs(self.paths) do 115 | local dirs = dir.getdirectories(path); 116 | for k,dirpath in ipairs(dirs) do 117 | local class = paths.basename(dirpath) 118 | local idx = tableFind(self.classes, class) 119 | if not idx then 120 | table.insert(self.classes, class) 121 | idx = #self.classes 122 | classPaths[idx] = {} 123 | end 124 | if not tableFind(classPaths[idx], dirpath) then 125 | table.insert(classPaths[idx], dirpath); 126 | end 127 | end 128 | end 129 | 130 | self.classIndices = {} 131 | for k,v in ipairs(self.classes) do 132 | self.classIndices[v] = k 133 | end 134 | 135 | -- define command-line tools, try your best to maintain OSX compatibility 136 | local wc = 'wc' 137 | local cut = 'cut' 138 | local find = 'find' 139 | if jit.os == 'OSX' then 140 | wc = 'gwc' 141 | cut = 'gcut' 142 | find = 'gfind' 143 | end 144 | ---------------------------------------------------------------------- 145 | -- Options for the GNU find command 146 | local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 147 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 148 | for i=2,#extensionList do 149 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 150 | end 151 | 152 | -- find the image path names 153 | self.imagePath = torch.CharTensor() -- path to each image in dataset 154 | self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 155 | self.classList = {} -- index of imageList to each image of a particular class 156 | self.classListSample = self.classList -- the main list used when sampling data 157 | 158 | print('running "find" on each class directory, and concatenate all' 159 | .. ' those filenames into a single file containing all image paths for a given class') 160 | -- so, generates one file per class 161 | local classFindFiles = {} 162 | for i=1,#self.classes do 163 | classFindFiles[i] = os.tmpname() 164 | end 165 | local combinedFindList = os.tmpname(); 166 | 167 | local tmpfile = os.tmpname() 168 | local tmphandle = assert(io.open(tmpfile, 'w')) 169 | -- iterate over classes 170 | for i, class in ipairs(self.classes) do 171 | -- iterate over classPaths 172 | for j,path in ipairs(classPaths[i]) do 173 | local command = find .. ' "' .. path .. '" ' .. findOptions 174 | .. ' >>"' .. classFindFiles[i] .. '" \n' 175 | tmphandle:write(command) 176 | end 177 | end 178 | io.close(tmphandle) 179 | os.execute('bash ' .. tmpfile) 180 | os.execute('rm -f ' .. tmpfile) 181 | 182 | print('now combine all the files to a single large file') 183 | local tmpfile = os.tmpname() 184 | local tmphandle = assert(io.open(tmpfile, 'w')) 185 | -- concat all finds to a single large file in the order of self.classes 186 | for i=1,#self.classes do 187 | local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n' 188 | tmphandle:write(command) 189 | end 190 | io.close(tmphandle) 191 | os.execute('bash ' .. tmpfile) 192 | os.execute('rm -f ' .. tmpfile) 193 | 194 | --========================================================================== 195 | print('load the large concatenated list of sample paths to self.imagePath') 196 | local maxPathLength = tonumber(sys.fexecute(wc .. " -L '" 197 | .. combinedFindList .. "' |" 198 | .. cut .. " -f1 -d' '")) + 1 199 | local length = tonumber(sys.fexecute(wc .. " -l '" 200 | .. combinedFindList .. "' |" 201 | .. cut .. " -f1 -d' '")) 202 | assert(length > 0, "Could not find any image file in the given input paths") 203 | assert(maxPathLength > 0, "paths of files are length 0?") 204 | self.imagePath:resize(length, maxPathLength):fill(0) 205 | local s_data = self.imagePath:data() 206 | local count = 0 207 | for line in io.lines(combinedFindList) do 208 | ffi.copy(s_data, line) 209 | s_data = s_data + maxPathLength 210 | if self.verbose and count % 10000 == 0 then 211 | xlua.progress(count, length) 212 | end; 213 | count = count + 1 214 | end 215 | 216 | self.numSamples = self.imagePath:size(1) 217 | if self.verbose then print(self.numSamples .. ' samples found.') end 218 | --========================================================================== 219 | print('Updating classList and imageClass appropriately') 220 | self.imageClass:resize(self.numSamples) 221 | local runningIndex = 0 222 | for i=1,#self.classes do 223 | if self.verbose then xlua.progress(i, #(self.classes)) end 224 | local length = tonumber(sys.fexecute(wc .. " -l '" 225 | .. classFindFiles[i] .. "' |" 226 | .. cut .. " -f1 -d' '")) 227 | if length == 0 then 228 | error('Class has zero samples') 229 | else 230 | self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long() 231 | self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i) 232 | end 233 | runningIndex = runningIndex + length 234 | end 235 | 236 | --========================================================================== 237 | -- clean up temporary files 238 | print('Cleaning up temporary files') 239 | local tmpfilelistall = '' 240 | for i=1,#(classFindFiles) do 241 | tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"' 242 | if i % 1000 == 0 then 243 | os.execute('rm -f ' .. tmpfilelistall) 244 | tmpfilelistall = '' 245 | end 246 | end 247 | os.execute('rm -f ' .. tmpfilelistall) 248 | os.execute('rm -f "' .. combinedFindList .. '"') 249 | --========================================================================== 250 | 251 | if self.split == 100 then 252 | self.testIndicesSize = 0 253 | else 254 | print('Splitting training and test sets to a ratio of ' 255 | .. self.split .. '/' .. (100-self.split)) 256 | self.classListTrain = {} 257 | self.classListTest = {} 258 | self.classListSample = self.classListTrain 259 | local totalTestSamples = 0 260 | -- split the classList into classListTrain and classListTest 261 | for i=1,#self.classes do 262 | local list = self.classList[i] 263 | local count = self.classList[i]:size(1) 264 | local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round 265 | local perm = torch.randperm(count) 266 | self.classListTrain[i] = torch.LongTensor(splitidx) 267 | for j=1,splitidx do 268 | self.classListTrain[i][j] = list[perm[j]] 269 | end 270 | if splitidx == count then -- all samples were allocated to train set 271 | self.classListTest[i] = torch.LongTensor() 272 | else 273 | self.classListTest[i] = torch.LongTensor(count-splitidx) 274 | totalTestSamples = totalTestSamples + self.classListTest[i]:size(1) 275 | local idx = 1 276 | for j=splitidx+1,count do 277 | self.classListTest[i][idx] = list[perm[j]] 278 | idx = idx + 1 279 | end 280 | end 281 | end 282 | -- Now combine classListTest into a single tensor 283 | self.testIndices = torch.LongTensor(totalTestSamples) 284 | self.testIndicesSize = totalTestSamples 285 | local tdata = self.testIndices:data() 286 | local tidx = 0 287 | for i=1,#self.classes do 288 | local list = self.classListTest[i] 289 | if list:dim() ~= 0 then 290 | local ldata = list:data() 291 | for j=0,list:size(1)-1 do 292 | tdata[tidx] = ldata[j] 293 | tidx = tidx + 1 294 | end 295 | end 296 | end 297 | end 298 | end 299 | 300 | -- size(), size(class) 301 | function dataset:size(class, list) 302 | list = list or self.classList 303 | if not class then 304 | return self.numSamples 305 | elseif type(class) == 'string' then 306 | return list[self.classIndices[class]]:size(1) 307 | elseif type(class) == 'number' then 308 | return list[class]:size(1) 309 | end 310 | end 311 | 312 | -- getByClass 313 | function dataset:getByClass(class) 314 | local index = math.ceil(torch.uniform() * self.classListSample[class]:nElement()) 315 | local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]])) 316 | return self:sampleHookTrain(imgpath) 317 | end 318 | 319 | -- converts a table of samples (and corresponding labels) to a clean tensor 320 | local function tableToOutput(self, dataTable, scalarTable) 321 | local data, scalarLabels, labels 322 | local quantity = #scalarTable 323 | assert(dataTable[1]:dim() == 3) 324 | data = torch.Tensor(quantity, 325 | self.sampleSize[1], self.sampleSize[2], self.sampleSize[3]) 326 | scalarLabels = torch.LongTensor(quantity):fill(-1111) 327 | for i=1,#dataTable do 328 | data[i]:copy(dataTable[i]) 329 | scalarLabels[i] = scalarTable[i] 330 | end 331 | return data, scalarLabels 332 | end 333 | 334 | -- sampler, samples from the training set. 335 | function dataset:sample(quantity) 336 | assert(quantity) 337 | local dataTable = {} 338 | local scalarTable = {} 339 | for i=1,quantity do 340 | local class = torch.random(1, #self.classes) 341 | local out = self:getByClass(class) 342 | table.insert(dataTable, out) 343 | table.insert(scalarTable, class) 344 | end 345 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) 346 | return data, scalarLabels 347 | end 348 | 349 | function dataset:get(i1, i2) 350 | local indices = torch.range(i1, i2); 351 | local quantity = i2 - i1 + 1; 352 | assert(quantity > 0) 353 | -- now that indices has been initialized, get the samples 354 | local dataTable = {} 355 | local scalarTable = {} 356 | for i=1,quantity do 357 | -- load the sample 358 | local imgpath = ffi.string(torch.data(self.imagePath[indices[i]])) 359 | local out = self:sampleHookTest(imgpath) 360 | table.insert(dataTable, out) 361 | table.insert(scalarTable, self.imageClass[indices[i]]) 362 | end 363 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) 364 | return data, scalarLabels 365 | end 366 | 367 | return dataset 368 | -------------------------------------------------------------------------------- /lsun/layers/SpatialConvolutionUpsample.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | local SpatialConvolutionUpsample, parent = torch.class('nn.SpatialConvolutionUpsample','nn.SpatialConvolution') 11 | 12 | function SpatialConvolutionUpsample:__init(nInputPlane, nOutputPlane, kW, kH, factor) 13 | factor = factor or 2 14 | assert(kW and kH and nInputPlane and nOutputPlane) 15 | assert(kW % 2 == 1, 'kW has to be odd') 16 | assert(kH % 2 == 1, 'kH has to be odd') 17 | self.factor = factor 18 | self.kW = kW 19 | self.kH = kH 20 | self.nInputPlaneU = nInputPlane 21 | self.nOutputPlaneU = nOutputPlane 22 | parent.__init(self, nInputPlane, nOutputPlane * factor * factor, kW, kH, 1, 1, (kW-1)/2) 23 | end 24 | 25 | function SpatialConvolutionUpsample:updateOutput(input) 26 | self.output = parent.updateOutput(self, input) 27 | if input:dim() == 4 then 28 | self.h = input:size(3) 29 | self.w = input:size(4) 30 | self.output = self.output:view(input:size(1), self.nOutputPlaneU, self.h*self.factor, self.w*self.factor) 31 | else 32 | self.h = input:size(2) 33 | self.w = input:size(3) 34 | self.output = self.output:view(self.nOutputPlaneU, self.h*self.factor, self.w*self.factor) 35 | end 36 | return self.output 37 | end 38 | 39 | function SpatialConvolutionUpsample:updateGradInput(input, gradOutput) 40 | if input:dim() == 4 then 41 | gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 42 | else 43 | gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 44 | end 45 | self.gradInput = parent.updateGradInput(self, input, gradOutput) 46 | return self.gradInput 47 | end 48 | 49 | function SpatialConvolutionUpsample:accGradParameters(input, gradOutput, scale) 50 | if input:dim() == 4 then 51 | gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 52 | else 53 | gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 54 | end 55 | parent.accGradParameters(self, input, gradOutput, scale) 56 | end 57 | 58 | function SpatialConvolutionUpsample:accUpdateGradParameters(input, gradOutput, scale) 59 | if input:dim() == 4 then 60 | gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 61 | else 62 | gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 63 | end 64 | parent.accUpdateGradParameters(self, input, gradOutput, scale) 65 | end 66 | -------------------------------------------------------------------------------- /lsun/layers/cudnnSpatialConvolutionUpsample.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'cudnn' 11 | local SpatialConvolutionUpsample, parent = torch.class('cudnn.SpatialConvolutionUpsample','cudnn.SpatialConvolution') 12 | 13 | function SpatialConvolutionUpsample:__init(nInputPlane, nOutputPlane, kW, kH, factor, groups) 14 | factor = factor or 2 15 | assert(kW and kH and nInputPlane and nOutputPlane) 16 | assert(kW % 2 == 1, 'kW has to be odd') 17 | assert(kH % 2 == 1, 'kH has to be odd') 18 | self.factor = factor 19 | self.kW = kW 20 | self.kH = kH 21 | self.nInputPlaneU = nInputPlane 22 | self.nOutputPlaneU = nOutputPlane 23 | parent.__init(self, nInputPlane, nOutputPlane * factor * factor, 24 | kW, kH, 1, 1, (kW-1)/2, (kH-1)/2, groups) 25 | end 26 | 27 | function SpatialConvolutionUpsample:updateOutput(input) 28 | self.output = parent.updateOutput(self, input) 29 | if input:dim() == 4 then 30 | self.h = input:size(3) 31 | self.w = input:size(4) 32 | self.output = self.output:view(input:size(1), self.nOutputPlaneU, self.h*self.factor, self.w*self.factor) 33 | else 34 | self.h = input:size(2) 35 | self.w = input:size(3) 36 | self.output = self.output:view(self.nOutputPlaneU, self.h*self.factor, self.w*self.factor) 37 | end 38 | return self.output 39 | end 40 | 41 | function SpatialConvolutionUpsample:updateGradInput(input, gradOutput) 42 | if input:dim() == 4 then 43 | gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 44 | else 45 | gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 46 | end 47 | self.gradInput = parent.updateGradInput(self, input, gradOutput) 48 | return self.gradInput 49 | end 50 | 51 | function SpatialConvolutionUpsample:accGradParameters(input, gradOutput, scale) 52 | if input:dim() == 4 then 53 | gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 54 | else 55 | gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 56 | end 57 | parent.accGradParameters(self, input, gradOutput, scale) 58 | end 59 | 60 | function SpatialConvolutionUpsample:accUpdateGradParameters(input, gradOutput, scale) 61 | if input:dim() == 4 then 62 | gradOutput = gradOutput:view(input:size(1), self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 63 | else 64 | gradOutput = gradOutput:view(self.nOutputPlaneU*self.factor*self.factor, self.h, self.w) 65 | end 66 | parent.accUpdateGradParameters(self, input, gradOutput, scale) 67 | end 68 | -------------------------------------------------------------------------------- /lsun/layers/test_cudnn.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'cudnn' 11 | require 'cunn' 12 | paths.dofile('cudnnSpatialConvolutionUpsample.lua') 13 | paths.dofile('SpatialConvolutionUpsample.lua') 14 | 15 | local cudnntest = {} 16 | local precision_forward = 1e-4 17 | local precision_backward = 1e-2 18 | local precision_jac = 1e-3 19 | local nloop = 1 20 | local times = {} 21 | local mytester 22 | 23 | function cudnntest.SpatialConvolution_forward_batch() 24 | local bs = math.random(1,32) 25 | local from = math.random(1,32) 26 | local to = math.random(1,64) 27 | local ki = math.random(1,5) * 2 - 1 28 | local kj = ki 29 | local scale = math.random(1,2) 30 | local outi = math.random(1,32) * 2 31 | local outj = outi 32 | local ini = outi / scale 33 | local inj = outj / scale 34 | local input = torch.randn(bs,from,inj,ini):cuda() 35 | local sconv = nn.SpatialConvolutionUpsample(from,to,ki,kj,scale):cuda() 36 | local groundtruth = sconv:forward(input) 37 | cutorch.synchronize() 38 | local gconv = cudnn.SpatialConvolutionUpsample(from,to,ki,kj,scale):cuda() 39 | gconv.weight:copy(sconv.weight) 40 | gconv.bias:copy(sconv.bias) 41 | local rescuda = gconv:forward(input) 42 | cutorch.synchronize() 43 | local error = rescuda:float() - groundtruth:float() 44 | mytester:assertlt(error:abs():max(), precision_forward, 'error on state (forward) ') 45 | end 46 | 47 | function cudnntest.SpatialConvolution_backward_batch() 48 | local bs = math.random(1,32) 49 | local from = math.random(1,32) * 2 50 | local to = math.random(1,64) * 2 51 | local ki = math.random(1,5) * 2 - 1 52 | local kj = ki 53 | local scale = math.random(1,2) 54 | local outi = math.random(ki,32) * 2 55 | local outj = outi 56 | local ini = outi / scale 57 | local inj = outj / scale 58 | print(bs, from, to, inj, ini, outj, outi, scale, ki, kj) 59 | 60 | local input = torch.randn(bs,from,inj,ini):cuda() 61 | local gradOutput = torch.randn(bs,to,outj,outi):cuda() 62 | local sconv = nn.SpatialConvolutionUpsample(from,to,ki,kj,scale):cuda() 63 | sconv:forward(input) 64 | sconv:zeroGradParameters() 65 | local groundgrad = sconv:backward(input, gradOutput) 66 | cutorch.synchronize() 67 | local groundweight = sconv.gradWeight 68 | local groundbias = sconv.gradBias 69 | 70 | local gconv = cudnn.SpatialConvolutionUpsample(from,to,ki,kj,scale):cuda() 71 | gconv.weight:copy(sconv.weight) 72 | gconv.bias:copy(sconv.bias) 73 | gconv:forward(input) 74 | 75 | -- serialize and deserialize 76 | torch.save('modelTemp.t7', gconv) 77 | gconv = torch.load('modelTemp.t7') 78 | 79 | gconv:forward(input) 80 | gconv:zeroGradParameters() 81 | local rescuda = gconv:backward(input, gradOutput) 82 | cutorch.synchronize() 83 | local weightcuda = gconv.gradWeight 84 | local biascuda = gconv.gradBias 85 | 86 | local error = rescuda:float() - groundgrad:float() 87 | local werror = weightcuda:float() - groundweight:float() 88 | local berror = biascuda:float() - groundbias:float() 89 | 90 | mytester:assertlt(error:abs():max(), precision_backward, 'error on state (backward) ') 91 | mytester:assertlt(werror:abs():max(), precision_backward, 'error on weight (backward) ') 92 | mytester:assertlt(berror:abs():max(), precision_backward, 'error on bias (backward) ') 93 | end 94 | 95 | torch.setdefaulttensortype('torch.FloatTensor') 96 | math.randomseed(os.time()) 97 | mytester = torch.Tester() 98 | mytester:add(cudnntest) 99 | 100 | mytester:run() 101 | 102 | os.execute('rm -f modelTemp.t7') 103 | -------------------------------------------------------------------------------- /lsun/layers/test_upsampler.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'nn' 11 | paths.dofile('../layers/SpatialConvolutionUpsample.lua') 12 | local mytester = torch.Tester() 13 | local jac 14 | local precision = 1e-5 15 | 16 | nntest = {} 17 | function nntest.SpatialConvolutionUpsample() 18 | local from = math.random(1,5) 19 | local to = math.random(1,5) 20 | local ki = 3 21 | local kj = 3 22 | local outi = 16 23 | local outj = 16 24 | local factor = 2 25 | local ini = outi / factor 26 | local inj = outj / factor 27 | local module = nn.SpatialConvolutionUpsample(from, to, ki, kj, factor) 28 | local input = torch.Tensor(from, inj, ini):zero() 29 | 30 | -- stochastic 31 | 32 | local err = jac.testJacobian(module, input) 33 | mytester:assertlt(err, precision, 'error on state ') 34 | 35 | local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) 36 | mytester:assertlt(err , precision, 'error on weight ') 37 | 38 | local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) 39 | mytester:assertlt(err , precision, 'error on bias ') 40 | 41 | local err = jac.testJacobianUpdateParameters(module, input, module.weight) 42 | mytester:assertlt(err , precision, 'error on weight [direct update] ') 43 | 44 | local err = jac.testJacobianUpdateParameters(module, input, module.bias) 45 | mytester:assertlt(err , precision, 'error on bias [direct update] ') 46 | 47 | for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do 48 | mytester:assertlt(err, precision, string.format( 49 | 'error on weight [%s]', t)) 50 | end 51 | 52 | for t,err in pairs(jac.testAllUpdate(module, input, 'bias', 'gradBias')) do 53 | mytester:assertlt(err, precision, string.format( 54 | 'error on bias [%s]', t)) 55 | end 56 | 57 | -- batch 58 | 59 | local batch = math.random(2,5) 60 | module = nn.SpatialConvolutionUpsample(from, to, ki, kj, factor) 61 | input = torch.Tensor(batch,from,inj,ini):zero() 62 | 63 | local err = jac.testJacobian(module, input) 64 | mytester:assertlt(err, precision, 'batch error on state ') 65 | 66 | local err = jac.testJacobianParameters(module, input, module.weight, module.gradWeight) 67 | mytester:assertlt(err , precision, 'batch error on weight ') 68 | 69 | local err = jac.testJacobianParameters(module, input, module.bias, module.gradBias) 70 | mytester:assertlt(err , precision, 'batch error on bias ') 71 | 72 | local err = jac.testJacobianUpdateParameters(module, input, module.weight) 73 | mytester:assertlt(err , precision, 'batch error on weight [direct update] ') 74 | 75 | local err = jac.testJacobianUpdateParameters(module, input, module.bias) 76 | mytester:assertlt(err , precision, 'batch error on bias [direct update] ') 77 | 78 | for t,err in pairs(jac.testAllUpdate(module, input, 'weight', 'gradWeight')) do 79 | mytester:assertlt(err, precision, string.format( 80 | 'error on weight [%s]', t)) 81 | end 82 | 83 | for t,err in pairs(jac.testAllUpdate(module, input, 'bias', 'gradBias')) do 84 | mytester:assertlt(err, precision, string.format( 85 | 'batch error on bias [%s]', t)) 86 | end 87 | 88 | local ferr, berr = jac.testIO(module, input) 89 | mytester:asserteq(0, ferr, torch.typename(module) .. ' - i/o forward err ') 90 | mytester:asserteq(0, berr, torch.typename(module) .. ' - i/o backward err ') 91 | end 92 | 93 | mytester:add(nntest) 94 | 95 | require 'nn' 96 | jac = nn.Jacobian 97 | mytester:run() 98 | -------------------------------------------------------------------------------- /lsun/main.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'torch' 11 | require 'cunn' 12 | require 'optim' 13 | require 'image' 14 | require 'paths' 15 | local pl=require 'pl' 16 | 17 | ---------------------------------------------------------------------- 18 | -- parse command-line options 19 | opt = lapp[[ 20 | --dataset (default "imagenet") imagenet | lsun 21 | --model (default "large") large | small | autogen 22 | -s,--save (default "imgslogs") subdirectory to save logs 23 | --saveFreq (default 10) save every saveFreq epochs 24 | -n,--network (default "") reload pretrained network 25 | -p,--plot plot while training 26 | -r,--learningRateD (default 0.01) learning rate, for SGD only 27 | --learningRateG (default 0.01) learning rate, for SGD only 28 | -b,--batchSize (default 128) batch size 29 | -m,--momentum (default 0) momentum, for SGD only 30 | -w, --window (default 3) windsow id of sample image 31 | --nDonkeys (default 10) number of data loading threads 32 | --cache (default "cache") folder to cache metadata 33 | --epochSize (default 5000) number of samples per epoch 34 | --nEpochs (default 25) 35 | --coarseSize (default 16) 36 | --scaleUp (default 2) How much to upscale coarseSize 37 | --archgen (default 1) 38 | --scratch (default 0) 39 | --forceDonkeys (default 0) 40 | ]] 41 | 42 | print(opt) 43 | 44 | opt.manualSeed = torch.random(1,10000) -- fix seed 45 | print("Seed: " .. opt.manualSeed) 46 | torch.manualSeed(opt.manualSeed) 47 | torch.setnumthreads(1) 48 | torch.setdefaulttensortype('torch.FloatTensor') 49 | 50 | opt.fineSize = opt.coarseSize * opt.scaleUp 51 | opt.loadSize = math.ceil(opt.coarseSize * (3 * opt.scaleUp / 2)) 52 | opt.noiseDim = {1, opt.fineSize, opt.fineSize} 53 | classes = {'0','1'} 54 | opt.geometry = {3, opt.fineSize, opt.fineSize} 55 | opt.condDim = {3, opt.fineSize, opt.fineSize} 56 | opt.run_id = math.random(1,10000000) 57 | 58 | paths.dofile('model.lua') 59 | paths.dofile('data.lua') 60 | adversarial = paths.dofile('train.lua') 61 | 62 | -- this matrix records the current confusion among real/fake classes 63 | confusion = optim.ConfusionMatrix(2) 64 | 65 | -- Training parameters 66 | sgdState_D = { 67 | learningRate = opt.learningRateD, 68 | momentum = opt.momentum 69 | } 70 | 71 | sgdState_G = { 72 | learningRate = opt.learningRateG, 73 | momentum = opt.momentum 74 | } 75 | 76 | local function train() 77 | print("Training epoch: " .. epoch) 78 | confusion:zero() 79 | model_D:training() 80 | model_G:training() 81 | batchNumber = 0 82 | for i=1,opt.epochSize do 83 | donkeys:addjob( 84 | function() 85 | return makeData(trainLoader:sample(opt.batchSize)), 86 | makeData(trainLoader:sample(opt.batchSize)) 87 | end, 88 | adversarial.train) 89 | end 90 | donkeys:synchronize() 91 | cutorch.synchronize() 92 | print(confusion) 93 | tr_acc0 = confusion.valids[1] * 100 94 | tr_acc1 = confusion.valids[2] * 100 95 | if tr_acc0 ~= tr_acc0 then tr_acc0 = 0 end 96 | if tr_acc1 ~= tr_acc1 then tr_acc1 = 0 end 97 | end 98 | 99 | local function test() 100 | print("Testing epoch: " .. epoch) 101 | confusion:zero() 102 | model_D:evaluate() 103 | model_G:evaluate() 104 | for i=1,nTest/opt.batchSize do -- nTest is set in data.lua 105 | -- xlua.progress(i, math.floor(nTest/opt.batchSize)) 106 | local indexStart = (i-1) * opt.batchSize + 1 107 | local indexEnd = (indexStart + opt.batchSize - 1) 108 | donkeys:addjob(function() return makeData(testLoader:get(indexStart, indexEnd)) end, 109 | adversarial.test) 110 | end 111 | donkeys:synchronize() 112 | cutorch.synchronize() 113 | print(confusion) 114 | ts_acc0 = confusion.valids[1] * 100 115 | ts_acc1 = confusion.valids[2] * 100 116 | if ts_acc0 ~= ts_acc0 then ts_acc0 = 0 end 117 | if ts_acc1 ~= ts_acc1 then ts_acc1 = 0 end 118 | end 119 | 120 | local function plot(N) 121 | local N = N or 16 122 | N = math.min(N, opt.batchSize) 123 | local offset = 1000 124 | if opt.dataset == 'lsun' then 125 | offset = math.floor((nTest - opt.batchSize - 1)/16) 126 | end 127 | assert((N * offset) < (nTest - opt.batchSize - 1)) 128 | local noise_inputs = torch.CudaTensor(N, opt.noiseDim[1], opt.noiseDim[2], opt.noiseDim[3]) 129 | local cond_inputs = torch.CudaTensor(N, opt.condDim[1], opt.condDim[2], opt.condDim[3]) 130 | local gt = torch.CudaTensor(N, 3, opt.condDim[2], opt.condDim[3]) 131 | 132 | -- Generate samples 133 | noise_inputs:uniform(-1, 1) 134 | for i=0,N-1 do 135 | local indexStart = i * offset + 1 136 | local indexEnd = (indexStart + opt.batchSize - 1) 137 | donkeys:addjob( 138 | function() return makeData(testLoader:get(indexStart, indexEnd)) end, 139 | function(d) cond_inputs[i+1]:copy(d[3][1]); gt[i+1]:copy(d[4][1]); end 140 | ) 141 | donkeys:synchronize() 142 | end 143 | local finputs = {noise_inputs, cond_inputs} 144 | if opt.scratch == 1 then 145 | finputs = noise_inputs 146 | end 147 | local samples = model_G:forward(finputs) 148 | 149 | local to_plot = {} 150 | for i=1,N do 151 | local pred = torch.add(cond_inputs[i]:float(), samples[i]:float()) 152 | to_plot[#to_plot+1] = gt[i]:float() 153 | to_plot[#to_plot+1] = pred 154 | to_plot[#to_plot+1] = cond_inputs[i]:float() 155 | to_plot[#to_plot+1] = samples[i]:float() 156 | end 157 | to_plot = image.toDisplayTensor{input=to_plot, scaleeach=true, nrow=8} 158 | if opt.coarseSize < 32 then 159 | to_plot = image.scale(to_plot, to_plot:size(2) * 32 / opt.coarseSize, 160 | to_plot:size(3) * 32 / opt.coarseSize) 161 | end 162 | image.save(opt.save .. '/' .. 'gen_' .. epoch .. '.png', to_plot) 163 | if opt.plot then 164 | local disp = require 'display' 165 | disp.image(to_plot, {win=opt.window, width=600}) 166 | end 167 | 168 | end 169 | 170 | os.execute('mkdir -p ' .. opt.save) 171 | 172 | epoch = 1 173 | while epoch < opt.nEpochs do 174 | train() 175 | test() 176 | torch.save(opt.save .. '/' .. 'model_' .. epoch .. '.t7', 177 | {D = sanitize(model_D), G = sanitize(model_G)}) 178 | print(merge_table({epoch = opt.epoch, 179 | tr_acc0 = tr_acc0, 180 | tr_acc1 = tr_acc1, 181 | ts_acc0 = ts_acc0, 182 | ts_acc1 = ts_acc1, 183 | desc_D = desc_D, 184 | desc_G = desc_G, 185 | }, opt)) 186 | sgdState_D.momentum = math.min(sgdState_D.momentum + 0.0008, 0.7) 187 | sgdState_D.learningRate = math.max(sgdState_D.learningRate / 1.000004, 0.000001) 188 | sgdState_G.momentum = math.min(sgdState_G.momentum + 0.0008, 0.7) 189 | sgdState_G.learningRate = math.max(sgdState_G.learningRate / 1.000004, 0.000001) 190 | 191 | plot(16) 192 | epoch = epoch + 1 193 | end 194 | -------------------------------------------------------------------------------- /lsun/model.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | paths.dofile('modelGenerator/modelGen.lua') 11 | ---------------------------------------------------------------------- 12 | local freeParams = function(m) 13 | local list = m:listModules() 14 | local p = 0 15 | for k,v in pairs(list) do 16 | p = p + (v.weight and v.weight:nElement() or 0) 17 | p = p + (v.bias and v.bias:nElement() or 0) 18 | end 19 | return p 20 | end 21 | ---------------------------------------------------------------------- 22 | if opt.network ~= '' then 23 | print(' reloading previously trained network: ' .. opt.network) 24 | local tmp = torch.load(opt.network) 25 | model_D = tmp.D 26 | model_G = tmp.G 27 | print('Discriminator network:') 28 | print(model_D) 29 | print('Generator network:') 30 | print(model_G) 31 | elseif opt.model == 'small' then 32 | local nplanes = 64 33 | model_D = nn.Sequential() 34 | model_D:add(nn.CAddTable()) 35 | model_D:add(nn.SpatialConvolution(3, nplanes, 5, 5)) --28 x 28 36 | model_D:add(nn.ReLU()) 37 | model_D:add(nn.SpatialConvolution(nplanes, nplanes, 5, 5, 2, 2)) 38 | local sz = math.floor( ( (opt.fineSize - 5 + 1) - 5) / 2 + 1) 39 | model_D:add(nn.View(nplanes*sz*sz)) 40 | model_D:add(nn.ReLU()) 41 | model_D:add(nn.Linear(nplanes*sz*sz, 1)) 42 | model_D:add(nn.Sigmoid()) 43 | local nplanes = 128 44 | model_G = nn.Sequential() 45 | model_G:add(nn.JoinTable(2, 2)) 46 | model_G:add(cudnn.SpatialConvolutionUpsample(3+1, nplanes, 7, 7, 1)) 47 | model_G:add(nn.ReLU()) 48 | model_G:add(cudnn.SpatialConvolutionUpsample(nplanes, nplanes, 7, 7, 1)) 49 | model_G:add(nn.ReLU()) 50 | model_G:add(cudnn.SpatialConvolutionUpsample(nplanes, 3, 5, 5, 1)) 51 | model_G:add(nn.View(opt.geometry[1], opt.geometry[2], opt.geometry[3])) 52 | elseif opt.model == 'large' then 53 | require 'fbcunn' 54 | print('Generator network (good):') 55 | desc_G = '___JT22___C_4_64_g1_7x7___R__BN___C_64_368_g4_7x7___R__BN___SDrop 0.5___C_368_128_g4_7x7___R__BN___P_LPOut_2___C_64_224_g2_5x5___R__BN___SDrop 0.5___C_224_3_g1_7x7__BNA' 56 | model_G = nn.Sequential() 57 | model_G:add(nn.JoinTable(2, 2)) 58 | model_G:add(cudnn.SpatialConvolutionUpsample(3+1, 64, 7, 7, 1, 1)):add(cudnn.ReLU(true)) 59 | model_G:add(nn.SpatialBatchNormalization(64, nil, nil, false)) 60 | model_G:add(cudnn.SpatialConvolutionUpsample(64, 368, 7, 7, 1, 4)):add(cudnn.ReLU(true)) 61 | model_G:add(nn.SpatialBatchNormalization(368, nil, nil, false)) 62 | model_G:add(nn.SpatialDropout(0.5)) 63 | model_G:add(cudnn.SpatialConvolutionUpsample(368, 128, 7, 7, 1, 4)):add(cudnn.ReLU(true)) 64 | model_G:add(nn.SpatialBatchNormalization(128, nil, nil, false)) 65 | model_G:add(nn.FeatureLPPooling(2,2,2,true)) 66 | model_G:add(cudnn.SpatialConvolutionUpsample(64, 224, 5, 5, 1, 2)):add(cudnn.ReLU(true)) 67 | model_G:add(nn.SpatialBatchNormalization(224, nil, nil, false)) 68 | model_G:add(nn.SpatialDropout(0.5)) 69 | model_G:add(cudnn.SpatialConvolutionUpsample(224, 3, 7, 7, 1, 1)) 70 | model_G:add(nn.SpatialBatchNormalization(3, nil, nil, false)) 71 | model_G:add(nn.View(opt.geometry[1], opt.geometry[2], opt.geometry[3])) 72 | print(desc_G) 73 | 74 | desc_D = '___CAdd___C_3_48_g1_3x3___R___C_48_448_g4_5x5___R___C_448_416_g16_7x7___R___V_166400___L 166400_1___Sig' 75 | model_D = nn.Sequential() 76 | model_D:add(nn.CAddTable()) 77 | model_D:add(cudnn.SpatialConvolution(3, 48, 3, 3)) 78 | model_D:add(cudnn.ReLU(true)) 79 | model_D:add(cudnn.SpatialConvolution(48, 448, 5, 5, 1, 1, 0, 0, 4)) 80 | model_D:add(cudnn.ReLU(true)) 81 | model_D:add(cudnn.SpatialConvolution(448, 416, 7, 7, 1, 1, 0, 0, 16)) 82 | model_D:add(cudnn.ReLU()) 83 | model_D:cuda() 84 | local dummy_input = torch.zeros(opt.batchSize, 3, opt.fineSize, opt.fineSize):cuda() 85 | local out = model_D:forward({dummy_input, dummy_input}) 86 | local nElem = out:nElement() / opt.batchSize 87 | model_D:add(nn.View(nElem):setNumInputDims(3)) 88 | model_D:add(nn.Linear(nElem, 1)) 89 | model_D:add(nn.Sigmoid()) 90 | model_D:cuda() 91 | print(desc_D) 92 | elseif opt.model == 'autogen' then 93 | -- define G network to train 94 | print('Generator network:') 95 | model_G,desc_G = generateModelG(3,5,128,512,3,7, 'mixed', 0, 4, 2, true) 96 | model_G:add(nn.View(opt.geometry[1], opt.geometry[2], opt.geometry[3])) 97 | print(desc_G) 98 | print(model_G) 99 | local trygen = 1 100 | local poolType = 'none' 101 | -- if torch.random(1,2) == 1 then poolType = 'none' end 102 | repeat 103 | trygen = trygen + 1 104 | if trygen == 500 then error('Could not find a good D model') end 105 | -- define D network to train 106 | print('Discriminator network:') 107 | model_D,desc_D = generateModelD(2,6,64,512,3,7, poolType, 0, 4, 2) 108 | print(desc_D) 109 | print(model_D) 110 | until (freeParams(model_D) < freeParams(model_G)) 111 | and (freeParams(model_D) > freeParams(model_G) / 10) 112 | elseif opt.model == 'full' or opt.model == 'fullgen' then 113 | local nhid = 1024 114 | local nhidlayers = 2 115 | local batchnorm = 1 -- disabled 116 | if opt.model == 'fullgen' then 117 | nhidlayers = torch.random(1,5) 118 | nhid = torch.random(8, 128) * 16 119 | batchnorm = torch.random(1,2) 120 | end 121 | desc_G = '' 122 | model_G = nn.Sequential() 123 | model_G:add(nn.JoinTable(2, 2)) 124 | desc_G = desc_G .. '___JT22' 125 | model_G:add(nn.View(4 * opt.fineSize * opt.fineSize):setNumInputDims(3)) 126 | desc_G = desc_G .. '___V_' .. 4 * opt.fineSize * opt.fineSize 127 | model_G:add(nn.Linear(4 * opt.fineSize * opt.fineSize, nhid)):add(nn.ReLU()) 128 | desc_G = desc_G .. '___L ' .. 4 * opt.fineSize * opt.fineSize .. '_' .. nhid 129 | desc_G = desc_G .. '__R' 130 | if batchnorm == 2 then 131 | model_G:add(nn.BatchNormalization(nhid), nil, nil, true) 132 | desc_G = desc_G .. '__BNA' 133 | end 134 | model_G:add(nn.Dropout(0.5)) 135 | desc_G = desc_G .. '__Drop' .. 0.5 136 | for i=1,nhidlayers do 137 | model_G:add(nn.Linear(nhid, nhid)):add(nn.ReLU()) 138 | desc_G = desc_G .. '___L ' .. nhid .. '_' .. nhid 139 | desc_G = desc_G .. '__R' 140 | if batchnorm == 2 then 141 | model_G:add(nn.BatchNormalization(nhid), nil, nil, true) 142 | desc_G = desc_G .. '__BNA' 143 | end 144 | model_G:add(nn.Dropout(0.5)) 145 | desc_G = desc_G .. '__Drop' .. 0.5 146 | end 147 | model_G:add(nn.Linear(nhid, opt.geometry[1]*opt.geometry[2]*opt.geometry[3])) 148 | desc_G = desc_G .. '___L ' .. nhid .. '_' .. opt.geometry[1]*opt.geometry[2]*opt.geometry[3] 149 | if batchnorm == 2 then 150 | model_G:add(nn.BatchNormalization(opt.geometry[1]*opt.geometry[2]*opt.geometry[3])) 151 | desc_G = desc_G .. '__BNA' 152 | end 153 | model_G:add(nn.View(opt.geometry[1], opt.geometry[2], opt.geometry[3])) 154 | desc_G = desc_G .. '___V_' .. opt.geometry[1] .. '_' .. opt.geometry[2] .. '_' .. opt.geometry[3] 155 | print(desc_G) 156 | print(model_G) 157 | 158 | nhid = nhid / 2 159 | desc_D = '' 160 | model_D = nn.Sequential() 161 | model_D:add(nn.CAddTable()) 162 | desc_D = desc_D .. '___CAdd' 163 | model_D:add(nn.View(opt.geometry[1]* opt.geometry[2]* opt.geometry[3])) 164 | desc_D = desc_D .. '___V_' .. opt.geometry[1]* opt.geometry[2]* opt.geometry[3] 165 | model_D:add(nn.Linear(opt.geometry[1]* opt.geometry[2]* opt.geometry[3], nhid)):add(nn.ReLU()) 166 | desc_D = desc_D .. '___L ' .. opt.geometry[1]* opt.geometry[2]* opt.geometry[3] .. '_' .. nhid 167 | desc_D = desc_D .. '__R' 168 | for i=1,nhidlayers do 169 | model_D:add(nn.Linear(nhid, nhid)):add(nn.ReLU()) 170 | desc_D = desc_D .. '___L ' .. nhid .. '_' .. nhid 171 | desc_D = desc_D .. '__R' 172 | model_D:add(nn.Dropout(0.5)) 173 | desc_D = desc_D .. '__Drop' .. 0.5 174 | end 175 | model_D:add(nn.Linear(nhid, 1)) 176 | desc_D = desc_D .. '___L ' .. nhid .. '_' .. 1 177 | model_D:add(nn.Sigmoid()) 178 | desc_D = desc_D .. '__Sig' 179 | model_D:cuda() 180 | print(desc_D) 181 | print(model_D) 182 | elseif opt.model == 'small_18' then 183 | assert(opt.scratch == 1) -- check that this is not conditional on a previous scale 184 | ---------------------------------------------------------------------- 185 | local input_sz = opt.geometry[1] * opt.geometry[2] * opt.geometry[3] 186 | -- define D network to train 187 | local numhid = 600 188 | model_D = nn.Sequential() 189 | model_D:add(nn.View(input_sz):setNumInputDims(3)) 190 | model_D:add(nn.Linear(input_sz, numhid)) 191 | model_D:add(nn.ReLU()) 192 | model_D:add(nn.Dropout()) 193 | model_D:add(nn.Linear(numhid, numhid)) 194 | model_D:add(nn.ReLU()) 195 | model_D:add(nn.Dropout()) 196 | model_D:add(nn.Linear(numhid, 1)) 197 | model_D:add(nn.Sigmoid()) 198 | ---------------------------------------------------------------------- 199 | local noiseDim = opt.noiseDim[1] * opt.noiseDim[2] * opt.noiseDim[3] 200 | -- define G network to train 201 | local numhid = 600 202 | model_G = nn.Sequential() 203 | model_G:add(nn.View(noiseDim):setNumInputDims(3)) 204 | model_G:add(nn.Linear(noiseDim, numhid)) 205 | model_G:add(nn.ReLU()) 206 | model_G:add(nn.Linear(numhid, numhid)) 207 | model_G:add(nn.ReLU()) 208 | model_G:add(nn.Linear(numhid, input_sz)) 209 | model_G:add(nn.View(opt.geometry[1], opt.geometry[2], opt.geometry[3])) 210 | end 211 | 212 | -- loss function: negative log-likelihood 213 | criterion = nn.BCECriterion() 214 | 215 | model_D:cuda() 216 | model_G:cuda() 217 | criterion:cuda() 218 | 219 | 220 | -- retrieve parameters and gradients 221 | parameters_D,gradParameters_D = model_D:getParameters() 222 | parameters_G,gradParameters_G = model_G:getParameters() 223 | 224 | print('\nNumber of free parameters in D: ' .. freeParams(model_D)) 225 | print('Number of free parameters in G: ' .. freeParams(model_G) .. '\n') 226 | -------------------------------------------------------------------------------- /lsun/modelGenerator/gentest.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | paths.dofile('modelGen.lua') 11 | 12 | opt = opt or {} 13 | opt.geometry = opt.geometry or {3,16,16} 14 | print('\nGenerator') 15 | model_G, desc_G = generateModelG(2,5,64,1024,3,11, 'mixed', 0, 4, 2) 16 | print(desc_G) 17 | print('\nDiscriminator') 18 | model_D, desc_D = generateModelD(2,6,64,1024,3,11, 'mixed', 0, 4, 2) 19 | print(desc_D) 20 | -------------------------------------------------------------------------------- /lsun/modelGenerator/modelGen.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'cudnn' 11 | paths.dofile('../layers/cudnnSpatialConvolutionUpsample.lua') 12 | 13 | -- square kernels. 14 | -- pool size == stride 15 | -- poolType = max,l2,avg,maxout,poolout,mixed 16 | -- poolOrder = 0/1, 0 = pool + upsampling. 1 = upsampling + pooling 17 | -- minHidden/maxHidden - linear layer hidden units 18 | -- dropoutType - 0,1,2. 0 = no dropout, 1 = Dropout, 2 = SpatialDropout + Dropout 19 | -- L1Penalty = 0,1 (enabled or disabled) 20 | -- batchNorm = 0,1,2 (0 = disabled, 1 = before relu, 2 = after relu, 3 = only at the end) 21 | 22 | -- ngroups, 0,4 23 | function generateModelG(minLayers, maxLayers, 24 | minPlanes, maxPlanes, 25 | minKH, maxKH, 26 | poolType, 27 | nGroupsMin, nGroupsMax, 28 | dropoutType, batchNorm) 29 | assert(minPlanes % 16 == 0, 'minPlanes has to be a multiple of 16') 30 | assert(maxPlanes % 16 == 0, 'maxPlanes has to be a multiple of 16') 31 | local bn = torch.random(0, 3) 32 | minPlanes = minPlanes / 16 33 | maxPlanes = maxPlanes / 16 34 | maxKH = math.min(maxKH, opt.geometry[2] - 1) 35 | local desc = '' 36 | if dropoutType > 0 then if torch.random(1,2) == 1 then dropoutType = 0 end end 37 | local nLayers = torch.random(minLayers, maxLayers) 38 | local function nPlanes() 39 | return torch.random(minPlanes, maxPlanes) * 16 -- planes are always a multiple of 16 40 | end 41 | local function kH() 42 | return math.floor((torch.random(minKH, maxKH)/2)) * 2 + 1 -- odd kernel size 43 | end 44 | local function getFactor() 45 | if poolType == 'mixed' or poolType == 'maxout' or poolType == 'poolout' then 46 | return torch.random(1,2) 47 | else 48 | return 1 49 | end 50 | end 51 | 52 | local function pool(model, factor) 53 | if factor > 1 then 54 | local t 55 | if poolType == 'maxout' then t = 1 56 | elseif poolType == 'poolout' then t = 2 57 | elseif poolType == 'mixed' then t = torch.random(1,2) end 58 | if t == 1 then 59 | desc = desc .. '___P_' .. 'MOut_' .. factor 60 | model:add(nn.VolumetricMaxPooling(factor, 1, 1)) 61 | elseif t == 2 then 62 | desc = desc .. '___P_' .. 'LPOut_' .. factor 63 | require 'fbcunn' 64 | model:add(nn.FeatureLPPooling(2,2,2,true)) 65 | end 66 | end 67 | end 68 | 69 | local model = nn.Sequential() 70 | desc = desc .. '___JT22' 71 | model:add(nn.JoinTable(2,2)) 72 | local factor = getFactor() 73 | local planesOut = torch.random(1,5) * 16 74 | local k = kH() 75 | desc = desc .. '___C_' .. 4 .. '_' .. planesOut .. '_g' .. 1 .. '_' .. k .. 'x' .. k 76 | model:add(cudnn.SpatialConvolutionUpsample(3+1, planesOut, k, k, 1, 1)) 77 | if bn == 1 then 78 | desc = desc .. '__BNA' 79 | model:add(nn.SpatialBatchNormalization(planesOut, nil, nil, true)) 80 | end 81 | desc = desc .. '___R' 82 | model:add(cudnn.ReLU(true)) 83 | if bn == 2 then 84 | desc = desc .. '__BN' 85 | model:add(nn.SpatialBatchNormalization(planesOut, nil, nil, false)) 86 | end 87 | 88 | pool(model, factor) 89 | planesOut = planesOut / factor 90 | local planesIn = planesOut 91 | 92 | for i=1,nLayers-2 do 93 | local factor = getFactor() 94 | local planesOut = nPlanes() 95 | local k = kH() 96 | local groups = 13 97 | while planesIn % groups ~= 0 or planesOut % groups ~= 0 do 98 | local pow 99 | if planesOut > 256 or planesIn > 256 then 100 | pow = torch.random(2, nGroupsMax) 101 | else 102 | pow = torch.random(nGroupsMin, nGroupsMax) 103 | end 104 | groups = math.pow(2, pow) 105 | end 106 | desc = desc .. '___C_' .. planesIn .. '_' .. planesOut 107 | .. '_g' .. groups .. '_' .. k .. 'x' .. k 108 | model:add(cudnn.SpatialConvolutionUpsample(planesIn, planesOut, k, k, 1, groups)) 109 | if bn == 1 then 110 | desc = desc .. '__BNA' 111 | model:add(nn.SpatialBatchNormalization(planesOut, nil, nil, true)) 112 | end 113 | desc = desc .. '___R' 114 | model:add(cudnn.ReLU(true)) 115 | if bn == 2 then 116 | desc = desc .. '__BN' 117 | model:add(nn.SpatialBatchNormalization(planesOut, nil, nil, false)) 118 | end 119 | 120 | if dropoutType == 2 then 121 | local dp = torch.random(1,2) 122 | if dp == 2 then 123 | desc = desc .. '___SDrop ' .. 0.5 124 | model:add(nn.SpatialDropout(0.5)) 125 | end 126 | end 127 | pool(model, factor) 128 | planesOut = planesOut / factor 129 | planesIn = planesOut 130 | end 131 | desc = desc .. '___C_' .. planesIn .. '_' .. 3 .. '_g' .. 1 .. '_' .. k .. 'x' .. k 132 | model:add(cudnn.SpatialConvolutionUpsample(planesIn, 3, k, k, 1, 1)) 133 | if bn > 0 then 134 | desc = desc .. '__BNA' 135 | model:add(nn.SpatialBatchNormalization(3, nil, nil, true)) 136 | end 137 | 138 | 139 | model=model:cuda() 140 | local input1 = torch.zeros(10,3,opt.geometry[2],opt.geometry[2]):cuda() 141 | local input2 = torch.zeros(10,1,opt.geometry[2],opt.geometry[2]):cuda() 142 | model:forward({input1, input2}) 143 | return model, desc 144 | end 145 | 146 | function generateModelD(minLayers, maxLayers, 147 | minPlanes, maxPlanes, 148 | minKH, maxKH, 149 | poolType, 150 | nGroupsMin, nGroupsMax, 151 | dropoutType) 152 | assert(minPlanes % 16 == 0, 'minPlanes has to be a multiple of 16') 153 | assert(maxPlanes % 16 == 0, 'maxPlanes has to be a multiple of 16') 154 | minPlanes = minPlanes / 16 155 | maxPlanes = maxPlanes / 16 156 | local iH = opt.geometry[2] 157 | maxKH = math.min(maxKH, opt.geometry[2] - 1) 158 | if torch.random(1,2) == 1 then poolType = 'none' end 159 | local desc = '' 160 | if dropoutType > 0 then if torch.random(1,2) == 1 then dropoutType = 0 end end 161 | local nLayers = torch.random(minLayers, maxLayers) 162 | local function nPlanes() 163 | return torch.random(minPlanes, maxPlanes) * 16 -- planes are always a multiple of 16 164 | end 165 | local function kH() 166 | return math.floor(torch.random(minKH, maxKH)/2) * 2 + 1 -- odd kernel size 167 | end 168 | local function pool(model,planes) 169 | if poolType == 'none' then return end 170 | if torch.random(1,2) == 1 then return end 171 | if math.floor(iH / 2) < 1 then return true end 172 | local t 173 | if poolType == 'max' then t = 3 174 | elseif poolType == 'lp' then t = 4 175 | elseif poolType == 'avg' then t = 5 176 | elseif poolType == 'mixed' then t = torch.random(3,5) end 177 | if t == 3 then 178 | desc = desc .. '___P_' .. 'Max_' .. 2 179 | model:add(cudnn.SpatialMaxPooling(2,2,2,2)) 180 | elseif t == 4 then 181 | desc = desc .. '___P_' .. 'LP_' .. 2 182 | model:add(nn.SpatialLPPooling(planes,2,2,2,2,2)) 183 | elseif t == 5 then 184 | desc = desc .. '___P_' .. 'Avg_' .. 2 185 | model:add(cudnn.SpatialAveragePooling(2,2,2,2)) 186 | end 187 | iH = math.floor(iH / 2) 188 | end 189 | 190 | local model = nn.Sequential() 191 | desc = desc .. '___CAdd' 192 | model:add(nn.CAddTable()) 193 | 194 | local planesOut = torch.random(1,5) * 16 195 | local k = kH() 196 | desc = desc .. '___C_' .. 3 .. '_' .. planesOut .. '_g' .. 1 .. '_' .. k .. 'x' .. k 197 | model:add(cudnn.SpatialConvolution(3, planesOut, k, k, 1, 1)) 198 | desc = desc .. '___R' 199 | model:add(cudnn.ReLU(true)) 200 | iH = iH - k + 1 201 | pool(model, planesOut) 202 | assert(iH >= 1) 203 | 204 | local planesIn = planesOut 205 | 206 | for i=1,nLayers-2 do 207 | local planesOut = nPlanes() 208 | local k = kH() 209 | local groups = 13 210 | while planesIn % groups ~= 0 or planesOut % groups ~= 0 do 211 | local pow 212 | if planesOut > 256 or planesIn > 256 then 213 | pow = torch.random(2, nGroupsMax) 214 | else 215 | pow = torch.random(nGroupsMin, nGroupsMax) 216 | end 217 | groups = math.pow(2, pow) 218 | end 219 | if (iH - k + 1) < 1 then break end 220 | iH = iH - k + 1 221 | assert(iH >= 1) 222 | desc = desc .. '___C_' .. planesIn .. '_' .. planesOut 223 | .. '_g' .. groups .. '_' .. k .. 'x' .. k 224 | model:add(cudnn.SpatialConvolution(planesIn, planesOut, k, k, 1, 1, 0, 0, groups)) 225 | desc = desc .. '___R' 226 | model:add(cudnn.ReLU(true)) 227 | planesIn = planesOut 228 | if pool(model, planesOut) then break end 229 | assert(iH >= 1) 230 | end 231 | desc = desc .. '___V_' .. planesIn * iH * iH 232 | model:add(nn.View(planesIn * iH * iH):setNumInputDims(3)) 233 | desc = desc .. '___L ' .. planesIn * iH * iH .. '_' .. 1 234 | model:add(nn.Linear(planesIn * iH * iH, 1)) 235 | desc = desc .. '___Sig' 236 | model:add(nn.Sigmoid()) 237 | 238 | -- print(model) 239 | model=model:cuda() 240 | local input1 = torch.zeros(13,3,opt.geometry[2],opt.geometry[2]):cuda() 241 | local input2 = torch.zeros(13,3,opt.geometry[2],opt.geometry[2]):cuda() 242 | model:forward({input1, input2}) 243 | return model, desc 244 | end 245 | -------------------------------------------------------------------------------- /lsun/train.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'torch' 11 | require 'optim' 12 | require 'paths' 13 | 14 | local adversarial = {} 15 | 16 | -- reusable buffers 17 | local targets = torch.CudaTensor(opt.batchSize) 18 | local inputs = torch.CudaTensor(opt.batchSize, unpack(opt.geometry)) -- original full-res image - low res image 19 | local cond_inputs = torch.CudaTensor(opt.batchSize, unpack(opt.condDim)) -- low res image blown up and differenced from original 20 | local noise_inputs = torch.CudaTensor(opt.batchSize, unpack(opt.noiseDim)) -- pure noise 21 | local sampleTimer = torch.Timer() 22 | local dataTimer = torch.Timer() 23 | 24 | -- training function 25 | function adversarial.train(inputs_all, inputs_all2) 26 | local dataLoadingTime = dataTimer:time().real; sampleTimer:reset(); -- timers 27 | local err_G, err_D = -1, -1 28 | 29 | -- inputs_all = {diff, label, coarse, fine} 30 | inputs:copy(inputs_all[1]) 31 | cond_inputs:copy(inputs_all[3]) 32 | 33 | -- create closure to evaluate f(X) and df/dX of discriminator 34 | local fevalD = function(x) 35 | collectgarbage() 36 | gradParameters_D:zero() -- reset gradients 37 | 38 | local finputs = {inputs, cond_inputs} 39 | if opt.scratch == 1 then 40 | finputs = inputs 41 | end 42 | 43 | -- forward pass 44 | local outputs = model_D:forward(finputs) 45 | err_D = criterion:forward(outputs, targets) 46 | 47 | -- backward pass 48 | local df_do = criterion:backward(outputs, targets) 49 | model_D:backward(finputs, df_do) 50 | 51 | -- update confusion (add 1 since classes are binary) 52 | outputs[outputs:gt(0.5)] = 2 53 | outputs[outputs:le(0.5)] = 1 54 | confusion:batchAdd(outputs, targets:clone():add(1)) 55 | 56 | return err_D,gradParameters_D 57 | end 58 | ---------------------------------------------------------------------- 59 | -- create closure to evaluate f(X) and df/dX of generator 60 | local fevalG = function(x) 61 | collectgarbage() 62 | gradParameters_G:zero() -- reset gradients 63 | local finputsG = {noise_inputs, cond_inputs} 64 | if opt.scratch == 1 then 65 | finputsG = noise_inputs 66 | end 67 | -- forward pass 68 | local hallucinations = model_G:forward(finputsG) 69 | local finputsD = {hallucinations, cond_inputs} 70 | if opt.scratch == 1 then 71 | finputsD = hallucinations 72 | end 73 | local outputs = model_D:forward(finputsD) 74 | err_G = criterion:forward(outputs, targets) 75 | 76 | -- backward pass 77 | local df_hallucinations = criterion:backward(outputs, targets) 78 | model_D:backward(finputsD, df_hallucinations) 79 | local df_do = model_D.modules[1].gradInput[1] 80 | if opt.scratch == 1 then 81 | df_do = model_D.gradInput 82 | end 83 | model_G:backward(finputsG, df_do) 84 | 85 | return err_G,gradParameters_G 86 | end 87 | ---------------------------------------------------------------------- 88 | -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 89 | assert (opt.batchSize % 2 == 0) 90 | -- (1.1) Real data is in {inputs, cond_inputs} 91 | targets:fill(1) 92 | -- (1.2) Sampled data 93 | noise_inputs:uniform(-1, 1) 94 | local inps = {noise_inputs, cond_inputs} 95 | if opt.scratch == 1 then -- no scale conditioning if training from scratch 96 | inps = noise_inputs 97 | end 98 | local hallucinations = model_G:forward(inps) 99 | assert(hallucinations:size(1) == opt.batchSize) 100 | assert(hallucinations:size(2) == 3) 101 | assert(hallucinations:nElement() == inputs:nElement()) 102 | -- print(#hallucinations) 103 | -- print(#inputs) 104 | inputs:narrow(1, 1, opt.batchSize / 2):copy(hallucinations:narrow(1, 1, opt.batchSize / 2)) 105 | targets:narrow(1, 1, opt.batchSize / 2):fill(0) 106 | -- evaluate inputs and get the err for G and D separately 107 | local optimizeG = false 108 | local optimizeD = false 109 | local err_R, err_F 110 | do 111 | local margin = 0.3 112 | 113 | local finputs = {inputs, cond_inputs} 114 | if opt.scratch == 1 then 115 | finputs = inputs 116 | end 117 | 118 | local outputs = model_D:forward(finputs) 119 | assert(targets:narrow(1, (opt.batchSize / 2) + 1, opt.batchSize / 2):min() == 1) 120 | err_F = criterion:forward(outputs:narrow(1, 1, opt.batchSize / 2), targets:narrow(1, 1, opt.batchSize / 2)) 121 | err_R = criterion:forward(outputs:narrow(1, (opt.batchSize / 2) + 1, opt.batchSize / 2), targets:narrow(1, (opt.batchSize / 2) + 1, opt.batchSize / 2)) 122 | if err_F > err_R + margin then 123 | optimizeG = false; optimizeD = true; 124 | elseif err_F > err_R and err_F <= err_R + margin then optimizeG = true; optimizeD = false; 125 | elseif err_F <= err_R then 126 | optimizeG = true; optimizeD = false; 127 | end 128 | if err_R > 0.7 then optimizeD = true; end 129 | end 130 | if optimizeD then 131 | optim.sgd(fevalD, parameters_D, sgdState_D) 132 | end 133 | ---------------------------------------------------------------------- 134 | -- (2) Update G network: maximize log(D(G(z))) 135 | noise_inputs:uniform(-1, 1) 136 | targets:fill(1) 137 | cond_inputs:copy(inputs_all2[3]) 138 | if optimizeG then 139 | optim.sgd(fevalG, parameters_G, sgdState_G) 140 | end 141 | batchNumber = batchNumber + 1 142 | cutorch.synchronize(); collectgarbage(); 143 | -- xlua.progress(batchNumber, opt.epochSize) 144 | print(('Epoch: [%d][%d/%d]\tTime %.3f DataTime %.3f Err_G %.4f Err_D %.4f Err_R %.4f Err_F %.4f'):format(epoch, batchNumber, opt.epochSize, sampleTimer:time().real, dataLoadingTime, err_G, err_D, err_R, err_F)) 145 | dataTimer:reset() 146 | end 147 | 148 | -- test function 149 | function adversarial.test(inputs_all) 150 | -- (1) Real data 151 | targets:fill(1) 152 | inputs:copy(inputs_all[1]) 153 | cond_inputs:copy(inputs_all[3]) 154 | local finputs = {inputs, cond_inputs} 155 | if opt.scratch == 1 then 156 | finputs = inputs 157 | end 158 | 159 | local outputs = model_D:forward(finputs) -- get predictions from D 160 | 161 | -- add to confusion matrix 162 | outputs[outputs:gt(0.5)] = 2 163 | outputs[outputs:le(0.5)] = 1 164 | confusion:batchAdd(outputs, targets:clone():add(1)) 165 | ---------------------------------------------------------------------- 166 | -- (2) Generated data 167 | noise_inputs:uniform(-1, 1) 168 | local finputsG = {noise_inputs, cond_inputs} 169 | if opt.scratch == 1 then 170 | finputsG = noise_inputs 171 | end 172 | local samples = model_G:forward(finputsG) 173 | targets:fill(0) 174 | local finputsD = {samples, cond_inputs} 175 | if opt.scratch == 1 then 176 | finputsD = samples 177 | end 178 | local outputs = model_D:forward(finputsD) 179 | end 180 | 181 | return adversarial 182 | --------------------------------------------------------------------------------