├── logs └── README.md ├── images ├── G.png ├── skies.jpg ├── baubles.jpg ├── cat-faces.jpg ├── human-faces.jpg ├── christmas-trees.jpg ├── snowy-landscapes.jpg └── G.xml ├── samples └── README.md ├── .gitignore ├── show_model_content.lua ├── LICENSE ├── weight-init.lua ├── models_rgb.lua ├── sample.lua ├── dataset_rgb.lua ├── README.md ├── train_rgb.lua ├── adversarial_rgb.lua └── utils └── nn_utils.lua /logs/README.md: -------------------------------------------------------------------------------- 1 | This directory will contain all generated models. 2 | -------------------------------------------------------------------------------- /images/G.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/G.png -------------------------------------------------------------------------------- /images/skies.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/skies.jpg -------------------------------------------------------------------------------- /images/baubles.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/baubles.jpg -------------------------------------------------------------------------------- /images/cat-faces.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/cat-faces.jpg -------------------------------------------------------------------------------- /samples/README.md: -------------------------------------------------------------------------------- 1 | This directory will be used to save samples generated via `sample.lua`. 2 | -------------------------------------------------------------------------------- /images/human-faces.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/human-faces.jpg -------------------------------------------------------------------------------- /images/christmas-trees.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/christmas-trees.jpg -------------------------------------------------------------------------------- /images/snowy-landscapes.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/snowy-landscapes.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/*.net 2 | logs/*.net.old 3 | logs/images/*.jpg 4 | logs/images/*.png 5 | logs/images_good/*.jpg 6 | logs/images_good/*.png 7 | logs/images_bad/*.jpg 8 | logs/images_bad/*.png 9 | samples/*.jpg 10 | samples/*.png 11 | *_vgg.lua 12 | *_coco.lua 13 | 14 | # Compiled Lua sources 15 | luac.out 16 | 17 | # luarocks build files 18 | *.src.rock 19 | *.zip 20 | *.tar.gz 21 | 22 | # Object files 23 | *.o 24 | *.os 25 | *.ko 26 | *.obj 27 | *.elf 28 | 29 | # Precompiled Headers 30 | *.gch 31 | *.pch 32 | 33 | # Libraries 34 | *.lib 35 | *.a 36 | *.la 37 | *.lo 38 | *.def 39 | *.exp 40 | 41 | # Shared objects (inc. Windows DLLs) 42 | *.dll 43 | *.so 44 | *.so.* 45 | *.dylib 46 | 47 | # Executables 48 | *.exe 49 | *.out 50 | *.app 51 | *.i*86 52 | *.x86_64 53 | *.hex 54 | 55 | -------------------------------------------------------------------------------- /show_model_content.lua: -------------------------------------------------------------------------------- 1 | require 'paths' 2 | require 'nn' 3 | require 'cutorch' 4 | require 'cunn' 5 | require 'cudnn' 6 | require 'dpnn' 7 | 8 | OPT = lapp[[ 9 | --save (default "logs") subdirectory in which the model is saved 10 | --network (default "adversarial.net") name of the model file 11 | ]] 12 | 13 | local filepath = paths.concat(OPT.save, OPT.network) 14 | local tmp = torch.load(filepath) 15 | if tmp.epoch then print("") print("Epoch:") print(tmp.epoch) end 16 | if tmp.opt then print("") print("OPT:") print(tmp.opt) end 17 | if tmp.G then print("") print("G:") print(tmp.G) end 18 | if tmp.G1 then print("") print("G1:") print(tmp.G1) end 19 | if tmp.G2 then print("") print("G2:") print(tmp.G2) end 20 | if tmp.G3 then print("") print("G3:") print(tmp.G3) end 21 | if tmp.D then print("") print("D:") print(tmp.D) end 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Alexander Jung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /images/G.xml: -------------------------------------------------------------------------------- 1 | 7Vvfc5s4EP5r/HpjIDjpY51rc3Mz7XQuc3N3jyooWFOBPBgnzv31XaFdAxZusQsEjP1iadHP7/tWSAvMvPt495Cy9eqTCrmcufNwN/N+n7mu4/ge/GnLq7Hc3s2NIUpFiIUKw6P4n6ORim1FyDeVgplSMhPrqjFQScKDrGJjaapeqsWelKz2umYR9VgYHgMmbes/IsxWxnrnLgr7H1xEK+rZWbwzV76y4FuUqm2C/c1c7yn/mcsxo7byiXofAMRUKWhGp+LdPZcaSMLIoPHxyNX9IFOe4EB+XAEu6QrPTG5xnp+V2HBj3WSvNPuXlcj445oFOv8CDM+85SqLJeQcSEaSbTS+c0hvslR922Okrz4JKe+VVCnkE5VAE0smRZRANoBxcrAvcRw8zThJxp5LbsKJPHAV8yx9hSJYwb3zTRVU2AKV81LQdYumVYkpD20MBRLtGy5wgwRCVw8jarsE41IC7bpVoB2GoOHT5WOtMnchoe/l1xRSkU6NEm7nBoE7Dje5eNtw31hw36vkGSwOeCPUnns7YEQnUi63JrX8fCm4E6qIu0OiLgG/56Zt5NHDbOQ9ME8NeVpce0Ee7zMl5D+xHRi+wE1QJNGlQuw3hJio+BWIb4+JewELzuTE/a5Hcd9NUtxe0/WjDXHjbrTmnqnRn5q6vZrdSmfqpiOMjb3rT3DDckN49II9Dveqe4M95XvB3j5d/r3esBgABOnD+n6ZEDddWtpY1R375DnhPYvv4/mkF3Xbp9AJqNv3GkLcirqvx81ygIVU24u67fMmQX+A/EZEsYJqFwp6rzdM+wRqgciT8L0OZxcwlCA0sBFMEGme5z+4wnci+1eD+5uPuf+wDjT4EbBF4E1/PLSC4QcAwpjUNs2ZpPinjWkJMzrHlyEjG/guy8Rztcc6HLGHL0rAWArK6uOQ1IAZJ9Yph8IPmtkPh+Jqh4xmLI14ZjWUk7qfdDOe7fPuOHjOh3nluTHP9qH7RzwHegEUQWOqExhIiWud1WSfRq+ZbSmu/lb0/pSXpgSTMKidzuilLU4XbnyWs46euhuCtHvfpFlOnryfYn4ueYsOPc+OLLS6sBJ/utzY+LNgP/fG2CV/dthiHBsgE28ZyA7IGQPTdpRk0FsgE3IYyo30bF8+2AMd7pVb5NcO0fTE73FHP8+1UahlMbypFtpy9v6kcFr0YsBSwNhXWQr53K5SaCqF0wIcA5YCTqQshfxQf5VCUym8WQykbSngRCq7heKd3asWGrwme1rAZMBaIKgrYhhSdHQEYjgtADNkMdSdI940ljo+MXQc0OlRDHggqogh31BexdBUDHZ06E9o1hIEPOXVz5mPSgDDRk2eHOtnxgI+9nmPF2IRhrqbZd0zawWln2QuxBWU41Ch8vC6jW9ZQLcV0mpeFe3qrQuv9vOKgBlaJ0KAQ68/79VuE1Dns60Q0G1Mpe0Fb1CLW1tPhW/JuzpY3uxXa/56APAbfx12yW5HHzMRDTXrXt1HfGe4HWSL7ywNjcWXq96H7w== -------------------------------------------------------------------------------- /weight-init.lua: -------------------------------------------------------------------------------- 1 | -- Source: 2 | -- https://github.com/e-lab/torch-toolbox/blob/master/Weight-init/weight-init.lua 3 | 4 | -- 5 | -- Different weight initialization methods 6 | -- 7 | -- > model = require('weight-init')(model, 'heuristic') 8 | -- 9 | require("nn") 10 | 11 | 12 | -- "Efficient backprop" 13 | -- Yann Lecun, 1998 14 | local function w_init_heuristic(fan_in, fan_out) 15 | return math.sqrt(1/(3*fan_in)) 16 | end 17 | 18 | 19 | -- "Understanding the difficulty of training deep feedforward neural networks" 20 | -- Xavier Glorot, 2010 21 | local function w_init_xavier(fan_in, fan_out) 22 | return math.sqrt(2/(fan_in + fan_out)) 23 | end 24 | 25 | 26 | -- "Understanding the difficulty of training deep feedforward neural networks" 27 | -- Xavier Glorot, 2010 28 | local function w_init_xavier_caffe(fan_in, fan_out) 29 | return math.sqrt(1/fan_in) 30 | end 31 | 32 | 33 | -- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification" 34 | -- Kaiming He, 2015 35 | local function w_init_kaiming(fan_in, fan_out) 36 | return math.sqrt(4/(fan_in + fan_out)) 37 | end 38 | 39 | 40 | local function w_init(net, arg) 41 | -- choose initialization method 42 | local method = nil 43 | if arg == 'heuristic' then method = w_init_heuristic 44 | elseif arg == 'xavier' then method = w_init_xavier 45 | elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe 46 | elseif arg == 'kaiming' then method = w_init_kaiming 47 | else 48 | assert(false) 49 | end 50 | 51 | -- loop over all convolutional modules 52 | for i = 1, #net.modules do 53 | local m = net.modules[i] 54 | if m.__typename == 'nn.SpatialConvolution' then 55 | m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) 56 | elseif m.__typename == 'nn.SpatialConvolutionMM' then 57 | m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW)) 58 | elseif m.__typename == 'nn.LateralConvolution' then 59 | m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1)) 60 | elseif m.__typename == 'nn.VerticalConvolution' then 61 | m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) 62 | elseif m.__typename == 'nn.HorizontalConvolution' then 63 | m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW)) 64 | elseif m.__typename == 'nn.Linear' then 65 | m:reset(method(m.weight:size(2), m.weight:size(1))) 66 | elseif m.__typename == 'nn.TemporalConvolution' then 67 | m:reset(method(m.weight:size(2), m.weight:size(1))) 68 | end 69 | 70 | if m.bias then 71 | m.bias:zero() 72 | end 73 | end 74 | return net 75 | end 76 | 77 | 78 | return w_init 79 | -------------------------------------------------------------------------------- /models_rgb.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'dpnn' 4 | require 'cudnn' 5 | 6 | local models = {} 7 | 8 | -- Creates the generator model (G). 9 | -- @param dimensions The dimensions of each image as {channels, height, width}. 10 | -- @param cuda Whether to activate GPU mode for the model. 11 | -- @returns nn.Sequential 12 | function models.create_G(dimensions, cuda) 13 | local model = nn.Sequential() 14 | 15 | model:add(nn.JoinTable(2, 2)) 16 | 17 | if cuda then 18 | model:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor', true, true)) 19 | end 20 | 21 | local inner = nn.Sequential() 22 | local conc = nn.Concat(2) 23 | local left = nn.Sequential() 24 | local right = nn.Sequential() 25 | 26 | left:add(nn.Identity()) 27 | 28 | right:add(cudnn.SpatialConvolution(1+1, 16, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 29 | right:add(nn.SpatialBatchNormalization(16)) 30 | right:add(cudnn.ReLU(true)) 31 | 32 | right:add(cudnn.SpatialConvolution(16, 32, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 33 | right:add(nn.SpatialBatchNormalization(32)) 34 | right:add(cudnn.ReLU(true)) 35 | right:add(nn.SpatialMaxPooling(2, 2)) 36 | 37 | right:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 38 | right:add(nn.SpatialBatchNormalization(64)) 39 | right:add(cudnn.ReLU(true)) 40 | right:add(nn.SpatialMaxPooling(2, 2)) 41 | 42 | right:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 43 | right:add(nn.SpatialBatchNormalization(128)) 44 | right:add(cudnn.ReLU(true)) 45 | 46 | right:add(cudnn.SpatialConvolution(128, 256, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 47 | right:add(nn.SpatialBatchNormalization(256)) 48 | right:add(cudnn.ReLU(true)) 49 | 50 | right:add(nn.SpatialUpSamplingNearest(2)) 51 | right:add(cudnn.SpatialConvolution(256, 128, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 52 | right:add(nn.SpatialBatchNormalization(128)) 53 | right:add(cudnn.ReLU(true)) 54 | 55 | right:add(nn.SpatialUpSamplingNearest(2)) 56 | right:add(cudnn.SpatialConvolution(128, 64, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 57 | right:add(nn.SpatialBatchNormalization(64)) 58 | right:add(cudnn.ReLU(true)) 59 | 60 | conc:add(left) 61 | conc:add(right) 62 | inner:add(conc) 63 | 64 | inner:add(cudnn.SpatialConvolution(2+64, 32, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 65 | inner:add(nn.SpatialBatchNormalization(32)) 66 | inner:add(cudnn.ReLU(true)) 67 | 68 | inner:add(cudnn.SpatialConvolution(32, 3, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 69 | inner:add(nn.Sigmoid()) 70 | 71 | model:add(inner) 72 | 73 | if cuda then 74 | model:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true)) 75 | inner:cuda() 76 | end 77 | 78 | model = require('weight-init')(model, 'heuristic') 79 | 80 | return model 81 | end 82 | 83 | -- Creates the discriminator model (D). 84 | -- @param dimensions The dimensions of each image as {channels, height, width}. 85 | -- @param cuda Whether to activate GPU mode for the model. 86 | -- @returns nn.Sequential 87 | function models.create_D(dimensions, cuda) 88 | local model = nn.Sequential() 89 | 90 | --model:add(nn.CAddTable()) 91 | model:add(nn.JoinTable(2, 2)) 92 | 93 | if cuda then 94 | model:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor', true, true)) 95 | end 96 | 97 | local inner = nn.Sequential() 98 | 99 | -- 64x64 100 | inner:add(nn.SpatialConvolution(3+1, 64, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 101 | inner:add(cudnn.ReLU(true)) 102 | inner:add(nn.SpatialDropout(0.25)) 103 | inner:add(nn.SpatialAveragePooling(2, 2, 2, 2)) 104 | 105 | -- 32x32 106 | inner:add(nn.SpatialConvolution(64, 128, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 107 | inner:add(cudnn.ReLU(true)) 108 | inner:add(nn.SpatialDropout(0.25)) 109 | 110 | -- 32x32 111 | inner:add(nn.SpatialConvolution(128, 256, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 112 | inner:add(cudnn.ReLU(true)) 113 | inner:add(nn.SpatialDropout(0.25)) 114 | inner:add(nn.SpatialMaxPooling(2, 2)) 115 | 116 | -- 16x16 117 | inner:add(nn.SpatialConvolution(256, 256, 3, 3, 1, 1, (3-1)/2, (3-1)/2)) 118 | inner:add(cudnn.ReLU(true)) 119 | inner:add(nn.SpatialDropout(0.5)) 120 | inner:add(nn.SpatialMaxPooling(2, 2)) 121 | 122 | local height = dimensions[2] * 0.5 * 0.5 * 0.5 123 | local width = dimensions[3] * 0.5 * 0.5 * 0.5 124 | 125 | -- 8x8 126 | inner:add(nn.View(256*height*width)) 127 | inner:add(nn.Linear(256*height*width, 128)) 128 | inner:add(nn.PReLU()) 129 | inner:add(nn.Dropout(0.5)) 130 | inner:add(nn.Linear(128, 1)) 131 | inner:add(nn.Sigmoid()) 132 | 133 | model:add(inner) 134 | 135 | if cuda then 136 | model:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true)) 137 | inner:cuda() 138 | end 139 | 140 | model = require('weight-init')(model, 'heuristic') 141 | 142 | return model 143 | end 144 | 145 | return models 146 | -------------------------------------------------------------------------------- /sample.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'image' 3 | require 'paths' 4 | require 'pl' 5 | require 'cudnn' 6 | NN_UTILS = require 'utils.nn_utils' 7 | DATASET = require 'dataset_rgb' 8 | 9 | OPT = lapp[[ 10 | --save (default "logs") Directory in which the networks are stored. 11 | --network (default "adversarial.net") Filename of the network to use. 12 | --neighbours Whether to search for nearest neighbours of generated images in the dataset (takes long) 13 | --writeto (default "samples") Directory to save the images to 14 | --seed (default 1) Random number seed to use. 15 | --gpu (default 0) GPU to run on 16 | --runs (default 1) How often to sample and save images 17 | --noiseDim (default 100) Noise vector size. 18 | --batchSize (default 16) Sizes of batches. 19 | --dataset (default "NONE") Directory that contains *.jpg images 20 | ]] 21 | 22 | if OPT.gpu < 0 then 23 | print("[ERROR] Sample script currently only runs on GPU, set --gpu=x where x is between 0 and 3.") 24 | exit() 25 | end 26 | 27 | -- Start GPU mode 28 | print("Starting gpu support...") 29 | require 'cutorch' 30 | require 'cunn' 31 | torch.setdefaulttensortype('torch.FloatTensor') 32 | cutorch.setDevice(OPT.gpu + 1) 33 | 34 | -- initialize seeds 35 | math.randomseed(OPT.seed) 36 | torch.manualSeed(OPT.seed) 37 | cutorch.manualSeed(OPT.seed) 38 | 39 | -- Initialize dataset 40 | DATASET.setFileExtension("jpg") 41 | 42 | -- Main function that runs the sampling 43 | function main() 44 | -- Load all models 45 | local G, D, height, width, dataset = loadModels() 46 | 47 | -- Image dimensions 48 | IMG_DIMENSIONS = {3, height, width} 49 | NOISE_DIM = {1, height, width} 50 | 51 | DATASET.setHeight(height) 52 | DATASET.setWidth(width) 53 | if OPT.dataset ~= "NONE" or dataset == nil then 54 | DATASET.setDirs({OPT.dataset}) 55 | else 56 | DATASET.setDirs({dataset}) 57 | end 58 | 59 | print("Sampling...") 60 | for run=1,OPT.runs do 61 | -- save 64 randomly selected images from the training set 62 | local imagesTrainList = DATASET.loadRandomImages(64) 63 | -- dont use nn_utils.toImageTensor here, because the metatable of imagesTrainList was changed 64 | local imagesTrain = torch.Tensor(#imagesTrainList, imagesTrainList[1].grayscale:size(1), imagesTrainList[1].grayscale:size(2), imagesTrainList[1].grayscale:size(3)) 65 | for i=1,#imagesTrainList do 66 | imagesTrain[i] = imagesTrainList[i].grayscale 67 | end 68 | image.save(paths.concat(OPT.writeto, string.format('trainset_s1_%04d_base.jpg', run)), toGrid(imagesTrainList, nil, 8)) 69 | 70 | -- sample 64 colorizations from G 71 | local noise = torch.Tensor(64, NOISE_DIM[1], NOISE_DIM[2], NOISE_DIM[3]) 72 | noise:uniform(0, 1) 73 | local imagesGenerated = G:forward({noise, imagesTrain}) 74 | 75 | -- validate image dimensions 76 | if imagesGenerated[1]:size(1) ~= IMG_DIMENSIONS[1] or imagesGenerated[1]:size(2) ~= IMG_DIMENSIONS[2] or imagesGenerated[1]:size(3) ~= IMG_DIMENSIONS[3] then 77 | print("[WARNING] dimension mismatch between images generated by base G and command line parameters") 78 | print("Dimension G:", images[1]:size()) 79 | print("Settings:", IMG_DIMENSIONS) 80 | end 81 | 82 | -- save big images of those 1024 random images 83 | image.save(paths.concat(OPT.writeto, string.format('random64_%04d.jpg', run)), toGrid(imagesTrainList, imagesGenerated, 12)) 84 | 85 | xlua.progress(run, OPT.runs) 86 | end 87 | 88 | print("Finished.") 89 | end 90 | 91 | -- Converts images to one image grid with set amount of rows. 92 | -- @param images Tensor of images 93 | -- @param nrow Number of rows. 94 | -- @return Tensor 95 | function toGrid(imagesOriginal, imagesGenerated, nrow) 96 | local N = 2 * imagesOriginal:size() 97 | if imagesGenerated ~= nil then N = N + imagesGenerated:size(1) end 98 | local images = torch.Tensor(N, 3, imagesOriginal[1].color:size(2), imagesOriginal[1].color:size(3)) 99 | local idx = 1 100 | for i=1,#imagesOriginal do 101 | images[{{idx}, {1}, {}, {}}] = imagesOriginal[i].grayscale 102 | images[{{idx}, {2}, {}, {}}] = imagesOriginal[i].grayscale 103 | images[{{idx}, {3}, {}, {}}] = imagesOriginal[i].grayscale 104 | images[idx+1] = imagesOriginal[i].color 105 | if imagesGenerated ~= nil then 106 | images[idx+2] = imagesGenerated[i] 107 | idx = idx + 1 108 | end 109 | idx = idx + 2 110 | end 111 | 112 | return image.toDisplayTensor{input=images, nrow=nrow} 113 | end 114 | 115 | -- Loads all necessary models/networks and returns them. 116 | -- @returns G, D, height, width, dataset directory 117 | function loadModels() 118 | local file = torch.load(paths.concat(OPT.save, OPT.network)) 119 | 120 | local G = file.G 121 | local D = file.D 122 | local opt_loaded = file.opt 123 | G:evaluate() 124 | D:evaluate() 125 | 126 | return G, D, opt_loaded.height, opt_loaded.width, opt_loaded.dataset 127 | end 128 | 129 | main() 130 | -------------------------------------------------------------------------------- /dataset_rgb.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'image' 3 | require 'paths' 4 | 5 | local dataset = {} 6 | 7 | -- load data from these directories 8 | dataset.dirs = {} 9 | 10 | -- load only images with this file extension 11 | dataset.fileExtension = "" 12 | 13 | -- expected original height/width of images 14 | dataset.originalHeight = 64 15 | dataset.originalWidth = 64 16 | 17 | -- desired height/width of images 18 | dataset.height = 32 19 | dataset.width = 32 20 | 21 | --dataset.colorSpace = "rgb" 22 | 23 | -- cache for filepaths to all images 24 | dataset.paths = nil 25 | 26 | -- Set directories to load images from 27 | -- @param dirs List of paths to directories 28 | function dataset.setDirs(dirs) 29 | dataset.dirs = dirs 30 | end 31 | 32 | -- Set file extension that images to load must have 33 | -- @param fileExtension the file extension of the images 34 | function dataset.setFileExtension(fileExtension) 35 | dataset.fileExtension = fileExtension 36 | end 37 | 38 | -- Desired height of the images (will be resized if necessary) 39 | -- @param scale The height of the images 40 | function dataset.setHeight(height) 41 | dataset.height = height 42 | end 43 | 44 | -- Desired height of the images (will be resized if necessary) 45 | -- @param scale The height of the images 46 | function dataset.setWidth(width) 47 | dataset.width = width 48 | end 49 | 50 | -- Set desired number of channels for the images (1=grayscale, 3=color) 51 | -- @param nbChannels The number of channels 52 | function dataset.setNbChannels(nbChannels) 53 | dataset.nbChannels = nbChannels 54 | end 55 | 56 | -- Loads the paths of all images in the defined files 57 | -- (with defined file extensions) 58 | function dataset.loadPaths() 59 | local files = {} 60 | local dirs = dataset.dirs 61 | local ext = dataset.fileExtension 62 | 63 | for i=1, #dirs do 64 | local dir = dirs[i] 65 | -- Go over all files in directory. We use an iterator, paths.files(). 66 | for file in paths.files(dir) do 67 | -- We only load files that match the extension 68 | if file:find(ext .. '$') then 69 | -- and insert the ones we care about in our table 70 | table.insert(files, paths.concat(dir,file)) 71 | end 72 | end 73 | 74 | -- sort for reproduceability 75 | table.sort(files, function (a,b) return a < b end) 76 | 77 | -- Check files 78 | if #files == 0 then 79 | error('given directory doesnt contain any files of type: ' .. ext) 80 | end 81 | end 82 | 83 | dataset.paths = files 84 | end 85 | 86 | -- Load images from the dataset. 87 | -- @param startAt Number of the first image. 88 | -- @param count Count of the images to load. 89 | -- @return Table of images. You can call :size() on that table to get the number of loaded images. 90 | function dataset.loadImages(startAt, count) 91 | --local endBefore = startAt + count 92 | if dataset.paths == nil then 93 | dataset.loadPaths() 94 | end 95 | 96 | local N = math.min(count, #dataset.paths) 97 | local images = torch.FloatTensor(N, 3, dataset.height, dataset.width) 98 | for i=0,(N-1) do 99 | local img = image.load(dataset.paths[startAt + i], dataset.nbChannels, "float") 100 | img = image.scale(img, dataset.width, dataset.height) 101 | --print(img, startAt, startAt+i, count, endBefore) 102 | images[i+1] = img 103 | end 104 | images = NN_UTILS.rgbToColorSpace(images, dataset.colorSpace) 105 | 106 | local result = {} 107 | result.data = images 108 | 109 | function result:size() 110 | return N 111 | end 112 | 113 | setmetatable(result, { 114 | __index = function(self, index) return self.data[index] end, 115 | __len = function(self) return self.data:size(1) end 116 | }) 117 | 118 | return result 119 | end 120 | 121 | -- Loads a defined number of randomly selected images from 122 | -- the cached paths (cached in loadPaths()). 123 | -- @param count Number of random images. 124 | -- @return List of Tensors 125 | function dataset.loadRandomImages(count) 126 | local images = dataset.loadRandomImagesFromPaths(count) 127 | local data = torch.FloatTensor(#images, 3, dataset.height, dataset.width) 128 | for i=1, #images do 129 | --data[i] = image.scale(images[i], dataset.width, dataset.height) 130 | data[i] = images[i] 131 | end 132 | local data_yuv = NN_UTILS.rgbToColorSpace(data, "yuv") 133 | 134 | local N = data:size(1) 135 | local result = {} 136 | result.color = data 137 | result.uv = data_yuv[{{}, {2,3}, {}, {}}] -- remove y channel from yuv 138 | result.grayscale = data_yuv[{{}, {1}, {}, {}}] -- only y channel from yuv 139 | 140 | --[[ 141 | image.display(NN_UTILS.switchColorSpace(result.color, "yuv", "rgb")[1]) 142 | image.display(result.uv[1]) 143 | image.display(result.grayscale[1]) 144 | io.read() 145 | --]] 146 | 147 | function result:size() 148 | return N 149 | end 150 | 151 | setmetatable(result, { 152 | __index = function(self, index) 153 | return { 154 | color = result.color[index], 155 | grayscale = result.grayscale[index], 156 | uv = result.uv[index] 157 | } 158 | end, 159 | __len = function(self) return self.color:size(1) end 160 | }) 161 | 162 | return result 163 | end 164 | 165 | -- Loads randomly selected images from the cached paths. 166 | -- TODO: merge with loadRandomImages() 167 | -- @param count Number of images to load 168 | -- @returns List of Tensors 169 | function dataset.loadRandomImagesFromPaths(count) 170 | if dataset.paths == nil then 171 | dataset.loadPaths() 172 | end 173 | 174 | local shuffle = torch.randperm(#dataset.paths) 175 | 176 | local images = {} 177 | for i=1,math.min(shuffle:size(1), count) do 178 | -- load each image 179 | --table.insert(images, image.load(dataset.paths[shuffle[i]], 3, "float")) 180 | local fp = dataset.paths[shuffle[i]] 181 | local img = image.load(fp, 3, "float") 182 | img = image.scale(img, dataset.width, dataset.height) 183 | table.insert(images, img) 184 | end 185 | 186 | return images 187 | end 188 | 189 | return dataset 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # About 2 | 3 | This project uses GANs ([generative adversarial networks](http://papers.nips.cc/paper/5423-generative-adversarial-nets)) to add color to black and white images. 4 | For each such image, the generator network (G) receives its black and white version and outputs a full RGB version of the image (i.e. the black and white image with color added to it). 5 | That RGB version is then rated (in regards to its quality) by the discriminator (D). 6 | The quality measure is backpropagated through D and then through G. 7 | Thereby G can learn to correctly colorize images. 8 | The architectures used are modifications of the [DCGAN](http://arxiv.org/abs/1511.06434) schema. 9 | See [this blog post](http://tinyclouds.org/colorize/) for an alternative version, which uses standard convnets (i.e. no GANs) with VGG features. 10 | 11 | Key results: 12 | * If a dataset of images can be generated by a GAN, then a GAN can also learn to add colors to it. 13 | * The task of adding colors seems to be a bit easier than the full generation of images. 14 | * G did not learn to add colors to rather rare and small elements (e.g. when coloring images of christmas trees it didn't add color to presents below the trees, small baubles or clothes of people in the image). This might partly be a limitation of the architecture, which uses pooling layers in G (hence small elements might get lost). 15 | * G did not learn to correctly add colors to datasets with high variance (heterogeneous collections of images). It would resort to mostly just adding one or two colors everywhere. 16 | * I experimented with using VGG features but didn't have much success with those. G didn't seem to learn more than without VGG features. My tests were limited though due to hardware constraints (VGG + G + D = three big networks in memory). It did not try the hypercolumn that was used in the [previously mentioned blog post](http://tinyclouds.org/colorize/). 17 | * Producing UV values in G and combining them with Y to an YUV image (which is then fed into D) failed. G just wouldn't learn anything. G had to output full RGB images to learn successfully. Not sure if there was a bug somewhere or if there's a good reason for that effect. 18 | 19 | # Images 20 | 21 | Colorizers were trained on multiple image datasets which were reused from previous projects. (I.e. *multiple* GANs were trained, not just one for all images. That's due to GANs not being very good at handling heterogeneous datasets.) 22 | Besides of the datasets shown below, the MSCOCO 2014 validation dataset was also used, but G failed to learn much on that one (it added mostly just 1-3 uniform colors per image), hence the results of that run are not shown. 23 | 24 | Notes: 25 | * There were no splits into training and validation sets (partly due to laziness, partly because GANs in my experience basically never just memorize the training set). Note how the coloring in the images below is often different from the original coloring. 26 | * Training times were usually quite fast (<=2 hours per dataset). 27 | * All generated color images were a little bit blurry, probably because G generated full RGB images instead of just adding color (UV in YUV). As such, it has to learn to copy the Y channel information correctly while still adding colors. 28 | 29 | ## Human faces 30 | 31 | This dataset worked fairly well. Notice the image in the 10th row at the far right. G assigns a skin color to the microphone. Also notice how G usually doesn't add red color to the lips. Maybe they get lost during the pooling...? 32 | 33 | ![Human faces, training set and images colored by G](images/human-faces.jpg?raw=true "Human faces, training set and images colored by G") 34 | 35 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.* 36 | 37 | ## Cat faces 38 | 39 | This dataset worked fairly well. 40 | 41 | ![Cat faces, training set and images colored by G](images/cat-faces.jpg?raw=true "Cat faces, training set and images colored by G") 42 | 43 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.* 44 | 45 | ## Skies 46 | 47 | Here G created sometimes weird mixtures of blue and orange. They were not visible in earlier epochs, but those then had weird vertical stripes around the borders of the images. 48 | 49 | ![Skies, training set and images colored by G](images/skies.jpg?raw=true "Skies, training set and images colored by G") 50 | 51 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.* 52 | 53 | ## Baubles 54 | 55 | This dataset already caused problems when I tried to generate it (i.e. full image generation, not just colorization). It didn't work too well here either. Baubles often remained colorless. I had to carefully select the optimal epoch to generate half decent images. There are blue blobs in some of the images. These blobs become bigger if the experiment is run longer. 56 | 57 | ![Baubles, training set and images colored by G](images/baubles.jpg?raw=true "Baubles, training set and images colored by G") 58 | 59 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.* 60 | 61 | ## Snowy landscapes 62 | 63 | Here G only had to either keep the black and white image or add some blue color, fairly easy task. It mostly learned that and sometimes exaggerated the blue (e.g. by adding it to trees). 64 | 65 | ![Snowy landscapes, training set and images colored by G](images/snowy-landscapes.jpg?raw=true "Snowy landscapes, training set and images colored by G") 66 | 67 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.* 68 | 69 | ## Christmas trees 70 | 71 | This dataset worked fairly well. When zooming in you can see that G doesn't add color to presents, baubles and people's clothings. E.g. for baubles look at row=9, col=3, for clothes row=8, col=1 and for presents row=1, col=2 (indices starting at 1). 72 | 73 | ![Christmas trees, training set and images colored by G](images/christmas-trees.jpg?raw=true "Christmas trees, training set and images colored by G") 74 | 75 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.* 76 | 77 | 78 | # Architecture 79 | 80 | The architecture of D was a standard convolutional neural net with one small fully connected layer at the end, mostly ReLU activations and some spatial dropout. 81 | 82 | G is an upsampling generator, similar to what is described in the DCGAN paper. Before the upsampling part it has an image analyzation part, similar to a standard convolutional network. The analyzation part takes the black and white image and tries to make sense of it with some convolutional and pooling layers. The black and white image is fed into the network for a second time towards the end so that the analyzation and upsampling parts can focus on the color and don't have to transfer the Y channel information through the layers. The noise layer is fed into the network both at the start and the end for technical simplicity (it just gets joined with the black and white image). 83 | 84 | ![Architecture of G](images/G.png?raw=true "Architecture of G") 85 | 86 | *Architecture of G. The formulation "Conv K, 3x3, relu, BN" denotes a convolutional layer with K kernels/planes, each with filter size 3x3, ReLU activation and batch normalization. Max pooling is over a 2x2 area. Upsampling layers increase height and width each by a factor of 2. The noise layer and the black and white image usually both have size 1x64x64.* 87 | 88 | D had about 3 million parameters (exact value depends on the input image size), G about 0.8 million. 89 | 90 | # Usage 91 | 92 | Requirements are: 93 | * Torch 94 | * Required packages (most of them should be part of the default torch install, install missing ones with `luarocks install packageName`): `cudnn`, `nn`, `pl`, `paths`, `image`, `optim`, `cutorch`, `cunn`, `cudnn`, `dpnn`, `display` 95 | * Image datasets have to be downloaded from previous projects and will likely require Python 2.7. You can however use your own dataset, provided that the images in that one are square or have an aspect ratio of 2.0 (e.g. height=64, width=32). Other ratios might work but haven't been tested. 96 | * [Human faces](https://github.com/aleju/face-generator) 97 | * [Cat faces](https://github.com/aleju/cat-generator) 98 | * [Skies](https://github.com/aleju/sky-generator) 99 | * [Christmas trees, baubles, snowy landscapes](https://github.com/aleju/christmas-generator) 100 | * NVIDIA GPU with cudnn3 and 4GB or more memory 101 | 102 | To train a network: 103 | * `~/.display/run.js &` - This will start `display`, which is used to plot results in the browser 104 | * Open http://localhost:8000/ in your browser (`display` interface) 105 | * Open a console in the repository directory and then `th train_rgb.lua --dataset="DATASET_PATH" --height=64 --width=64`, where `DATASET_PATH` is the filepath to the directory containing all your images (must be jpg). `height` and `width` resemble the size of the *generated* images. Your source images in that directory may be larger (e.g. 256x256). Only 32x32 (height x width), 64x64 and 32x64 were tested. Other values might result in errors. Note: Training keeps running until stopped manually with ctrl+c. 106 | 107 | To continue a training session use `th train_rgb.lua --dataset="DATSET_PATH" --height=64 --width=64 --network="logs/adversarial.net"`. 108 | To sample images (i.e. colorize images from the training set) use `th sample.lua` (should automatically reuse dataset directory, height and width). 109 | -------------------------------------------------------------------------------- /train_rgb.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'image' 3 | require 'pl' -- this is somehow responsible for lapp working in qlua mode 4 | require 'paths' 5 | ok, DISP = pcall(require, 'display') 6 | if not ok then print('display not found. unable to plot') end 7 | ADVERSARIAL = require 'adversarial_rgb' 8 | DATASET = require 'dataset_rgb' 9 | NN_UTILS = require 'utils.nn_utils' 10 | MODELS = require 'models_rgb' 11 | 12 | ---------------------------------------------------------------------- 13 | -- parse command-line options 14 | OPT = lapp[[ 15 | --save (default "logs") subdirectory to save logs 16 | --saveFreq (default 30) save every saveFreq epochs 17 | --network (default "") reload pretrained network 18 | --G_pretrained_dir (default "logs") 19 | --noplot plot while training 20 | --D_sgd_lr (default 0.02) D SGD learning rate 21 | --G_sgd_lr (default 0.02) G SGD learning rate 22 | --D_sgd_momentum (default 0) D SGD momentum 23 | --G_sgd_momentum (default 0) G SGD momentum 24 | --batchSize (default 32) batch size 25 | --N_epoch (default 30) Number of batches per epoch 26 | --G_L1 (default 0) L1 penalty on the weights of G 27 | --G_L2 (default 0e-6) L2 penalty on the weights of G 28 | --D_L1 (default 0e-7) L1 penalty on the weights of D 29 | --D_L2 (default 1e-4) L2 penalty on the weights of D 30 | --D_iterations (default 1) number of iterations to optimize D for 31 | --G_iterations (default 1) number of iterations to optimize G for 32 | --D_maxAcc (default 1.01) Deactivate learning of D while above this threshold 33 | --D_clamp (default 1) Clamp threshold for D's gradient (+/- N) 34 | --G_clamp (default 5) Clamp threshold for G's gradient (+/- N) 35 | --D_optmethod (default "adam") sgd|adagrad|adadelta|adamax|adam|rmsprob 36 | --G_optmethod (default "adam") sgd|adagrad|adadelta|adamax|adam|rmsprob 37 | --threads (default 4) number of threads 38 | --gpu (default 0) gpu to run on (default cpu) 39 | --noiseDim (default 100) dimensionality of noise vector 40 | --window (default 3) window id of sample image 41 | --seed (default 1) seed for the RNG 42 | --nopretraining Whether to deactivate loading of pretrained networks 43 | --height (default 64) Height of the training images 44 | --width (default 64) Width of the training images 45 | --dataset (default "NONE") Directory that contains *.jpg images 46 | ]] 47 | 48 | NORMALIZE = false 49 | START_TIME = os.time() 50 | 51 | if OPT.gpu < 0 or OPT.gpu > 3 then OPT.gpu = false end 52 | print(OPT) 53 | 54 | -- fix seed 55 | math.randomseed(OPT.seed) 56 | torch.manualSeed(OPT.seed) 57 | 58 | -- threads 59 | torch.setnumthreads(OPT.threads) 60 | print(' set nb of threads to ' .. torch.getnumthreads()) 61 | 62 | -- possible output of disciminator 63 | CLASSES = {"0", "1"} 64 | Y_GENERATOR = 0 65 | Y_NOT_GENERATOR = 1 66 | 67 | -- axis of images: 3 channels, height, width 68 | IMG_DIMENSIONS = {3, OPT.height, OPT.width} 69 | COND_DIM = {1, OPT.height, OPT.width} 70 | NOISE_DIM = {1, OPT.height, OPT.width} 71 | 72 | ---------------------------------------------------------------------- 73 | -- get/create dataset 74 | ---------------------------------------------------------------------- 75 | assert(OPT.dataset ~= "NONE") 76 | DATASET.setFileExtension("jpg") 77 | DATASET.setHeight(IMG_DIMENSIONS[2]) 78 | DATASET.setWidth(IMG_DIMENSIONS[3]) 79 | DATASET.setDirs({OPT.dataset}) 80 | ---------------------------------------------------------------------- 81 | 82 | -- run on gpu if chosen 83 | -- We have to load all kinds of libraries here, otherwise we risk crashes when loading 84 | -- saved networks afterwards 85 | print(" starting gpu support...") 86 | require 'nn' 87 | require 'cutorch' 88 | require 'cunn' 89 | require 'dpnn' 90 | if OPT.gpu then 91 | cutorch.setDevice(OPT.gpu + 1) 92 | cutorch.manualSeed(OPT.seed) 93 | print(string.format(" using gpu device %d", OPT.gpu)) 94 | end 95 | torch.setdefaulttensortype('torch.FloatTensor') 96 | 97 | function main() 98 | ---------------------------------------------------------------------- 99 | -- Load / Define network 100 | ---------------------------------------------------------------------- 101 | 102 | -- load previous networks (D and G) 103 | -- or initialize them new 104 | if OPT.network ~= "" then 105 | print(string.format(" reloading previously trained network: %s", OPT.network)) 106 | local tmp = torch.load(OPT.network) 107 | MODEL_D = tmp.D 108 | MODEL_G = tmp.G 109 | EPOCH = tmp.epoch + 1 110 | VIS_NOISE_INPUTS = tmp.vis_noise_inputs 111 | if NORMALIZE then 112 | NORMALIZE_MEAN = tmp.normalize_mean 113 | NORMALIZE_STD = tmp.normalize_std 114 | end 115 | 116 | if OPT.gpu == false then 117 | MODEL_D:float() 118 | MODEL_G:float() 119 | end 120 | else 121 | local pt_filename = paths.concat(OPT.save, string.format('pretrained_%dx%dx%d_nd%d.net', IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3], OPT.noiseDim)) 122 | -- pretrained via pretrain_with_previous_net.lua ? 123 | if not OPT.nopretraining and paths.filep(pt_filename) then 124 | local tmp = torch.load(pt_filename) 125 | MODEL_D = tmp.D 126 | MODEL_G = tmp.G 127 | MODEL_D:training() 128 | MODEL_G:training() 129 | if OPT.gpu == false then 130 | MODEL_D:float() 131 | MODEL_G:float() 132 | end 133 | else 134 | -------------- 135 | -- D 136 | -------------- 137 | MODEL_D = MODELS.create_D(IMG_DIMENSIONS, OPT.gpu ~= false) 138 | 139 | -------------- 140 | -- G 141 | -------------- 142 | local g_pt_filename = paths.concat(OPT.G_pretrained_dir, string.format('g_pretrained_%dx%dx%d_nd%d.net', IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3], OPT.noiseDim)) 143 | if not OPT.nopretraining and paths.filep(g_pt_filename) then 144 | -- Load a pretrained version of G 145 | print(" loading pretrained G...") 146 | local tmp = torch.load(g_pt_filename) 147 | MODEL_G = tmp.G 148 | MODEL_G:training() 149 | if OPT.gpu == false then 150 | MODEL_G:float() 151 | end 152 | else 153 | print(" Note: Did not find pretrained G") 154 | MODEL_G = MODELS.create_G(IMG_DIMENSIONS, OPT.gpu ~= false) 155 | end 156 | end 157 | end 158 | 159 | print(MODEL_G) 160 | print(MODEL_D) 161 | 162 | -- count free parameters in D/G 163 | print(string.format('Number of free parameters in D: %d', NN_UTILS.getNumberOfParameters(MODEL_D))) 164 | print(string.format('Number of free parameters in G: %d', NN_UTILS.getNumberOfParameters(MODEL_G))) 165 | 166 | -- loss function: negative log-likelihood 167 | CRITERION = nn.BCECriterion() 168 | 169 | -- retrieve parameters and gradients 170 | PARAMETERS_D, GRAD_PARAMETERS_D = MODEL_D:getParameters() 171 | PARAMETERS_G, GRAD_PARAMETERS_G = MODEL_G:getParameters() 172 | 173 | -- this matrix records the current confusion across classes 174 | CONFUSION = optim.ConfusionMatrix(CLASSES) 175 | 176 | -- Set optimizer states 177 | OPTSTATE = { 178 | adagrad = { D = {}, G = {} }, 179 | adadelta = { D = {}, G = {} }, 180 | adamax = { D = {}, G = {} }, 181 | adam = { D = {}, G = {} }, 182 | rmsprop = {D = {}, G = {}}, 183 | sgd = { 184 | D = {learningRate = OPT.D_sgd_lr, momentum = OPT.D_sgd_momentum}, 185 | G = {learningRate = OPT.G_sgd_lr, momentum = OPT.G_sgd_momentum} 186 | } 187 | } 188 | 189 | -- Whether to normalize the images. Not used for this project. 190 | if NORMALIZE then 191 | if NORMALIZE_MEAN == nil then 192 | TRAIN_DATA = DATASET.loadRandomImages(10000) 193 | NORMALIZE_MEAN, NORMALIZE_STD = TRAIN_DATA.normalize() 194 | end 195 | end 196 | 197 | if EPOCH == nil then 198 | EPOCH = 1 199 | end 200 | 201 | PLOT_DATA = {} 202 | 203 | -- Noise vectors. Not used for this project. 204 | if VIS_NOISE_INPUTS == nil then 205 | VIS_NOISE_INPUTS = NN_UTILS.createNoiseInputs(100) 206 | end 207 | 208 | -- Example images to use for plotting during training. 209 | EXAMPLE_IMAGES = DATASET.loadRandomImages(48) 210 | 211 | -- training loop 212 | while true do 213 | print('Loading new training data...') 214 | TRAIN_DATA = DATASET.loadRandomImages(OPT.N_epoch * OPT.batchSize) 215 | if NORMALIZE then 216 | TRAIN_DATA.normalize(NORMALIZE_MEAN, NORMALIZE_STD) 217 | end 218 | 219 | -- Show images and plots if requested 220 | if not OPT.noplot then 221 | --visualizeProgress(MODEL_G, MODEL_D, VIS_NOISE_INPUTS, TRAIN_DATA) 222 | visualizeProgressConditional() 223 | end 224 | 225 | -- Train D and G 226 | -- ... but train D only while having an accuracy below OPT.D_maxAcc 227 | -- over the last math.max(20, math.min(1000/OPT.batchSize, 250)) batches 228 | ADVERSARIAL.train(TRAIN_DATA, OPT.D_maxAcc, math.max(20, math.min(1000/OPT.batchSize, 250))) 229 | 230 | -- Save current net 231 | if EPOCH % OPT.saveFreq == 0 then 232 | local filename = paths.concat(OPT.save, 'adversarial.net') 233 | saveAs(filename) 234 | end 235 | 236 | EPOCH = EPOCH + 1 237 | end 238 | end 239 | 240 | -- Save the current models G and D to a file. 241 | -- @param filename The path to the file 242 | function saveAs(filename) 243 | os.execute(string.format("mkdir -p %s", sys.dirname(filename))) 244 | if paths.filep(filename) then 245 | os.execute(string.format("mv %s %s.old", filename, filename)) 246 | end 247 | print(string.format(" saving network to %s", filename)) 248 | NN_UTILS.prepareNetworkForSave(MODEL_G) 249 | NN_UTILS.prepareNetworkForSave(MODEL_D) 250 | torch.save(filename, {D = MODEL_D, G = MODEL_G, opt = OPT, plot_data = PLOT_DATA, epoch = EPOCH, vis_noise_inputs = VIS_NOISE_INPUTS, normalize_mean=NORMALIZE_MEAN, normalize_std=NORMALIZE_STD}) 251 | end 252 | 253 | -- Get examples to plot. 254 | -- Returns a list of the pattern 255 | -- [i] Image, black and white. 256 | -- [i+1] Image, color. 257 | -- [i+2] Image, color added by G. 258 | function getSamples() 259 | local N = EXAMPLE_IMAGES:size() 260 | local ds = EXAMPLE_IMAGES 261 | local noiseInputs = torch.Tensor(N, NOISE_DIM[1], NOISE_DIM[2], NOISE_DIM[3]) 262 | local condInputs = torch.Tensor(N, COND_DIM[1], COND_DIM[2], COND_DIM[3]) 263 | local gt = torch.Tensor(N, IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3]) 264 | 265 | -- Generate samples 266 | noiseInputs:uniform(0, 1) 267 | for i=1,N do 268 | --local idx = math.random(ds:size()) 269 | local idx = i 270 | local example = ds[idx] 271 | condInputs[i] = example.grayscale:clone() 272 | gt[i] = example.color:clone() 273 | end 274 | local samples = MODEL_G:forward({noiseInputs, condInputs}) 275 | 276 | local to_plot = {} 277 | for i=1,N do 278 | --local withColor = torch.cat(condInputs[i]:float(), samples[i]:float(), 1) 279 | local withColor = samples[i]:clone() 280 | to_plot[#to_plot+1] = NN_UTILS.switchColorSpaceSingle(condInputs[i]:float(), "y", "rgb") 281 | to_plot[#to_plot+1] = gt[i]:float() 282 | to_plot[#to_plot+1] = withColor 283 | end 284 | return to_plot 285 | end 286 | 287 | -- Updates the display plot. 288 | function visualizeProgressConditional() 289 | -- Show images and their refinements for the validation and training set 290 | --local toPlotVal = getSamples(VAL_DATA, 20) 291 | local toPlotTrain = getSamples() 292 | --DISP.image(toPlotVal, {win=OPT.window, width=2*10*IMG_DIMENSIONS[3], title=string.format("[VAL] Coarse, GT, G img, GT diff, G diff (%s epoch %d)", OPT.save, EPOCH)}) 293 | DISP.image(toPlotTrain, {win=OPT.window+1, width=14*IMG_DIMENSIONS[3], title=string.format("[TRAIN] original grayscale, original color, auto-colorized (%s epoch %d)", OPT.save, EPOCH)}) 294 | end 295 | 296 | main() 297 | -------------------------------------------------------------------------------- /adversarial_rgb.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'optim' 3 | require 'pl' 4 | require 'image' 5 | 6 | local adversarial = {} 7 | 8 | -- this variable will save the accuracy values of D 9 | adversarial.accs = {} 10 | 11 | -- function to calculate the mean of a list of numbers 12 | function adversarial.mean(t) 13 | local sum = 0 14 | local count = 0 15 | 16 | for k,v in pairs(t) do 17 | if type(v) == 'number' then 18 | sum = sum + v 19 | count = count + 1 20 | end 21 | end 22 | 23 | return (sum / count) 24 | end 25 | 26 | -- main training function 27 | function adversarial.train(trainData, maxAccuracyD, accsInterval) 28 | EPOCH = EPOCH or 1 29 | local N_epoch = OPT.N_epoch 30 | if N_epoch <= 0 then 31 | N_epoch = 100 32 | end 33 | local dataBatchSize = OPT.batchSize / 2 -- size of a half-batch for D or G 34 | local time = sys.clock() 35 | 36 | -- variables to track D's accuracy and adjust learning rates 37 | local lastAccuracyD = 0.0 38 | local doTrainD = true 39 | local countTrainedD = 0 40 | local countNotTrainedD = 0 41 | 42 | samples = nil 43 | local batchIdx = 0 44 | 45 | -- do one epoch 46 | -- While this function is structured like one that picks example batches in consecutive order, 47 | -- in reality the examples (per batch) will be picked randomly 48 | print(string.format(" Epoch #%d [batchSize = %d]", EPOCH, OPT.batchSize)) 49 | for batchIdx=1,N_epoch do 50 | -- size of this batch, will usually be dataBatchSize but can be lower at the end 51 | --local thisBatchSize = math.min(OPT.batchSize, N_epoch - t + 1) 52 | 53 | -- Inputs for D, either original or generated images 54 | local inputs = torch.Tensor(OPT.batchSize, 3, IMG_DIMENSIONS[2], IMG_DIMENSIONS[3]) 55 | 56 | -- target y-values 57 | local targets = torch.Tensor(OPT.batchSize) 58 | 59 | -- tensor to use for noise for G 60 | --local noiseInputs = torch.Tensor(thisBatchSize, OPT.noiseDim) 61 | local noiseInputs = torch.Tensor(OPT.batchSize, NOISE_DIM[1], NOISE_DIM[2], NOISE_DIM[3]) 62 | local condInputs = torch.Tensor(OPT.batchSize, COND_DIM[1], COND_DIM[2], COND_DIM[3]) 63 | 64 | ---------------------------------------------------------------------- 65 | -- create closure to evaluate f(X) and df/dX of D 66 | local fevalD = function(x) 67 | collectgarbage() 68 | local confusion_batch_D = optim.ConfusionMatrix(CLASSES) 69 | confusion_batch_D:zero() 70 | 71 | if x ~= PARAMETERS_D then -- get new parameters 72 | PARAMETERS_D:copy(x) 73 | end 74 | 75 | GRAD_PARAMETERS_D:zero() -- reset gradients 76 | 77 | -- forward pass 78 | -- condInputs = y, inputs = uv 79 | local outputs = MODEL_D:forward({condInputs, inputs}) 80 | local f = CRITERION:forward(outputs, targets) 81 | 82 | -- backward pass 83 | local df_do = CRITERION:backward(outputs, targets) 84 | MODEL_D:backward({condInputs, inputs}, df_do) 85 | 86 | -- penalties (L1 and L2): 87 | if OPT.D_L1 ~= 0 or OPT.D_L2 ~= 0 then 88 | -- Loss: 89 | f = f + OPT.D_L1 * torch.norm(PARAMETERS_D, 1) 90 | f = f + OPT.D_L2 * torch.norm(PARAMETERS_D, 2)^2/2 91 | -- Gradients: 92 | GRAD_PARAMETERS_D:add(torch.sign(PARAMETERS_D):mul(OPT.D_L1) + PARAMETERS_D:clone():mul(OPT.D_L2) ) 93 | end 94 | 95 | -- update confusion (add 1 since targets are binary) 96 | for i=1,OPT.batchSize do 97 | local c 98 | if outputs[i][1] > 0.5 then c = 2 else c = 1 end 99 | CONFUSION:add(c, targets[i]+1) 100 | confusion_batch_D:add(c, targets[i]+1) 101 | end 102 | 103 | -- Clamp D's gradients 104 | -- This helps a bit against D suddenly giving up (only outputting y=1 or y=0) 105 | if OPT.D_clamp ~= 0 then 106 | GRAD_PARAMETERS_D:clamp((-1)*OPT.D_clamp, OPT.D_clamp) 107 | end 108 | 109 | -- Calculate accuracy of D on this batch 110 | confusion_batch_D:updateValids() 111 | local tV = confusion_batch_D.totalValid 112 | 113 | -- Add this batch's accuracy to the history of D's accuracies 114 | -- Also, keep that history to a fixed size 115 | adversarial.accs[#adversarial.accs+1] = tV 116 | if #adversarial.accs > accsInterval then 117 | table.remove(adversarial.accs, 1) 118 | end 119 | 120 | -- Mean accuracy of D over the last couple of batches 121 | local accAvg = adversarial.mean(adversarial.accs) 122 | 123 | -- We will only train D if its mean accuracy over the last couple of batches 124 | -- was below the defined maximum (maxAccuracyD). This protects a bit against 125 | -- G generating garbage. 126 | doTrainD = (accAvg < maxAccuracyD) 127 | lastAccuracyD = tV 128 | if doTrainD then 129 | countTrainedD = countTrainedD + 1 130 | return f,GRAD_PARAMETERS_D 131 | else 132 | countNotTrainedD = countNotTrainedD + 1 133 | 134 | -- The interruptable* Optimizers dont train when false is returned 135 | -- Maybe that would be equivalent to just returning 0 for all gradients? 136 | return false,false 137 | end 138 | end 139 | 140 | ---------------------------------------------------------------------- 141 | -- create closure to evaluate f(X) and df/dX of generator 142 | local fevalG_on_D = function(x) 143 | collectgarbage() 144 | if x ~= PARAMETERS_G then -- get new parameters 145 | PARAMETERS_G:copy(x) 146 | end 147 | 148 | GRAD_PARAMETERS_G:zero() -- reset gradients 149 | 150 | -- forward pass 151 | --local samples = NN_UTILS.createImagesFromNoise(noiseInputs, false, true) 152 | local samples = MODEL_G:forward({noiseInputs, condInputs}) 153 | -- condInputs = y, samples = uv 154 | local outputs = MODEL_D:forward({condInputs, samples}) 155 | local f = CRITERION:forward(outputs, targets) 156 | 157 | -- backward pass 158 | local df_samples = CRITERION:backward(outputs, targets) 159 | MODEL_D:backward({condInputs, samples}, df_samples) 160 | local df_do = MODEL_D.modules[1].gradInput[2] -- 1=grad of y/condInput, 2=grad of uv/samples 161 | MODEL_G:backward({noiseInputs, condInputs}, df_do) 162 | 163 | -- penalties (L1 and L2): 164 | if OPT.G_L1 ~= 0 or OPT.G_L2 ~= 0 then 165 | -- Loss: 166 | f = f + OPT.G_L1 * torch.norm(PARAMETERS_G, 1) 167 | f = f + OPT.G_L2 * torch.norm(PARAMETERS_G, 2)^2/2 168 | -- Gradients: 169 | GRAD_PARAMETERS_G:add(torch.sign(PARAMETERS_G):mul(OPT.G_L2) + PARAMETERS_G:clone():mul(OPT.G_L2)) 170 | end 171 | 172 | -- clamp G's Gradient to the range of -1.0 to +1.0 173 | if OPT.G_clamp ~= 0 then 174 | GRAD_PARAMETERS_G:clamp((-1)*OPT.G_clamp, OPT.G_clamp) 175 | end 176 | 177 | return f,GRAD_PARAMETERS_G 178 | end 179 | ------------------- end of eval functions --------------------------- 180 | 181 | ---------------------------------------------------------------------- 182 | -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) 183 | -- Get half a minibatch of real, half fake 184 | for k=1, OPT.D_iterations do 185 | -- (1.1) Real data 186 | local inputIdx = 1 187 | local realDataSize = OPT.batchSize / 2 188 | for i = 1, realDataSize do 189 | local randomIdx = math.random(trainData:size()) 190 | local trainingExample = trainData[randomIdx] 191 | --inputs[inputIdx] = trainingExample.uv:clone() 192 | inputs[inputIdx] = trainingExample.color:clone() 193 | condInputs[inputIdx] = trainingExample.grayscale:clone() 194 | targets[inputIdx] = Y_NOT_GENERATOR 195 | inputIdx = inputIdx + 1 196 | end 197 | 198 | -- (1.2) Sampled data 199 | noiseInputs:uniform(0, 1) 200 | for i = 1, realDataSize do 201 | local randomIdx = math.random(trainData:size()) 202 | local trainingExample = trainData[randomIdx] 203 | condInputs[inputIdx] = trainingExample.grayscale:clone() 204 | inputIdx = inputIdx + 1 205 | end 206 | inputIdx = inputIdx - realDataSize 207 | 208 | local generatedUV = MODEL_G:forward({ 209 | noiseInputs[{{realDataSize+1,2*realDataSize}}], 210 | condInputs[{{realDataSize+1,2*realDataSize}}] 211 | }) 212 | for i=1, realDataSize do 213 | inputs[inputIdx] = generatedUV[i]:clone() 214 | targets[inputIdx] = Y_GENERATOR 215 | inputIdx = inputIdx + 1 216 | end 217 | 218 | if OPT.D_optmethod == "sgd" then 219 | optim.sgd(fevalD, PARAMETERS_D, OPTSTATE.sgd.D) 220 | elseif OPT.D_optmethod == "adagrad" then 221 | optim.adagrad(fevalD, PARAMETERS_D, OPTSTATE.adagrad.D) 222 | elseif OPT.D_optmethod == "adadelta" then 223 | optim.adadelta(fevalD, PARAMETERS_D, OPTSTATE.adadelta.D) 224 | elseif OPT.D_optmethod == "adamax" then 225 | optim.adamax(fevalD, PARAMETERS_D, OPTSTATE.adamax.D) 226 | elseif OPT.D_optmethod == "adam" then 227 | optim.adam(fevalD, PARAMETERS_D, OPTSTATE.adam.D) 228 | elseif OPT.D_optmethod == "rmsprop" then 229 | optim.rmsprop(fevalD, PARAMETERS_D, OPTSTATE.rmsprop.D) 230 | else 231 | print("[Warning] Unknown optimizer method chosen for D.") 232 | end 233 | end 234 | 235 | ---------------------------------------------------------------------- 236 | -- (2) Update G network: maximize log(D(G(z))) 237 | for k=1, OPT.G_iterations do 238 | noiseInputs:uniform(0, 1) 239 | targets:fill(Y_NOT_GENERATOR) 240 | for i=1,OPT.batchSize do 241 | local randomIdx = math.random(trainData:size()) 242 | local trainingExample = trainData[randomIdx] 243 | condInputs[i] = trainingExample.grayscale:clone() 244 | end 245 | 246 | if OPT.G_optmethod == "sgd" then 247 | optim.sgd(fevalG_on_D, PARAMETERS_G, OPTSTATE.sgd.G) 248 | elseif OPT.G_optmethod == "adagrad" then 249 | optim.adagrad(fevalG_on_D, PARAMETERS_G, OPTSTATE.adagrad.G) 250 | elseif OPT.G_optmethod == "adadelta" then 251 | optim.adadelta(fevalG_on_D, PARAMETERS_G, OPTSTATE.adadelta.G) 252 | elseif OPT.G_optmethod == "adamax" then 253 | optim.adamax(fevalG_on_D, PARAMETERS_G, OPTSTATE.adamax.G) 254 | elseif OPT.G_optmethod == "adam" then 255 | optim.adam(fevalG_on_D, PARAMETERS_G, OPTSTATE.adam.G) 256 | elseif OPT.G_optmethod == "rmsprop" then 257 | optim.rmsprop(fevalG_on_D, PARAMETERS_G, OPTSTATE.rmsprop.G) 258 | else 259 | print("[Warning] Unknown optimizer method chosen for G.") 260 | end 261 | end 262 | 263 | batchIdx = batchIdx + 1 264 | -- display progress 265 | xlua.progress(batchIdx * OPT.batchSize, N_epoch * OPT.batchSize) 266 | end 267 | 268 | -- time taken 269 | time = sys.clock() - time 270 | if maxAccuracyD < 1.0 then 271 | print(string.format(" trained D %d of %d times.", countTrainedD, countTrainedD + countNotTrainedD)) 272 | end 273 | 274 | -- print confusion matrix 275 | print("Confusion of D:") 276 | print(CONFUSION) 277 | local tV = CONFUSION.totalValid 278 | CONFUSION:zero() 279 | 280 | return tV 281 | end 282 | 283 | -- Show the activity of a network in windows (i.e. windows full of blinking dots). 284 | -- The windows will automatically be reused. 285 | -- Only the activity of the layer types nn.SpatialConvolution and nn.Linear will be shown. 286 | -- Linear layers must have a minimum size to be shown (i.e. to not show the tiny output layers). 287 | -- 288 | -- NOTE: This function can only visualize one network proberly while the program runs. 289 | -- I.e. you can't call this function to show network A and then another time to show network B, 290 | -- because the function tries to reuse windows and that will not work correctly in such a case. 291 | -- 292 | -- NOTE: Old function, probably doesn't work anymore. 293 | -- 294 | -- @param net The network to visualize. 295 | -- @param minOutputs Minimum (output) size of a linear layer to be shown. 296 | function adversarial.visualizeNetwork(net, minOutputs) 297 | if minOutputs == nil then 298 | minOutputs = 150 299 | end 300 | 301 | -- (Global) Table to save the window ids in, so that we can reuse them between calls. 302 | netvis_windows = netvis_windows or {} 303 | 304 | local modules = net:listModules() 305 | local winIdx = 1 306 | -- last module seems to have no output? 307 | for i=1,(#modules-1) do 308 | local t = torch.type(modules[i]) 309 | local showTensor = nil 310 | -- This function only shows the activity of 2d convolutions and linear layers 311 | if t == 'nn.SpatialConvolution' then 312 | showTensor = modules[i].output[1] 313 | elseif t == 'nn.Linear' then 314 | local output = modules[i].output 315 | local shape = output:size() 316 | local nbValues = shape[2] 317 | 318 | if nbValues >= minOutputs and nbValues >= minOutputs then 319 | local nbRows = torch.floor(torch.sqrt(nbValues)) 320 | while nbValues % nbRows ~= 0 and nbRows < nbValues do 321 | nbRows = nbRows + 1 322 | end 323 | 324 | if nbRows >= nbValues then 325 | showTensor = nil 326 | else 327 | showTensor = output[1]:view(nbRows, nbValues / nbRows) 328 | end 329 | end 330 | end 331 | 332 | -- Show the layer outputs in a window 333 | -- Note that windows are reused if possible 334 | if showTensor ~= nil then 335 | netvis_windows[winIdx] = image.display{ 336 | image=showTensor, zoom=1, nrow=32, 337 | min=-1, max=1, 338 | win=netvis_windows[winIdx], legend=t .. ' (#' .. i .. ')', 339 | padding=1 340 | } 341 | winIdx = winIdx + 1 342 | end 343 | end 344 | end 345 | 346 | return adversarial 347 | -------------------------------------------------------------------------------- /utils/nn_utils.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | 3 | local nn_utils = {} 4 | 5 | -- Sets the weights of a layer to random values within a range. 6 | -- @param weights The weights module to change, e.g. mlp.modules[1].weight. 7 | -- @param range Range for the new values (single number, e.g. 0.005) 8 | function nn_utils.setWeights(weights, range) 9 | weights:randn(weights:size()) 10 | weights:mul(range) 11 | end 12 | 13 | -- Initializes all weights of a multi layer network. 14 | -- @param model The nn.Sequential() model with one or more layers 15 | -- @param rangeWeights A range for the new weights values (single number, e.g. 0.005) 16 | -- @param rangeBias A range for the new bias values (single number, e.g. 0.005) 17 | function nn_utils.initializeWeights(model, rangeWeights, rangeBias) 18 | rangeWeights = rangeWeights or 0.005 19 | rangeBias = rangeBias or 0.001 20 | 21 | for m = 1, #model.modules do 22 | if model.modules[m].weight then 23 | nn_utils.setWeights(model.modules[m].weight, rangeWeights) 24 | end 25 | if model.modules[m].bias then 26 | nn_utils.setWeights(model.modules[m].bias, rangeBias) 27 | end 28 | end 29 | end 30 | 31 | function nn_utils.forwardBatched(model, input, batchSize) 32 | local N 33 | if input.size then 34 | N = input:size(1) 35 | else 36 | N = #input 37 | end 38 | 39 | local output 40 | local nBatches = math.ceil(N/batchSize) 41 | for i=1,nBatches do 42 | local batchStart = 1 + (i-1) * batchSize 43 | local batchEnd = math.min(i*batchSize, N) 44 | local forwarded = model:forward(input[{{batchStart, batchEnd}}]):clone() 45 | if output == nil then 46 | local sizes = forwarded:size() 47 | sizes[1] = N 48 | output = torch.Tensor():resize(sizes) 49 | end 50 | output[{{batchStart, batchEnd}, {}, {}, {}}] = forwarded 51 | end 52 | 53 | return output 54 | end 55 | 56 | -- Creates a tensor of N vectors, each of dimension OPT.noiseDim with random values 57 | -- between -1 and +1. 58 | -- @param N Number of vectors to generate 59 | -- @returns Tensor of shape (N, OPT.noiseDim) 60 | function nn_utils.createNoiseInputs(N) 61 | local noiseInputs = torch.Tensor(N, OPT.noiseDim) 62 | noiseInputs:uniform(-1.0, 1.0) 63 | return noiseInputs 64 | end 65 | 66 | -- Feeds noise vectors into G or AE+G and returns the result. 67 | -- @param noiseInputs Tensor from createNoiseInputs() 68 | -- @param outputAsList Whether to return the images as one list or as a tensor. 69 | -- @returns Either list of images (as returned by G/AE) or tensor of images 70 | function nn_utils.createImagesFromNoise(noiseInputs, outputAsList) 71 | local images 72 | local N = noiseInputs:size(1) 73 | local nBatches = math.ceil(N/OPT.batchSize) 74 | for i=1,nBatches do 75 | local batchStart = 1 + (i-1)*OPT.batchSize 76 | local batchEnd = math.min(i*OPT.batchSize, N) 77 | local generated = MODEL_G:forward(noiseInputs[{{batchStart, batchEnd}}]):clone() 78 | if images == nil then 79 | local img = generated[1] 80 | images = torch.Tensor(N, img:size(1), img:size(2), img:size(3)) 81 | end 82 | images[{{batchStart, batchEnd}, {}, {}, {}}] = generated 83 | end 84 | 85 | if outputAsList then 86 | local imagesList = {} 87 | for i=1, images:size(1) do 88 | imagesList[#imagesList+1] = images[i]:float() 89 | end 90 | return imagesList 91 | else 92 | return images 93 | end 94 | end 95 | 96 | -- Creates new random images with G or AE+G. 97 | -- @param N Number of images to create. 98 | -- @param outputAsList Whether to return the images as one list or as a tensor. 99 | -- @returns Either list of images (as returned by G/AE) or tensor of images 100 | function nn_utils.createImages(N, outputAsList) 101 | return nn_utils.createImagesFromNoise(nn_utils.createNoiseInputs(N), outputAsList) 102 | end 103 | 104 | -- Sorts images based on D's certainty that they are fake/real. 105 | -- Descending order starts at y=1 (Y_NOT_GENERATOR) and ends with y=0 (Y_GENERATOR). 106 | -- Therefore, in case of descending order, images for which D is very certain that they are real 107 | -- come first and images that seem to be fake (according to D) come last. 108 | -- @param images Tensor of the images to sort. 109 | -- @param ascending If true then images that seem most fake to D are placed at the start of the list. 110 | -- Otherwise the list starts with probably real images. 111 | -- @param nbMaxOut Sets how many images may be returned max (cant be more images than provided). 112 | -- @return Tuple (list of images, list of predictions between 0.0 and 1.0) 113 | -- where 1.0 means "probably real" 114 | function nn_utils.sortImagesByPrediction(images, ascending, nbMaxOut) 115 | local predictions = torch.Tensor(images:size(1), 1) 116 | local nBatches = math.ceil(images:size(1)/OPT.batchSize) 117 | for i=1,nBatches do 118 | local batchStart = 1 + (i-1)*OPT.batchSize 119 | local batchEnd = math.min(i*OPT.batchSize, images:size(1)) 120 | predictions[{{batchStart, batchEnd}, {1}}] = MODEL_D:forward(images[{{batchStart, batchEnd}, {}, {}, {}}]):clone() 121 | end 122 | 123 | local imagesWithPreds = {} 124 | for i=1,images:size(1) do 125 | table.insert(imagesWithPreds, {images[i], predictions[i][1]}) 126 | end 127 | 128 | if ascending then 129 | table.sort(imagesWithPreds, function (a,b) return a[2] < b[2] end) 130 | else 131 | table.sort(imagesWithPreds, function (a,b) return a[2] > b[2] end) 132 | end 133 | 134 | resultImages = {} 135 | resultPredictions = {} 136 | for i=1,math.min(nbMaxOut,#imagesWithPreds) do 137 | resultImages[i] = imagesWithPreds[i][1] 138 | resultPredictions[i] = imagesWithPreds[i][2] 139 | end 140 | 141 | return resultImages, resultPredictions 142 | end 143 | 144 | function nn_utils.switchColorSpace(images, from, to) 145 | images = nn_utils.toRgb(images, from) 146 | images = nn_utils.rgbToColorSpace(images, to) 147 | return images 148 | end 149 | 150 | function nn_utils.switchColorSpaceSingle(image, from, to) 151 | local images = nn_utils.toBatch(image) 152 | images = nn_utils.toRgb(images, from) 153 | images = nn_utils.rgbToColorSpace(images, to) 154 | return images[1] 155 | end 156 | 157 | function nn_utils.toRgb(images, from) 158 | local images = nn_utils.toImageTensor(images) 159 | if from == "rgb" then 160 | return images 161 | elseif from == "y" then 162 | --[[ 163 | local imagesTmp 164 | if images:size(4) == nil then 165 | imagesTmp = images:clone() 166 | else 167 | imagesTmp = images:clone():squeeze(2) 168 | end 169 | 170 | local N = imagesTmp:size(1) 171 | local height = imagesTmp:size(2) 172 | local width = imagesTmp:size(3) 173 | --]] 174 | return torch.repeatTensor(images, 1, 3, 1, 1) 175 | elseif from == "hsl" then 176 | local out = torch.Tensor(images:size(1), 3, images:size(3), images:size(4)) 177 | for i=1,images:size(1) do 178 | out[i] = image.hsl2rgb(images[i]) 179 | end 180 | return out 181 | elseif from == "yuv" then 182 | local out = torch.Tensor(images:size(1), 3, images:size(3), images:size(4)) 183 | for i=1,images:size(1) do 184 | out[i] = image.yuv2rgb(images[i]) 185 | end 186 | return out 187 | else 188 | print("[WARNING] unknown color space : '" .. from .. "'") 189 | end 190 | end 191 | 192 | function nn_utils.rgbToColorSpace(images, colorSpace) 193 | if colorSpace == "rgb" then 194 | return images 195 | else 196 | if colorSpace == "y" then 197 | local out = torch.Tensor(images:size(1), 1, images:size(3), images:size(4)) 198 | for i=1,images:size(1) do 199 | out[i] = nn_utils.rgb2y(images[i]) 200 | end 201 | return out 202 | elseif colorSpace == "hsl" then 203 | local out = torch.Tensor(images:size(1), 3, images:size(3), images:size(4)) 204 | for i=1,images:size(1) do 205 | out[i] = image.rgb2hsl(images[i]) 206 | end 207 | return out 208 | elseif colorSpace == "yuv" then 209 | local out = torch.Tensor(images:size(1), 3, images:size(3), images:size(4)) 210 | for i=1,images:size(1) do 211 | out[i] = image.rgb2yuv(images[i]) 212 | end 213 | return out 214 | else 215 | print("[WARNING] unknown color space in rgbToColorSpace: '" .. colorSpace .. "'") 216 | end 217 | end 218 | end 219 | 220 | -- convert rgb to grayscale by averaging channel intensities 221 | -- https://gist.github.com/jkrish/29ca7302e98554dd0fcb 222 | function nn_utils.rgb2y(im, threeChannels) 223 | -- Image.rgb2y uses a different weight mixture 224 | local dim, w, h = im:size()[1], im:size()[2], im:size()[3] 225 | if dim ~= 3 then 226 | print(' expected 3 channels') 227 | return im 228 | end 229 | 230 | -- a cool application of tensor:select 231 | local r = im:select(1, 1) 232 | local g = im:select(1, 2) 233 | local b = im:select(1, 3) 234 | 235 | local z = torch.Tensor(1, w, h):zero() 236 | 237 | -- z = z + 0.21r 238 | z = z:add(0.21, r) 239 | z = z:add(0.72, g) 240 | z = z:add(0.07, b) 241 | 242 | if threeChannels == true then 243 | z = torch.repeatTensor(z, 3, 1, 1) 244 | end 245 | 246 | return z 247 | end 248 | 249 | function nn_utils.toBatch(image) 250 | if image:size() == 2 then 251 | local tnsr = torch.Tensor():resize(1, image:size(1), image:size(2)) 252 | tnsr[1] = image 253 | return tnsr 254 | else 255 | local tnsr = torch.Tensor():resize(1, image:size(1), image:size(2), image:size(3)) 256 | tnsr[1] = image 257 | return tnsr 258 | end 259 | end 260 | 261 | -- Convert a list (table) of images to a Tensor. 262 | -- If the parameter is already a tensor, it will be returned unchanged. 263 | -- @param imageList A non-empty list/table or tensor of images (each being a tensor). 264 | -- @returns A tensor of shape (N, channels, height, width) 265 | function nn_utils.toImageTensor(imageList, forceChannel) 266 | if imageList.size ~= nil then 267 | if not forceChannel or (#imageList:size() == 4) then 268 | return imageList 269 | else 270 | -- forceChannel activated and images lack channel dimension 271 | -- add it 272 | local tens = torch.Tensor(imageList:size(1), 1, imageList:size(2), imageList:size(3)) 273 | for i=1,imageList:size(1) do 274 | tens[i][1] = imageList[i] 275 | end 276 | return tens 277 | end 278 | else 279 | if forceChannel == nil then 280 | forceChannel = false 281 | end 282 | 283 | local hasChannel = (#imageList[1]:size() == 3) 284 | 285 | local tens 286 | if hasChannel then 287 | tens = torch.Tensor(#imageList, imageList[1]:size(1), imageList[1]:size(2), imageList[1]:size(3)) 288 | elseif not hasChannel and forceChannel then 289 | tens = torch.Tensor(#imageList, 1, imageList[1]:size(1), imageList[1]:size(2)) 290 | else 291 | tens = torch.Tensor(#imageList, imageList[1]:size(1), imageList[1]:size(2)) 292 | end 293 | 294 | for i=1,#imageList do 295 | if (not hasChannel and forceChannel) then 296 | tens[i][1] = imageList[i] 297 | else 298 | tens[i] = imageList[i] 299 | end 300 | end 301 | return tens 302 | end 303 | end 304 | 305 | function nn_utils.toImageList(imageTensor, forceChannel) 306 | local tens = nn_utils.toImageTensor(imageTensor, forceChannel) 307 | local lst = {} 308 | for i=1,tens:size(1) do 309 | table.insert(lst, tens[i]) 310 | end 311 | return lst 312 | end 313 | 314 | -- Switch networks to training mode (activate Dropout) 315 | function nn_utils.switchToTrainingMode() 316 | if MODEL_AE then 317 | MODEL_AE:training() 318 | end 319 | MODEL_G:training() 320 | MODEL_D:training() 321 | end 322 | 323 | -- Switch networks to evaluation mode (deactivate Dropout) 324 | function nn_utils.switchToEvaluationMode() 325 | if MODEL_AE then 326 | MODEL_AE:evaluate() 327 | end 328 | MODEL_G:evaluate() 329 | MODEL_D:evaluate() 330 | end 331 | 332 | -- Normalize given images, currently to range -1.0 (black) to +1.0 (white), assuming that 333 | -- the input images are normalized to range 0.0 (black) to +1.0 (white). 334 | -- @param data Tensor of images 335 | -- @param mean_ Currently ignored. 336 | -- @param std_ Currently ignored. 337 | -- @return (mean, std), both currently always 0.5 dummy values 338 | function nn_utils.normalize(data, mean_, std_) 339 | -- Code to normalize to zero-mean and unit-variance. 340 | --[[ 341 | local mean = mean_ or data:mean(1) 342 | local std = std_ or data:std(1, true) 343 | local eps = 1e-7 344 | local N 345 | if data.size ~= nil then 346 | N = data:size(1) 347 | else 348 | N = #data 349 | end 350 | 351 | for i=1,N do 352 | data[i]:add(-1, mean) 353 | data[i]:cdiv(std + eps) 354 | end 355 | 356 | return mean, std 357 | --]] 358 | 359 | -- Code to normalize to range -1.0 to +1.0, where -1.0 is black and 1.0 is the maximum 360 | -- value in this image. 361 | --[[ 362 | local N 363 | if data.size ~= nil then 364 | N = data:size(1) 365 | else 366 | N = #data 367 | end 368 | 369 | for i=1,N do 370 | local m = torch.max(data[i]) 371 | data[i]:div(m * 0.5) 372 | data[i]:add(-1.0) 373 | data[i] = torch.clamp(data[i], -1.0, 1.0) 374 | end 375 | --]] 376 | 377 | -- Normalize to range -1.0 to +1.0, where -1.0 is black and +1.0 is white. 378 | local N 379 | if data.size ~= nil then 380 | N = data:size(1) 381 | else 382 | N = #data 383 | end 384 | 385 | for i=1,N do 386 | data[i]:mul(2) 387 | data[i]:add(-1.0) 388 | data[i] = torch.clamp(data[i], -1.0, 1.0) 389 | end 390 | 391 | -- Dummy return values 392 | return 0.5, 0.5 393 | end 394 | 395 | -- from https://github.com/torch/DEPRECEATED-torch7-distro/issues/47 396 | function nn_utils.zeroDataSize(data) 397 | if type(data) == 'table' then 398 | for i = 1, #data do 399 | data[i] = nn_utils.zeroDataSize(data[i]) 400 | end 401 | elseif type(data) == 'userdata' then 402 | data = torch.Tensor():typeAs(data) 403 | end 404 | return data 405 | end 406 | 407 | -- from https://github.com/torch/DEPRECEATED-torch7-distro/issues/47 408 | -- Resize the output, gradInput, etc temporary tensors to zero (so that the on disk size is smaller) 409 | function nn_utils.prepareNetworkForSave(node) 410 | if node.output ~= nil then 411 | node.output = nn_utils.zeroDataSize(node.output) 412 | end 413 | if node.gradInput ~= nil then 414 | node.gradInput = nn_utils.zeroDataSize(node.gradInput) 415 | end 416 | if node.finput ~= nil then 417 | node.finput = nn_utils.zeroDataSize(node.finput) 418 | end 419 | -- Recurse on nodes with 'modules' 420 | if (node.modules ~= nil) then 421 | if (type(node.modules) == 'table') then 422 | for i = 1, #node.modules do 423 | local child = node.modules[i] 424 | nn_utils.prepareNetworkForSave(child) 425 | end 426 | end 427 | end 428 | collectgarbage() 429 | end 430 | 431 | function nn_utils.getNumberOfParameters(net) 432 | local nparams = 0 433 | local dModules = net:listModules() 434 | for i=1,#dModules do 435 | if dModules[i].weight ~= nil then 436 | nparams = nparams + dModules[i].weight:nElement() 437 | end 438 | end 439 | return nparams 440 | end 441 | 442 | -- Contains the pixels necessary to draw digits 0 to 9 443 | CHAR_TENSORS = {} 444 | CHAR_TENSORS[0] = torch.Tensor({{1, 1, 1}, 445 | {1, 0, 1}, 446 | {1, 0, 1}, 447 | {1, 0, 1}, 448 | {1, 1, 1}}) 449 | CHAR_TENSORS[1] = torch.Tensor({{0, 0, 1}, 450 | {0, 0, 1}, 451 | {0, 0, 1}, 452 | {0, 0, 1}, 453 | {0, 0, 1}}) 454 | CHAR_TENSORS[2] = torch.Tensor({{1, 1, 1}, 455 | {0, 0, 1}, 456 | {1, 1, 1}, 457 | {1, 0, 0}, 458 | {1, 1, 1}}) 459 | CHAR_TENSORS[3] = torch.Tensor({{1, 1, 1}, 460 | {0, 0, 1}, 461 | {0, 1, 1}, 462 | {0, 0, 1}, 463 | {1, 1, 1}}) 464 | CHAR_TENSORS[4] = torch.Tensor({{1, 0, 1}, 465 | {1, 0, 1}, 466 | {1, 1, 1}, 467 | {0, 0, 1}, 468 | {0, 0, 1}}) 469 | CHAR_TENSORS[5] = torch.Tensor({{1, 1, 1}, 470 | {1, 0, 0}, 471 | {1, 1, 1}, 472 | {0, 0, 1}, 473 | {1, 1, 1}}) 474 | CHAR_TENSORS[6] = torch.Tensor({{1, 1, 1}, 475 | {1, 0, 0}, 476 | {1, 1, 1}, 477 | {1, 0, 1}, 478 | {1, 1, 1}}) 479 | CHAR_TENSORS[7] = torch.Tensor({{1, 1, 1}, 480 | {0, 0, 1}, 481 | {0, 0, 1}, 482 | {0, 0, 1}, 483 | {0, 0, 1}}) 484 | CHAR_TENSORS[8] = torch.Tensor({{1, 1, 1}, 485 | {1, 0, 1}, 486 | {1, 1, 1}, 487 | {1, 0, 1}, 488 | {1, 1, 1}}) 489 | CHAR_TENSORS[9] = torch.Tensor({{1, 1, 1}, 490 | {1, 0, 1}, 491 | {1, 1, 1}, 492 | {0, 0, 1}, 493 | {1, 1, 1}}) 494 | 495 | -- Converts a list of images to a grid of images that can be saved easily. 496 | -- It will also place the epoch number at the bottom of the image. 497 | -- At least parts of this function probably should have been a simple call 498 | -- to image.toDisplayTensor(). 499 | -- @param images Tensor of image tensors 500 | -- @param height Height of the grid 501 | -- @param width Width of the grid 502 | -- @param epoch The epoch number to draw at the bottom of the grid 503 | -- @returns tensor 504 | function nn_utils.imagesToGridTensor(images, height, width, epoch) 505 | local imgChannels = images:size(2) 506 | local imgHeightPx = IMG_DIMENSIONS[2] 507 | local imgWidthPx = IMG_DIMENSIONS[3] 508 | local heightPx = height * imgHeightPx + (1 + 5 + 1) 509 | local widthPx = width * imgWidthPx 510 | local grid = torch.Tensor(imgChannels, heightPx, widthPx) 511 | grid:zero() 512 | 513 | -- add images to grid, one by one 514 | local yGridPos = 1 515 | local xGridPos = 1 516 | for i=1,math.min(images:size(1), height*width) do 517 | -- set pixels of image 518 | local yStart = 1 + ((yGridPos-1) * imgHeightPx) 519 | local yEnd = yStart + imgHeightPx - 1 520 | local xStart = 1 + ((xGridPos-1) * imgWidthPx) 521 | local xEnd = xStart + imgWidthPx - 1 522 | grid[{{1,imgChannels}, {yStart,yEnd}, {xStart,xEnd}}] = images[i]:float() 523 | 524 | -- move to next position in grid 525 | xGridPos = xGridPos + 1 526 | if xGridPos > width then 527 | xGridPos = 1 528 | yGridPos = yGridPos + 1 529 | end 530 | end 531 | 532 | -- add the epoch at the bottom of the image 533 | local epochStr = tostring(epoch) 534 | local pos = 1 535 | for i=epochStr:len(),1,-1 do 536 | local c = tonumber(epochStr:sub(i,i)) 537 | for channel=1,imgChannels do 538 | local yStart = heightPx - 1 - 5 -- constant for all 539 | local yEnd = yStart + 5 - 1 -- constant for all 540 | local xStart = widthPx - 1 - pos*5 - pos 541 | local xEnd = xStart + 3 - 1 542 | 543 | grid[{{channel}, {yStart, yEnd}, {xStart, xEnd}}] = CHAR_TENSORS[c] 544 | end 545 | pos = pos + 1 546 | end 547 | 548 | return grid 549 | end 550 | 551 | -- Saves the list of image to the provided filepath (as a grid image). 552 | -- @param filepath Save the grid image to that filepath 553 | -- @param images List of image tensors 554 | -- @param height Height of the grid 555 | -- @param width Width of the grid 556 | -- @param epoch The epoch number to draw at the bottom of the grid 557 | -- @returns tensor 558 | function nn_utils.saveImagesAsGrid(filepath, images, height, width, epoch) 559 | local grid = nn_utils.imagesToGridTensor(images, height, width, epoch) 560 | os.execute(string.format("mkdir -p %s", sys.dirname(filepath))) 561 | image.save(filepath, grid) 562 | end 563 | 564 | -- Deactivates CUDA mode on a network and returns the result. 565 | -- Expects networks in CUDA mode to be a Sequential of the form 566 | -- [1] Copy layer [2] Sequential [3] Copy layer 567 | -- as created by activateCuda(). 568 | -- @param net The network to deactivate CUDA mode on. 569 | -- @returns The CPU network 570 | function nn_utils.deactivateCuda(net) 571 | local newNet = net:clone() 572 | newNet:float() 573 | if torch.type(newNet:get(1)) == 'nn.Copy' then 574 | return newNet:get(2) 575 | else 576 | return newNet 577 | end 578 | end 579 | 580 | -- Returns whether a Sequential contains any copy layers. 581 | -- @param net The network to analyze. 582 | -- @return boolean 583 | function nn_utils.containsCopyLayers(net) 584 | local modules = net:listModules() 585 | for i=1,#modules do 586 | local t = torch.type(modules[i]) 587 | if string.find(t, "Copy") ~= nil then 588 | return true 589 | end 590 | end 591 | return false 592 | end 593 | 594 | -- Activates CUDA mode on a network and returns the result. 595 | -- This adds Copy layers at the start and end of the network. 596 | -- Expects the default tensor to be FloatTensor. 597 | -- @param net The network to activate CUDA mode on. 598 | -- @returns The CUDA network 599 | function nn_utils.activateCuda(net) 600 | --[[ 601 | local newNet = net:clone() 602 | newNet:cuda() 603 | local tmp = nn.Sequential() 604 | tmp:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor')) 605 | tmp:add(newNet) 606 | tmp:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor')) 607 | return tmp 608 | --]] 609 | local newNet = net:clone() 610 | 611 | -- does the network already contain any copy layers? 612 | local containsCopyLayers = nn_utils.containsCopyLayers(newNet) 613 | 614 | -- no copy layers in the network yet 615 | -- add them at the start and end 616 | if not containsCopyLayers then 617 | local tmp = nn.Sequential() 618 | tmp:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor')) 619 | tmp:add(newNet) 620 | tmp:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor')) 621 | newNet:cuda() 622 | newNet = tmp 623 | end 624 | 625 | --[[ 626 | local firstCopyFound = false 627 | local lastCopyFound = false 628 | modules = newNet:listModules() 629 | for i=1,#modules do 630 | print("module "..i.." " .. torch.type(modules[i])) 631 | local t = torch.type(modules[i]) 632 | if string.find(t, "Copy") ~= nil then 633 | if not firstCopyFound then 634 | firstCopyFound = true 635 | modules[i]:cuda() 636 | modules[i].intype = 'torch.FloatTensor' 637 | modules[i].outtype = 'torch.CudaTensor' 638 | else 639 | -- last copy found 640 | lastCopyFound = true 641 | modules[i]:float() 642 | modules[i].intype = 'torch.CudaTensor' 643 | modules[i].outtype = 'torch.FloatTensor' 644 | end 645 | elseif lastCopyFound then 646 | print("calling float() A") 647 | modules[i]:float() 648 | elseif firstCopyFound then 649 | print("calling cuda()") 650 | modules[i]:cuda() 651 | else 652 | print("calling float() B") 653 | modules[i]:float() 654 | end 655 | end 656 | --]] 657 | 658 | return newNet 659 | end 660 | 661 | -- Creates an average rating (0 to 1) for a list of images. 662 | -- 1 is best. 663 | -- @param images List of image tensors. 664 | -- @returns float 665 | function nn_utils.rateWithV(images) 666 | local imagesTensor 667 | local N 668 | if type(images) == 'table' then 669 | N = #images 670 | imagesTensor = torch.Tensor(N, IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3]) 671 | for i=1,N do 672 | imagesTensor[i] = images[i] 673 | end 674 | else 675 | N = images:size(1) 676 | imagesTensor = images 677 | end 678 | 679 | local predictions = MODEL_V:forward(imagesTensor) 680 | local sm = 0 681 | for i=1,N do 682 | -- first neuron in V signals whether the image is fake (1=yes, 0=no) 683 | sm = sm + predictions[i][1] 684 | end 685 | 686 | local fakiness = sm / N 687 | 688 | -- higher values for better images 689 | return (1 - fakiness) 690 | end 691 | 692 | return nn_utils 693 | --------------------------------------------------------------------------------