├── logs
└── README.md
├── images
├── G.png
├── skies.jpg
├── baubles.jpg
├── cat-faces.jpg
├── human-faces.jpg
├── christmas-trees.jpg
├── snowy-landscapes.jpg
└── G.xml
├── samples
└── README.md
├── .gitignore
├── show_model_content.lua
├── LICENSE
├── weight-init.lua
├── models_rgb.lua
├── sample.lua
├── dataset_rgb.lua
├── README.md
├── train_rgb.lua
├── adversarial_rgb.lua
└── utils
└── nn_utils.lua
/logs/README.md:
--------------------------------------------------------------------------------
1 | This directory will contain all generated models.
2 |
--------------------------------------------------------------------------------
/images/G.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/G.png
--------------------------------------------------------------------------------
/images/skies.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/skies.jpg
--------------------------------------------------------------------------------
/images/baubles.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/baubles.jpg
--------------------------------------------------------------------------------
/images/cat-faces.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/cat-faces.jpg
--------------------------------------------------------------------------------
/samples/README.md:
--------------------------------------------------------------------------------
1 | This directory will be used to save samples generated via `sample.lua`.
2 |
--------------------------------------------------------------------------------
/images/human-faces.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/human-faces.jpg
--------------------------------------------------------------------------------
/images/christmas-trees.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/christmas-trees.jpg
--------------------------------------------------------------------------------
/images/snowy-landscapes.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aleju/colorizer/HEAD/images/snowy-landscapes.jpg
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | logs/*.net
2 | logs/*.net.old
3 | logs/images/*.jpg
4 | logs/images/*.png
5 | logs/images_good/*.jpg
6 | logs/images_good/*.png
7 | logs/images_bad/*.jpg
8 | logs/images_bad/*.png
9 | samples/*.jpg
10 | samples/*.png
11 | *_vgg.lua
12 | *_coco.lua
13 |
14 | # Compiled Lua sources
15 | luac.out
16 |
17 | # luarocks build files
18 | *.src.rock
19 | *.zip
20 | *.tar.gz
21 |
22 | # Object files
23 | *.o
24 | *.os
25 | *.ko
26 | *.obj
27 | *.elf
28 |
29 | # Precompiled Headers
30 | *.gch
31 | *.pch
32 |
33 | # Libraries
34 | *.lib
35 | *.a
36 | *.la
37 | *.lo
38 | *.def
39 | *.exp
40 |
41 | # Shared objects (inc. Windows DLLs)
42 | *.dll
43 | *.so
44 | *.so.*
45 | *.dylib
46 |
47 | # Executables
48 | *.exe
49 | *.out
50 | *.app
51 | *.i*86
52 | *.x86_64
53 | *.hex
54 |
55 |
--------------------------------------------------------------------------------
/show_model_content.lua:
--------------------------------------------------------------------------------
1 | require 'paths'
2 | require 'nn'
3 | require 'cutorch'
4 | require 'cunn'
5 | require 'cudnn'
6 | require 'dpnn'
7 |
8 | OPT = lapp[[
9 | --save (default "logs") subdirectory in which the model is saved
10 | --network (default "adversarial.net") name of the model file
11 | ]]
12 |
13 | local filepath = paths.concat(OPT.save, OPT.network)
14 | local tmp = torch.load(filepath)
15 | if tmp.epoch then print("") print("Epoch:") print(tmp.epoch) end
16 | if tmp.opt then print("") print("OPT:") print(tmp.opt) end
17 | if tmp.G then print("") print("G:") print(tmp.G) end
18 | if tmp.G1 then print("") print("G1:") print(tmp.G1) end
19 | if tmp.G2 then print("") print("G2:") print(tmp.G2) end
20 | if tmp.G3 then print("") print("G3:") print(tmp.G3) end
21 | if tmp.D then print("") print("D:") print(tmp.D) end
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 Alexander Jung
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/images/G.xml:
--------------------------------------------------------------------------------
1 | 7Vvfc5s4EP5r/HpjIDjpY51rc3Mz7XQuc3N3jyooWFOBPBgnzv31XaFdAxZusQsEjP1iadHP7/tWSAvMvPt495Cy9eqTCrmcufNwN/N+n7mu4/ge/GnLq7Hc3s2NIUpFiIUKw6P4n6ORim1FyDeVgplSMhPrqjFQScKDrGJjaapeqsWelKz2umYR9VgYHgMmbes/IsxWxnrnLgr7H1xEK+rZWbwzV76y4FuUqm2C/c1c7yn/mcsxo7byiXofAMRUKWhGp+LdPZcaSMLIoPHxyNX9IFOe4EB+XAEu6QrPTG5xnp+V2HBj3WSvNPuXlcj445oFOv8CDM+85SqLJeQcSEaSbTS+c0hvslR922Okrz4JKe+VVCnkE5VAE0smRZRANoBxcrAvcRw8zThJxp5LbsKJPHAV8yx9hSJYwb3zTRVU2AKV81LQdYumVYkpD20MBRLtGy5wgwRCVw8jarsE41IC7bpVoB2GoOHT5WOtMnchoe/l1xRSkU6NEm7nBoE7Dje5eNtw31hw36vkGSwOeCPUnns7YEQnUi63JrX8fCm4E6qIu0OiLgG/56Zt5NHDbOQ9ME8NeVpce0Ee7zMl5D+xHRi+wE1QJNGlQuw3hJio+BWIb4+JewELzuTE/a5Hcd9NUtxe0/WjDXHjbrTmnqnRn5q6vZrdSmfqpiOMjb3rT3DDckN49II9Dveqe4M95XvB3j5d/r3esBgABOnD+n6ZEDddWtpY1R375DnhPYvv4/mkF3Xbp9AJqNv3GkLcirqvx81ygIVU24u67fMmQX+A/EZEsYJqFwp6rzdM+wRqgciT8L0OZxcwlCA0sBFMEGme5z+4wnci+1eD+5uPuf+wDjT4EbBF4E1/PLSC4QcAwpjUNs2ZpPinjWkJMzrHlyEjG/guy8Rztcc6HLGHL0rAWArK6uOQ1IAZJ9Yph8IPmtkPh+Jqh4xmLI14ZjWUk7qfdDOe7fPuOHjOh3nluTHP9qH7RzwHegEUQWOqExhIiWud1WSfRq+ZbSmu/lb0/pSXpgSTMKidzuilLU4XbnyWs46euhuCtHvfpFlOnryfYn4ueYsOPc+OLLS6sBJ/utzY+LNgP/fG2CV/dthiHBsgE28ZyA7IGQPTdpRk0FsgE3IYyo30bF8+2AMd7pVb5NcO0fTE73FHP8+1UahlMbypFtpy9v6kcFr0YsBSwNhXWQr53K5SaCqF0wIcA5YCTqQshfxQf5VCUym8WQykbSngRCq7heKd3asWGrwme1rAZMBaIKgrYhhSdHQEYjgtADNkMdSdI940ljo+MXQc0OlRDHggqogh31BexdBUDHZ06E9o1hIEPOXVz5mPSgDDRk2eHOtnxgI+9nmPF2IRhrqbZd0zawWln2QuxBWU41Ch8vC6jW9ZQLcV0mpeFe3qrQuv9vOKgBlaJ0KAQ68/79VuE1Dns60Q0G1Mpe0Fb1CLW1tPhW/JuzpY3uxXa/56APAbfx12yW5HHzMRDTXrXt1HfGe4HWSL7ywNjcWXq96H7w==
--------------------------------------------------------------------------------
/weight-init.lua:
--------------------------------------------------------------------------------
1 | -- Source:
2 | -- https://github.com/e-lab/torch-toolbox/blob/master/Weight-init/weight-init.lua
3 |
4 | --
5 | -- Different weight initialization methods
6 | --
7 | -- > model = require('weight-init')(model, 'heuristic')
8 | --
9 | require("nn")
10 |
11 |
12 | -- "Efficient backprop"
13 | -- Yann Lecun, 1998
14 | local function w_init_heuristic(fan_in, fan_out)
15 | return math.sqrt(1/(3*fan_in))
16 | end
17 |
18 |
19 | -- "Understanding the difficulty of training deep feedforward neural networks"
20 | -- Xavier Glorot, 2010
21 | local function w_init_xavier(fan_in, fan_out)
22 | return math.sqrt(2/(fan_in + fan_out))
23 | end
24 |
25 |
26 | -- "Understanding the difficulty of training deep feedforward neural networks"
27 | -- Xavier Glorot, 2010
28 | local function w_init_xavier_caffe(fan_in, fan_out)
29 | return math.sqrt(1/fan_in)
30 | end
31 |
32 |
33 | -- "Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification"
34 | -- Kaiming He, 2015
35 | local function w_init_kaiming(fan_in, fan_out)
36 | return math.sqrt(4/(fan_in + fan_out))
37 | end
38 |
39 |
40 | local function w_init(net, arg)
41 | -- choose initialization method
42 | local method = nil
43 | if arg == 'heuristic' then method = w_init_heuristic
44 | elseif arg == 'xavier' then method = w_init_xavier
45 | elseif arg == 'xavier_caffe' then method = w_init_xavier_caffe
46 | elseif arg == 'kaiming' then method = w_init_kaiming
47 | else
48 | assert(false)
49 | end
50 |
51 | -- loop over all convolutional modules
52 | for i = 1, #net.modules do
53 | local m = net.modules[i]
54 | if m.__typename == 'nn.SpatialConvolution' then
55 | m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
56 | elseif m.__typename == 'nn.SpatialConvolutionMM' then
57 | m:reset(method(m.nInputPlane*m.kH*m.kW, m.nOutputPlane*m.kH*m.kW))
58 | elseif m.__typename == 'nn.LateralConvolution' then
59 | m:reset(method(m.nInputPlane*1*1, m.nOutputPlane*1*1))
60 | elseif m.__typename == 'nn.VerticalConvolution' then
61 | m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
62 | elseif m.__typename == 'nn.HorizontalConvolution' then
63 | m:reset(method(1*m.kH*m.kW, 1*m.kH*m.kW))
64 | elseif m.__typename == 'nn.Linear' then
65 | m:reset(method(m.weight:size(2), m.weight:size(1)))
66 | elseif m.__typename == 'nn.TemporalConvolution' then
67 | m:reset(method(m.weight:size(2), m.weight:size(1)))
68 | end
69 |
70 | if m.bias then
71 | m.bias:zero()
72 | end
73 | end
74 | return net
75 | end
76 |
77 |
78 | return w_init
79 |
--------------------------------------------------------------------------------
/models_rgb.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'nn'
3 | require 'dpnn'
4 | require 'cudnn'
5 |
6 | local models = {}
7 |
8 | -- Creates the generator model (G).
9 | -- @param dimensions The dimensions of each image as {channels, height, width}.
10 | -- @param cuda Whether to activate GPU mode for the model.
11 | -- @returns nn.Sequential
12 | function models.create_G(dimensions, cuda)
13 | local model = nn.Sequential()
14 |
15 | model:add(nn.JoinTable(2, 2))
16 |
17 | if cuda then
18 | model:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor', true, true))
19 | end
20 |
21 | local inner = nn.Sequential()
22 | local conc = nn.Concat(2)
23 | local left = nn.Sequential()
24 | local right = nn.Sequential()
25 |
26 | left:add(nn.Identity())
27 |
28 | right:add(cudnn.SpatialConvolution(1+1, 16, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
29 | right:add(nn.SpatialBatchNormalization(16))
30 | right:add(cudnn.ReLU(true))
31 |
32 | right:add(cudnn.SpatialConvolution(16, 32, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
33 | right:add(nn.SpatialBatchNormalization(32))
34 | right:add(cudnn.ReLU(true))
35 | right:add(nn.SpatialMaxPooling(2, 2))
36 |
37 | right:add(cudnn.SpatialConvolution(32, 64, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
38 | right:add(nn.SpatialBatchNormalization(64))
39 | right:add(cudnn.ReLU(true))
40 | right:add(nn.SpatialMaxPooling(2, 2))
41 |
42 | right:add(cudnn.SpatialConvolution(64, 128, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
43 | right:add(nn.SpatialBatchNormalization(128))
44 | right:add(cudnn.ReLU(true))
45 |
46 | right:add(cudnn.SpatialConvolution(128, 256, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
47 | right:add(nn.SpatialBatchNormalization(256))
48 | right:add(cudnn.ReLU(true))
49 |
50 | right:add(nn.SpatialUpSamplingNearest(2))
51 | right:add(cudnn.SpatialConvolution(256, 128, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
52 | right:add(nn.SpatialBatchNormalization(128))
53 | right:add(cudnn.ReLU(true))
54 |
55 | right:add(nn.SpatialUpSamplingNearest(2))
56 | right:add(cudnn.SpatialConvolution(128, 64, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
57 | right:add(nn.SpatialBatchNormalization(64))
58 | right:add(cudnn.ReLU(true))
59 |
60 | conc:add(left)
61 | conc:add(right)
62 | inner:add(conc)
63 |
64 | inner:add(cudnn.SpatialConvolution(2+64, 32, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
65 | inner:add(nn.SpatialBatchNormalization(32))
66 | inner:add(cudnn.ReLU(true))
67 |
68 | inner:add(cudnn.SpatialConvolution(32, 3, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
69 | inner:add(nn.Sigmoid())
70 |
71 | model:add(inner)
72 |
73 | if cuda then
74 | model:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true))
75 | inner:cuda()
76 | end
77 |
78 | model = require('weight-init')(model, 'heuristic')
79 |
80 | return model
81 | end
82 |
83 | -- Creates the discriminator model (D).
84 | -- @param dimensions The dimensions of each image as {channels, height, width}.
85 | -- @param cuda Whether to activate GPU mode for the model.
86 | -- @returns nn.Sequential
87 | function models.create_D(dimensions, cuda)
88 | local model = nn.Sequential()
89 |
90 | --model:add(nn.CAddTable())
91 | model:add(nn.JoinTable(2, 2))
92 |
93 | if cuda then
94 | model:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor', true, true))
95 | end
96 |
97 | local inner = nn.Sequential()
98 |
99 | -- 64x64
100 | inner:add(nn.SpatialConvolution(3+1, 64, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
101 | inner:add(cudnn.ReLU(true))
102 | inner:add(nn.SpatialDropout(0.25))
103 | inner:add(nn.SpatialAveragePooling(2, 2, 2, 2))
104 |
105 | -- 32x32
106 | inner:add(nn.SpatialConvolution(64, 128, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
107 | inner:add(cudnn.ReLU(true))
108 | inner:add(nn.SpatialDropout(0.25))
109 |
110 | -- 32x32
111 | inner:add(nn.SpatialConvolution(128, 256, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
112 | inner:add(cudnn.ReLU(true))
113 | inner:add(nn.SpatialDropout(0.25))
114 | inner:add(nn.SpatialMaxPooling(2, 2))
115 |
116 | -- 16x16
117 | inner:add(nn.SpatialConvolution(256, 256, 3, 3, 1, 1, (3-1)/2, (3-1)/2))
118 | inner:add(cudnn.ReLU(true))
119 | inner:add(nn.SpatialDropout(0.5))
120 | inner:add(nn.SpatialMaxPooling(2, 2))
121 |
122 | local height = dimensions[2] * 0.5 * 0.5 * 0.5
123 | local width = dimensions[3] * 0.5 * 0.5 * 0.5
124 |
125 | -- 8x8
126 | inner:add(nn.View(256*height*width))
127 | inner:add(nn.Linear(256*height*width, 128))
128 | inner:add(nn.PReLU())
129 | inner:add(nn.Dropout(0.5))
130 | inner:add(nn.Linear(128, 1))
131 | inner:add(nn.Sigmoid())
132 |
133 | model:add(inner)
134 |
135 | if cuda then
136 | model:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor', true, true))
137 | inner:cuda()
138 | end
139 |
140 | model = require('weight-init')(model, 'heuristic')
141 |
142 | return model
143 | end
144 |
145 | return models
146 |
--------------------------------------------------------------------------------
/sample.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'image'
3 | require 'paths'
4 | require 'pl'
5 | require 'cudnn'
6 | NN_UTILS = require 'utils.nn_utils'
7 | DATASET = require 'dataset_rgb'
8 |
9 | OPT = lapp[[
10 | --save (default "logs") Directory in which the networks are stored.
11 | --network (default "adversarial.net") Filename of the network to use.
12 | --neighbours Whether to search for nearest neighbours of generated images in the dataset (takes long)
13 | --writeto (default "samples") Directory to save the images to
14 | --seed (default 1) Random number seed to use.
15 | --gpu (default 0) GPU to run on
16 | --runs (default 1) How often to sample and save images
17 | --noiseDim (default 100) Noise vector size.
18 | --batchSize (default 16) Sizes of batches.
19 | --dataset (default "NONE") Directory that contains *.jpg images
20 | ]]
21 |
22 | if OPT.gpu < 0 then
23 | print("[ERROR] Sample script currently only runs on GPU, set --gpu=x where x is between 0 and 3.")
24 | exit()
25 | end
26 |
27 | -- Start GPU mode
28 | print("Starting gpu support...")
29 | require 'cutorch'
30 | require 'cunn'
31 | torch.setdefaulttensortype('torch.FloatTensor')
32 | cutorch.setDevice(OPT.gpu + 1)
33 |
34 | -- initialize seeds
35 | math.randomseed(OPT.seed)
36 | torch.manualSeed(OPT.seed)
37 | cutorch.manualSeed(OPT.seed)
38 |
39 | -- Initialize dataset
40 | DATASET.setFileExtension("jpg")
41 |
42 | -- Main function that runs the sampling
43 | function main()
44 | -- Load all models
45 | local G, D, height, width, dataset = loadModels()
46 |
47 | -- Image dimensions
48 | IMG_DIMENSIONS = {3, height, width}
49 | NOISE_DIM = {1, height, width}
50 |
51 | DATASET.setHeight(height)
52 | DATASET.setWidth(width)
53 | if OPT.dataset ~= "NONE" or dataset == nil then
54 | DATASET.setDirs({OPT.dataset})
55 | else
56 | DATASET.setDirs({dataset})
57 | end
58 |
59 | print("Sampling...")
60 | for run=1,OPT.runs do
61 | -- save 64 randomly selected images from the training set
62 | local imagesTrainList = DATASET.loadRandomImages(64)
63 | -- dont use nn_utils.toImageTensor here, because the metatable of imagesTrainList was changed
64 | local imagesTrain = torch.Tensor(#imagesTrainList, imagesTrainList[1].grayscale:size(1), imagesTrainList[1].grayscale:size(2), imagesTrainList[1].grayscale:size(3))
65 | for i=1,#imagesTrainList do
66 | imagesTrain[i] = imagesTrainList[i].grayscale
67 | end
68 | image.save(paths.concat(OPT.writeto, string.format('trainset_s1_%04d_base.jpg', run)), toGrid(imagesTrainList, nil, 8))
69 |
70 | -- sample 64 colorizations from G
71 | local noise = torch.Tensor(64, NOISE_DIM[1], NOISE_DIM[2], NOISE_DIM[3])
72 | noise:uniform(0, 1)
73 | local imagesGenerated = G:forward({noise, imagesTrain})
74 |
75 | -- validate image dimensions
76 | if imagesGenerated[1]:size(1) ~= IMG_DIMENSIONS[1] or imagesGenerated[1]:size(2) ~= IMG_DIMENSIONS[2] or imagesGenerated[1]:size(3) ~= IMG_DIMENSIONS[3] then
77 | print("[WARNING] dimension mismatch between images generated by base G and command line parameters")
78 | print("Dimension G:", images[1]:size())
79 | print("Settings:", IMG_DIMENSIONS)
80 | end
81 |
82 | -- save big images of those 1024 random images
83 | image.save(paths.concat(OPT.writeto, string.format('random64_%04d.jpg', run)), toGrid(imagesTrainList, imagesGenerated, 12))
84 |
85 | xlua.progress(run, OPT.runs)
86 | end
87 |
88 | print("Finished.")
89 | end
90 |
91 | -- Converts images to one image grid with set amount of rows.
92 | -- @param images Tensor of images
93 | -- @param nrow Number of rows.
94 | -- @return Tensor
95 | function toGrid(imagesOriginal, imagesGenerated, nrow)
96 | local N = 2 * imagesOriginal:size()
97 | if imagesGenerated ~= nil then N = N + imagesGenerated:size(1) end
98 | local images = torch.Tensor(N, 3, imagesOriginal[1].color:size(2), imagesOriginal[1].color:size(3))
99 | local idx = 1
100 | for i=1,#imagesOriginal do
101 | images[{{idx}, {1}, {}, {}}] = imagesOriginal[i].grayscale
102 | images[{{idx}, {2}, {}, {}}] = imagesOriginal[i].grayscale
103 | images[{{idx}, {3}, {}, {}}] = imagesOriginal[i].grayscale
104 | images[idx+1] = imagesOriginal[i].color
105 | if imagesGenerated ~= nil then
106 | images[idx+2] = imagesGenerated[i]
107 | idx = idx + 1
108 | end
109 | idx = idx + 2
110 | end
111 |
112 | return image.toDisplayTensor{input=images, nrow=nrow}
113 | end
114 |
115 | -- Loads all necessary models/networks and returns them.
116 | -- @returns G, D, height, width, dataset directory
117 | function loadModels()
118 | local file = torch.load(paths.concat(OPT.save, OPT.network))
119 |
120 | local G = file.G
121 | local D = file.D
122 | local opt_loaded = file.opt
123 | G:evaluate()
124 | D:evaluate()
125 |
126 | return G, D, opt_loaded.height, opt_loaded.width, opt_loaded.dataset
127 | end
128 |
129 | main()
130 |
--------------------------------------------------------------------------------
/dataset_rgb.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'image'
3 | require 'paths'
4 |
5 | local dataset = {}
6 |
7 | -- load data from these directories
8 | dataset.dirs = {}
9 |
10 | -- load only images with this file extension
11 | dataset.fileExtension = ""
12 |
13 | -- expected original height/width of images
14 | dataset.originalHeight = 64
15 | dataset.originalWidth = 64
16 |
17 | -- desired height/width of images
18 | dataset.height = 32
19 | dataset.width = 32
20 |
21 | --dataset.colorSpace = "rgb"
22 |
23 | -- cache for filepaths to all images
24 | dataset.paths = nil
25 |
26 | -- Set directories to load images from
27 | -- @param dirs List of paths to directories
28 | function dataset.setDirs(dirs)
29 | dataset.dirs = dirs
30 | end
31 |
32 | -- Set file extension that images to load must have
33 | -- @param fileExtension the file extension of the images
34 | function dataset.setFileExtension(fileExtension)
35 | dataset.fileExtension = fileExtension
36 | end
37 |
38 | -- Desired height of the images (will be resized if necessary)
39 | -- @param scale The height of the images
40 | function dataset.setHeight(height)
41 | dataset.height = height
42 | end
43 |
44 | -- Desired height of the images (will be resized if necessary)
45 | -- @param scale The height of the images
46 | function dataset.setWidth(width)
47 | dataset.width = width
48 | end
49 |
50 | -- Set desired number of channels for the images (1=grayscale, 3=color)
51 | -- @param nbChannels The number of channels
52 | function dataset.setNbChannels(nbChannels)
53 | dataset.nbChannels = nbChannels
54 | end
55 |
56 | -- Loads the paths of all images in the defined files
57 | -- (with defined file extensions)
58 | function dataset.loadPaths()
59 | local files = {}
60 | local dirs = dataset.dirs
61 | local ext = dataset.fileExtension
62 |
63 | for i=1, #dirs do
64 | local dir = dirs[i]
65 | -- Go over all files in directory. We use an iterator, paths.files().
66 | for file in paths.files(dir) do
67 | -- We only load files that match the extension
68 | if file:find(ext .. '$') then
69 | -- and insert the ones we care about in our table
70 | table.insert(files, paths.concat(dir,file))
71 | end
72 | end
73 |
74 | -- sort for reproduceability
75 | table.sort(files, function (a,b) return a < b end)
76 |
77 | -- Check files
78 | if #files == 0 then
79 | error('given directory doesnt contain any files of type: ' .. ext)
80 | end
81 | end
82 |
83 | dataset.paths = files
84 | end
85 |
86 | -- Load images from the dataset.
87 | -- @param startAt Number of the first image.
88 | -- @param count Count of the images to load.
89 | -- @return Table of images. You can call :size() on that table to get the number of loaded images.
90 | function dataset.loadImages(startAt, count)
91 | --local endBefore = startAt + count
92 | if dataset.paths == nil then
93 | dataset.loadPaths()
94 | end
95 |
96 | local N = math.min(count, #dataset.paths)
97 | local images = torch.FloatTensor(N, 3, dataset.height, dataset.width)
98 | for i=0,(N-1) do
99 | local img = image.load(dataset.paths[startAt + i], dataset.nbChannels, "float")
100 | img = image.scale(img, dataset.width, dataset.height)
101 | --print(img, startAt, startAt+i, count, endBefore)
102 | images[i+1] = img
103 | end
104 | images = NN_UTILS.rgbToColorSpace(images, dataset.colorSpace)
105 |
106 | local result = {}
107 | result.data = images
108 |
109 | function result:size()
110 | return N
111 | end
112 |
113 | setmetatable(result, {
114 | __index = function(self, index) return self.data[index] end,
115 | __len = function(self) return self.data:size(1) end
116 | })
117 |
118 | return result
119 | end
120 |
121 | -- Loads a defined number of randomly selected images from
122 | -- the cached paths (cached in loadPaths()).
123 | -- @param count Number of random images.
124 | -- @return List of Tensors
125 | function dataset.loadRandomImages(count)
126 | local images = dataset.loadRandomImagesFromPaths(count)
127 | local data = torch.FloatTensor(#images, 3, dataset.height, dataset.width)
128 | for i=1, #images do
129 | --data[i] = image.scale(images[i], dataset.width, dataset.height)
130 | data[i] = images[i]
131 | end
132 | local data_yuv = NN_UTILS.rgbToColorSpace(data, "yuv")
133 |
134 | local N = data:size(1)
135 | local result = {}
136 | result.color = data
137 | result.uv = data_yuv[{{}, {2,3}, {}, {}}] -- remove y channel from yuv
138 | result.grayscale = data_yuv[{{}, {1}, {}, {}}] -- only y channel from yuv
139 |
140 | --[[
141 | image.display(NN_UTILS.switchColorSpace(result.color, "yuv", "rgb")[1])
142 | image.display(result.uv[1])
143 | image.display(result.grayscale[1])
144 | io.read()
145 | --]]
146 |
147 | function result:size()
148 | return N
149 | end
150 |
151 | setmetatable(result, {
152 | __index = function(self, index)
153 | return {
154 | color = result.color[index],
155 | grayscale = result.grayscale[index],
156 | uv = result.uv[index]
157 | }
158 | end,
159 | __len = function(self) return self.color:size(1) end
160 | })
161 |
162 | return result
163 | end
164 |
165 | -- Loads randomly selected images from the cached paths.
166 | -- TODO: merge with loadRandomImages()
167 | -- @param count Number of images to load
168 | -- @returns List of Tensors
169 | function dataset.loadRandomImagesFromPaths(count)
170 | if dataset.paths == nil then
171 | dataset.loadPaths()
172 | end
173 |
174 | local shuffle = torch.randperm(#dataset.paths)
175 |
176 | local images = {}
177 | for i=1,math.min(shuffle:size(1), count) do
178 | -- load each image
179 | --table.insert(images, image.load(dataset.paths[shuffle[i]], 3, "float"))
180 | local fp = dataset.paths[shuffle[i]]
181 | local img = image.load(fp, 3, "float")
182 | img = image.scale(img, dataset.width, dataset.height)
183 | table.insert(images, img)
184 | end
185 |
186 | return images
187 | end
188 |
189 | return dataset
190 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # About
2 |
3 | This project uses GANs ([generative adversarial networks](http://papers.nips.cc/paper/5423-generative-adversarial-nets)) to add color to black and white images.
4 | For each such image, the generator network (G) receives its black and white version and outputs a full RGB version of the image (i.e. the black and white image with color added to it).
5 | That RGB version is then rated (in regards to its quality) by the discriminator (D).
6 | The quality measure is backpropagated through D and then through G.
7 | Thereby G can learn to correctly colorize images.
8 | The architectures used are modifications of the [DCGAN](http://arxiv.org/abs/1511.06434) schema.
9 | See [this blog post](http://tinyclouds.org/colorize/) for an alternative version, which uses standard convnets (i.e. no GANs) with VGG features.
10 |
11 | Key results:
12 | * If a dataset of images can be generated by a GAN, then a GAN can also learn to add colors to it.
13 | * The task of adding colors seems to be a bit easier than the full generation of images.
14 | * G did not learn to add colors to rather rare and small elements (e.g. when coloring images of christmas trees it didn't add color to presents below the trees, small baubles or clothes of people in the image). This might partly be a limitation of the architecture, which uses pooling layers in G (hence small elements might get lost).
15 | * G did not learn to correctly add colors to datasets with high variance (heterogeneous collections of images). It would resort to mostly just adding one or two colors everywhere.
16 | * I experimented with using VGG features but didn't have much success with those. G didn't seem to learn more than without VGG features. My tests were limited though due to hardware constraints (VGG + G + D = three big networks in memory). It did not try the hypercolumn that was used in the [previously mentioned blog post](http://tinyclouds.org/colorize/).
17 | * Producing UV values in G and combining them with Y to an YUV image (which is then fed into D) failed. G just wouldn't learn anything. G had to output full RGB images to learn successfully. Not sure if there was a bug somewhere or if there's a good reason for that effect.
18 |
19 | # Images
20 |
21 | Colorizers were trained on multiple image datasets which were reused from previous projects. (I.e. *multiple* GANs were trained, not just one for all images. That's due to GANs not being very good at handling heterogeneous datasets.)
22 | Besides of the datasets shown below, the MSCOCO 2014 validation dataset was also used, but G failed to learn much on that one (it added mostly just 1-3 uniform colors per image), hence the results of that run are not shown.
23 |
24 | Notes:
25 | * There were no splits into training and validation sets (partly due to laziness, partly because GANs in my experience basically never just memorize the training set). Note how the coloring in the images below is often different from the original coloring.
26 | * Training times were usually quite fast (<=2 hours per dataset).
27 | * All generated color images were a little bit blurry, probably because G generated full RGB images instead of just adding color (UV in YUV). As such, it has to learn to copy the Y channel information correctly while still adding colors.
28 |
29 | ## Human faces
30 |
31 | This dataset worked fairly well. Notice the image in the 10th row at the far right. G assigns a skin color to the microphone. Also notice how G usually doesn't add red color to the lips. Maybe they get lost during the pooling...?
32 |
33 | 
34 |
35 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.*
36 |
37 | ## Cat faces
38 |
39 | This dataset worked fairly well.
40 |
41 | 
42 |
43 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.*
44 |
45 | ## Skies
46 |
47 | Here G created sometimes weird mixtures of blue and orange. They were not visible in earlier epochs, but those then had weird vertical stripes around the borders of the images.
48 |
49 | 
50 |
51 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.*
52 |
53 | ## Baubles
54 |
55 | This dataset already caused problems when I tried to generate it (i.e. full image generation, not just colorization). It didn't work too well here either. Baubles often remained colorless. I had to carefully select the optimal epoch to generate half decent images. There are blue blobs in some of the images. These blobs become bigger if the experiment is run longer.
56 |
57 | 
58 |
59 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.*
60 |
61 | ## Snowy landscapes
62 |
63 | Here G only had to either keep the black and white image or add some blue color, fairly easy task. It mostly learned that and sometimes exaggerated the blue (e.g. by adding it to trees).
64 |
65 | 
66 |
67 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.*
68 |
69 | ## Christmas trees
70 |
71 | This dataset worked fairly well. When zooming in you can see that G doesn't add color to presents, baubles and people's clothings. E.g. for baubles look at row=9, col=3, for clothes row=8, col=1 and for presents row=1, col=2 (indices starting at 1).
72 |
73 | 
74 |
75 | *For each tuple: (left) Original image in black and white, (middle) original image in color, (right) Color added by G.*
76 |
77 |
78 | # Architecture
79 |
80 | The architecture of D was a standard convolutional neural net with one small fully connected layer at the end, mostly ReLU activations and some spatial dropout.
81 |
82 | G is an upsampling generator, similar to what is described in the DCGAN paper. Before the upsampling part it has an image analyzation part, similar to a standard convolutional network. The analyzation part takes the black and white image and tries to make sense of it with some convolutional and pooling layers. The black and white image is fed into the network for a second time towards the end so that the analyzation and upsampling parts can focus on the color and don't have to transfer the Y channel information through the layers. The noise layer is fed into the network both at the start and the end for technical simplicity (it just gets joined with the black and white image).
83 |
84 | 
85 |
86 | *Architecture of G. The formulation "Conv K, 3x3, relu, BN" denotes a convolutional layer with K kernels/planes, each with filter size 3x3, ReLU activation and batch normalization. Max pooling is over a 2x2 area. Upsampling layers increase height and width each by a factor of 2. The noise layer and the black and white image usually both have size 1x64x64.*
87 |
88 | D had about 3 million parameters (exact value depends on the input image size), G about 0.8 million.
89 |
90 | # Usage
91 |
92 | Requirements are:
93 | * Torch
94 | * Required packages (most of them should be part of the default torch install, install missing ones with `luarocks install packageName`): `cudnn`, `nn`, `pl`, `paths`, `image`, `optim`, `cutorch`, `cunn`, `cudnn`, `dpnn`, `display`
95 | * Image datasets have to be downloaded from previous projects and will likely require Python 2.7. You can however use your own dataset, provided that the images in that one are square or have an aspect ratio of 2.0 (e.g. height=64, width=32). Other ratios might work but haven't been tested.
96 | * [Human faces](https://github.com/aleju/face-generator)
97 | * [Cat faces](https://github.com/aleju/cat-generator)
98 | * [Skies](https://github.com/aleju/sky-generator)
99 | * [Christmas trees, baubles, snowy landscapes](https://github.com/aleju/christmas-generator)
100 | * NVIDIA GPU with cudnn3 and 4GB or more memory
101 |
102 | To train a network:
103 | * `~/.display/run.js &` - This will start `display`, which is used to plot results in the browser
104 | * Open http://localhost:8000/ in your browser (`display` interface)
105 | * Open a console in the repository directory and then `th train_rgb.lua --dataset="DATASET_PATH" --height=64 --width=64`, where `DATASET_PATH` is the filepath to the directory containing all your images (must be jpg). `height` and `width` resemble the size of the *generated* images. Your source images in that directory may be larger (e.g. 256x256). Only 32x32 (height x width), 64x64 and 32x64 were tested. Other values might result in errors. Note: Training keeps running until stopped manually with ctrl+c.
106 |
107 | To continue a training session use `th train_rgb.lua --dataset="DATSET_PATH" --height=64 --width=64 --network="logs/adversarial.net"`.
108 | To sample images (i.e. colorize images from the training set) use `th sample.lua` (should automatically reuse dataset directory, height and width).
109 |
--------------------------------------------------------------------------------
/train_rgb.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'image'
3 | require 'pl' -- this is somehow responsible for lapp working in qlua mode
4 | require 'paths'
5 | ok, DISP = pcall(require, 'display')
6 | if not ok then print('display not found. unable to plot') end
7 | ADVERSARIAL = require 'adversarial_rgb'
8 | DATASET = require 'dataset_rgb'
9 | NN_UTILS = require 'utils.nn_utils'
10 | MODELS = require 'models_rgb'
11 |
12 | ----------------------------------------------------------------------
13 | -- parse command-line options
14 | OPT = lapp[[
15 | --save (default "logs") subdirectory to save logs
16 | --saveFreq (default 30) save every saveFreq epochs
17 | --network (default "") reload pretrained network
18 | --G_pretrained_dir (default "logs")
19 | --noplot plot while training
20 | --D_sgd_lr (default 0.02) D SGD learning rate
21 | --G_sgd_lr (default 0.02) G SGD learning rate
22 | --D_sgd_momentum (default 0) D SGD momentum
23 | --G_sgd_momentum (default 0) G SGD momentum
24 | --batchSize (default 32) batch size
25 | --N_epoch (default 30) Number of batches per epoch
26 | --G_L1 (default 0) L1 penalty on the weights of G
27 | --G_L2 (default 0e-6) L2 penalty on the weights of G
28 | --D_L1 (default 0e-7) L1 penalty on the weights of D
29 | --D_L2 (default 1e-4) L2 penalty on the weights of D
30 | --D_iterations (default 1) number of iterations to optimize D for
31 | --G_iterations (default 1) number of iterations to optimize G for
32 | --D_maxAcc (default 1.01) Deactivate learning of D while above this threshold
33 | --D_clamp (default 1) Clamp threshold for D's gradient (+/- N)
34 | --G_clamp (default 5) Clamp threshold for G's gradient (+/- N)
35 | --D_optmethod (default "adam") sgd|adagrad|adadelta|adamax|adam|rmsprob
36 | --G_optmethod (default "adam") sgd|adagrad|adadelta|adamax|adam|rmsprob
37 | --threads (default 4) number of threads
38 | --gpu (default 0) gpu to run on (default cpu)
39 | --noiseDim (default 100) dimensionality of noise vector
40 | --window (default 3) window id of sample image
41 | --seed (default 1) seed for the RNG
42 | --nopretraining Whether to deactivate loading of pretrained networks
43 | --height (default 64) Height of the training images
44 | --width (default 64) Width of the training images
45 | --dataset (default "NONE") Directory that contains *.jpg images
46 | ]]
47 |
48 | NORMALIZE = false
49 | START_TIME = os.time()
50 |
51 | if OPT.gpu < 0 or OPT.gpu > 3 then OPT.gpu = false end
52 | print(OPT)
53 |
54 | -- fix seed
55 | math.randomseed(OPT.seed)
56 | torch.manualSeed(OPT.seed)
57 |
58 | -- threads
59 | torch.setnumthreads(OPT.threads)
60 | print(' set nb of threads to ' .. torch.getnumthreads())
61 |
62 | -- possible output of disciminator
63 | CLASSES = {"0", "1"}
64 | Y_GENERATOR = 0
65 | Y_NOT_GENERATOR = 1
66 |
67 | -- axis of images: 3 channels, height, width
68 | IMG_DIMENSIONS = {3, OPT.height, OPT.width}
69 | COND_DIM = {1, OPT.height, OPT.width}
70 | NOISE_DIM = {1, OPT.height, OPT.width}
71 |
72 | ----------------------------------------------------------------------
73 | -- get/create dataset
74 | ----------------------------------------------------------------------
75 | assert(OPT.dataset ~= "NONE")
76 | DATASET.setFileExtension("jpg")
77 | DATASET.setHeight(IMG_DIMENSIONS[2])
78 | DATASET.setWidth(IMG_DIMENSIONS[3])
79 | DATASET.setDirs({OPT.dataset})
80 | ----------------------------------------------------------------------
81 |
82 | -- run on gpu if chosen
83 | -- We have to load all kinds of libraries here, otherwise we risk crashes when loading
84 | -- saved networks afterwards
85 | print(" starting gpu support...")
86 | require 'nn'
87 | require 'cutorch'
88 | require 'cunn'
89 | require 'dpnn'
90 | if OPT.gpu then
91 | cutorch.setDevice(OPT.gpu + 1)
92 | cutorch.manualSeed(OPT.seed)
93 | print(string.format(" using gpu device %d", OPT.gpu))
94 | end
95 | torch.setdefaulttensortype('torch.FloatTensor')
96 |
97 | function main()
98 | ----------------------------------------------------------------------
99 | -- Load / Define network
100 | ----------------------------------------------------------------------
101 |
102 | -- load previous networks (D and G)
103 | -- or initialize them new
104 | if OPT.network ~= "" then
105 | print(string.format(" reloading previously trained network: %s", OPT.network))
106 | local tmp = torch.load(OPT.network)
107 | MODEL_D = tmp.D
108 | MODEL_G = tmp.G
109 | EPOCH = tmp.epoch + 1
110 | VIS_NOISE_INPUTS = tmp.vis_noise_inputs
111 | if NORMALIZE then
112 | NORMALIZE_MEAN = tmp.normalize_mean
113 | NORMALIZE_STD = tmp.normalize_std
114 | end
115 |
116 | if OPT.gpu == false then
117 | MODEL_D:float()
118 | MODEL_G:float()
119 | end
120 | else
121 | local pt_filename = paths.concat(OPT.save, string.format('pretrained_%dx%dx%d_nd%d.net', IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3], OPT.noiseDim))
122 | -- pretrained via pretrain_with_previous_net.lua ?
123 | if not OPT.nopretraining and paths.filep(pt_filename) then
124 | local tmp = torch.load(pt_filename)
125 | MODEL_D = tmp.D
126 | MODEL_G = tmp.G
127 | MODEL_D:training()
128 | MODEL_G:training()
129 | if OPT.gpu == false then
130 | MODEL_D:float()
131 | MODEL_G:float()
132 | end
133 | else
134 | --------------
135 | -- D
136 | --------------
137 | MODEL_D = MODELS.create_D(IMG_DIMENSIONS, OPT.gpu ~= false)
138 |
139 | --------------
140 | -- G
141 | --------------
142 | local g_pt_filename = paths.concat(OPT.G_pretrained_dir, string.format('g_pretrained_%dx%dx%d_nd%d.net', IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3], OPT.noiseDim))
143 | if not OPT.nopretraining and paths.filep(g_pt_filename) then
144 | -- Load a pretrained version of G
145 | print(" loading pretrained G...")
146 | local tmp = torch.load(g_pt_filename)
147 | MODEL_G = tmp.G
148 | MODEL_G:training()
149 | if OPT.gpu == false then
150 | MODEL_G:float()
151 | end
152 | else
153 | print(" Note: Did not find pretrained G")
154 | MODEL_G = MODELS.create_G(IMG_DIMENSIONS, OPT.gpu ~= false)
155 | end
156 | end
157 | end
158 |
159 | print(MODEL_G)
160 | print(MODEL_D)
161 |
162 | -- count free parameters in D/G
163 | print(string.format('Number of free parameters in D: %d', NN_UTILS.getNumberOfParameters(MODEL_D)))
164 | print(string.format('Number of free parameters in G: %d', NN_UTILS.getNumberOfParameters(MODEL_G)))
165 |
166 | -- loss function: negative log-likelihood
167 | CRITERION = nn.BCECriterion()
168 |
169 | -- retrieve parameters and gradients
170 | PARAMETERS_D, GRAD_PARAMETERS_D = MODEL_D:getParameters()
171 | PARAMETERS_G, GRAD_PARAMETERS_G = MODEL_G:getParameters()
172 |
173 | -- this matrix records the current confusion across classes
174 | CONFUSION = optim.ConfusionMatrix(CLASSES)
175 |
176 | -- Set optimizer states
177 | OPTSTATE = {
178 | adagrad = { D = {}, G = {} },
179 | adadelta = { D = {}, G = {} },
180 | adamax = { D = {}, G = {} },
181 | adam = { D = {}, G = {} },
182 | rmsprop = {D = {}, G = {}},
183 | sgd = {
184 | D = {learningRate = OPT.D_sgd_lr, momentum = OPT.D_sgd_momentum},
185 | G = {learningRate = OPT.G_sgd_lr, momentum = OPT.G_sgd_momentum}
186 | }
187 | }
188 |
189 | -- Whether to normalize the images. Not used for this project.
190 | if NORMALIZE then
191 | if NORMALIZE_MEAN == nil then
192 | TRAIN_DATA = DATASET.loadRandomImages(10000)
193 | NORMALIZE_MEAN, NORMALIZE_STD = TRAIN_DATA.normalize()
194 | end
195 | end
196 |
197 | if EPOCH == nil then
198 | EPOCH = 1
199 | end
200 |
201 | PLOT_DATA = {}
202 |
203 | -- Noise vectors. Not used for this project.
204 | if VIS_NOISE_INPUTS == nil then
205 | VIS_NOISE_INPUTS = NN_UTILS.createNoiseInputs(100)
206 | end
207 |
208 | -- Example images to use for plotting during training.
209 | EXAMPLE_IMAGES = DATASET.loadRandomImages(48)
210 |
211 | -- training loop
212 | while true do
213 | print('Loading new training data...')
214 | TRAIN_DATA = DATASET.loadRandomImages(OPT.N_epoch * OPT.batchSize)
215 | if NORMALIZE then
216 | TRAIN_DATA.normalize(NORMALIZE_MEAN, NORMALIZE_STD)
217 | end
218 |
219 | -- Show images and plots if requested
220 | if not OPT.noplot then
221 | --visualizeProgress(MODEL_G, MODEL_D, VIS_NOISE_INPUTS, TRAIN_DATA)
222 | visualizeProgressConditional()
223 | end
224 |
225 | -- Train D and G
226 | -- ... but train D only while having an accuracy below OPT.D_maxAcc
227 | -- over the last math.max(20, math.min(1000/OPT.batchSize, 250)) batches
228 | ADVERSARIAL.train(TRAIN_DATA, OPT.D_maxAcc, math.max(20, math.min(1000/OPT.batchSize, 250)))
229 |
230 | -- Save current net
231 | if EPOCH % OPT.saveFreq == 0 then
232 | local filename = paths.concat(OPT.save, 'adversarial.net')
233 | saveAs(filename)
234 | end
235 |
236 | EPOCH = EPOCH + 1
237 | end
238 | end
239 |
240 | -- Save the current models G and D to a file.
241 | -- @param filename The path to the file
242 | function saveAs(filename)
243 | os.execute(string.format("mkdir -p %s", sys.dirname(filename)))
244 | if paths.filep(filename) then
245 | os.execute(string.format("mv %s %s.old", filename, filename))
246 | end
247 | print(string.format(" saving network to %s", filename))
248 | NN_UTILS.prepareNetworkForSave(MODEL_G)
249 | NN_UTILS.prepareNetworkForSave(MODEL_D)
250 | torch.save(filename, {D = MODEL_D, G = MODEL_G, opt = OPT, plot_data = PLOT_DATA, epoch = EPOCH, vis_noise_inputs = VIS_NOISE_INPUTS, normalize_mean=NORMALIZE_MEAN, normalize_std=NORMALIZE_STD})
251 | end
252 |
253 | -- Get examples to plot.
254 | -- Returns a list of the pattern
255 | -- [i] Image, black and white.
256 | -- [i+1] Image, color.
257 | -- [i+2] Image, color added by G.
258 | function getSamples()
259 | local N = EXAMPLE_IMAGES:size()
260 | local ds = EXAMPLE_IMAGES
261 | local noiseInputs = torch.Tensor(N, NOISE_DIM[1], NOISE_DIM[2], NOISE_DIM[3])
262 | local condInputs = torch.Tensor(N, COND_DIM[1], COND_DIM[2], COND_DIM[3])
263 | local gt = torch.Tensor(N, IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3])
264 |
265 | -- Generate samples
266 | noiseInputs:uniform(0, 1)
267 | for i=1,N do
268 | --local idx = math.random(ds:size())
269 | local idx = i
270 | local example = ds[idx]
271 | condInputs[i] = example.grayscale:clone()
272 | gt[i] = example.color:clone()
273 | end
274 | local samples = MODEL_G:forward({noiseInputs, condInputs})
275 |
276 | local to_plot = {}
277 | for i=1,N do
278 | --local withColor = torch.cat(condInputs[i]:float(), samples[i]:float(), 1)
279 | local withColor = samples[i]:clone()
280 | to_plot[#to_plot+1] = NN_UTILS.switchColorSpaceSingle(condInputs[i]:float(), "y", "rgb")
281 | to_plot[#to_plot+1] = gt[i]:float()
282 | to_plot[#to_plot+1] = withColor
283 | end
284 | return to_plot
285 | end
286 |
287 | -- Updates the display plot.
288 | function visualizeProgressConditional()
289 | -- Show images and their refinements for the validation and training set
290 | --local toPlotVal = getSamples(VAL_DATA, 20)
291 | local toPlotTrain = getSamples()
292 | --DISP.image(toPlotVal, {win=OPT.window, width=2*10*IMG_DIMENSIONS[3], title=string.format("[VAL] Coarse, GT, G img, GT diff, G diff (%s epoch %d)", OPT.save, EPOCH)})
293 | DISP.image(toPlotTrain, {win=OPT.window+1, width=14*IMG_DIMENSIONS[3], title=string.format("[TRAIN] original grayscale, original color, auto-colorized (%s epoch %d)", OPT.save, EPOCH)})
294 | end
295 |
296 | main()
297 |
--------------------------------------------------------------------------------
/adversarial_rgb.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 | require 'optim'
3 | require 'pl'
4 | require 'image'
5 |
6 | local adversarial = {}
7 |
8 | -- this variable will save the accuracy values of D
9 | adversarial.accs = {}
10 |
11 | -- function to calculate the mean of a list of numbers
12 | function adversarial.mean(t)
13 | local sum = 0
14 | local count = 0
15 |
16 | for k,v in pairs(t) do
17 | if type(v) == 'number' then
18 | sum = sum + v
19 | count = count + 1
20 | end
21 | end
22 |
23 | return (sum / count)
24 | end
25 |
26 | -- main training function
27 | function adversarial.train(trainData, maxAccuracyD, accsInterval)
28 | EPOCH = EPOCH or 1
29 | local N_epoch = OPT.N_epoch
30 | if N_epoch <= 0 then
31 | N_epoch = 100
32 | end
33 | local dataBatchSize = OPT.batchSize / 2 -- size of a half-batch for D or G
34 | local time = sys.clock()
35 |
36 | -- variables to track D's accuracy and adjust learning rates
37 | local lastAccuracyD = 0.0
38 | local doTrainD = true
39 | local countTrainedD = 0
40 | local countNotTrainedD = 0
41 |
42 | samples = nil
43 | local batchIdx = 0
44 |
45 | -- do one epoch
46 | -- While this function is structured like one that picks example batches in consecutive order,
47 | -- in reality the examples (per batch) will be picked randomly
48 | print(string.format(" Epoch #%d [batchSize = %d]", EPOCH, OPT.batchSize))
49 | for batchIdx=1,N_epoch do
50 | -- size of this batch, will usually be dataBatchSize but can be lower at the end
51 | --local thisBatchSize = math.min(OPT.batchSize, N_epoch - t + 1)
52 |
53 | -- Inputs for D, either original or generated images
54 | local inputs = torch.Tensor(OPT.batchSize, 3, IMG_DIMENSIONS[2], IMG_DIMENSIONS[3])
55 |
56 | -- target y-values
57 | local targets = torch.Tensor(OPT.batchSize)
58 |
59 | -- tensor to use for noise for G
60 | --local noiseInputs = torch.Tensor(thisBatchSize, OPT.noiseDim)
61 | local noiseInputs = torch.Tensor(OPT.batchSize, NOISE_DIM[1], NOISE_DIM[2], NOISE_DIM[3])
62 | local condInputs = torch.Tensor(OPT.batchSize, COND_DIM[1], COND_DIM[2], COND_DIM[3])
63 |
64 | ----------------------------------------------------------------------
65 | -- create closure to evaluate f(X) and df/dX of D
66 | local fevalD = function(x)
67 | collectgarbage()
68 | local confusion_batch_D = optim.ConfusionMatrix(CLASSES)
69 | confusion_batch_D:zero()
70 |
71 | if x ~= PARAMETERS_D then -- get new parameters
72 | PARAMETERS_D:copy(x)
73 | end
74 |
75 | GRAD_PARAMETERS_D:zero() -- reset gradients
76 |
77 | -- forward pass
78 | -- condInputs = y, inputs = uv
79 | local outputs = MODEL_D:forward({condInputs, inputs})
80 | local f = CRITERION:forward(outputs, targets)
81 |
82 | -- backward pass
83 | local df_do = CRITERION:backward(outputs, targets)
84 | MODEL_D:backward({condInputs, inputs}, df_do)
85 |
86 | -- penalties (L1 and L2):
87 | if OPT.D_L1 ~= 0 or OPT.D_L2 ~= 0 then
88 | -- Loss:
89 | f = f + OPT.D_L1 * torch.norm(PARAMETERS_D, 1)
90 | f = f + OPT.D_L2 * torch.norm(PARAMETERS_D, 2)^2/2
91 | -- Gradients:
92 | GRAD_PARAMETERS_D:add(torch.sign(PARAMETERS_D):mul(OPT.D_L1) + PARAMETERS_D:clone():mul(OPT.D_L2) )
93 | end
94 |
95 | -- update confusion (add 1 since targets are binary)
96 | for i=1,OPT.batchSize do
97 | local c
98 | if outputs[i][1] > 0.5 then c = 2 else c = 1 end
99 | CONFUSION:add(c, targets[i]+1)
100 | confusion_batch_D:add(c, targets[i]+1)
101 | end
102 |
103 | -- Clamp D's gradients
104 | -- This helps a bit against D suddenly giving up (only outputting y=1 or y=0)
105 | if OPT.D_clamp ~= 0 then
106 | GRAD_PARAMETERS_D:clamp((-1)*OPT.D_clamp, OPT.D_clamp)
107 | end
108 |
109 | -- Calculate accuracy of D on this batch
110 | confusion_batch_D:updateValids()
111 | local tV = confusion_batch_D.totalValid
112 |
113 | -- Add this batch's accuracy to the history of D's accuracies
114 | -- Also, keep that history to a fixed size
115 | adversarial.accs[#adversarial.accs+1] = tV
116 | if #adversarial.accs > accsInterval then
117 | table.remove(adversarial.accs, 1)
118 | end
119 |
120 | -- Mean accuracy of D over the last couple of batches
121 | local accAvg = adversarial.mean(adversarial.accs)
122 |
123 | -- We will only train D if its mean accuracy over the last couple of batches
124 | -- was below the defined maximum (maxAccuracyD). This protects a bit against
125 | -- G generating garbage.
126 | doTrainD = (accAvg < maxAccuracyD)
127 | lastAccuracyD = tV
128 | if doTrainD then
129 | countTrainedD = countTrainedD + 1
130 | return f,GRAD_PARAMETERS_D
131 | else
132 | countNotTrainedD = countNotTrainedD + 1
133 |
134 | -- The interruptable* Optimizers dont train when false is returned
135 | -- Maybe that would be equivalent to just returning 0 for all gradients?
136 | return false,false
137 | end
138 | end
139 |
140 | ----------------------------------------------------------------------
141 | -- create closure to evaluate f(X) and df/dX of generator
142 | local fevalG_on_D = function(x)
143 | collectgarbage()
144 | if x ~= PARAMETERS_G then -- get new parameters
145 | PARAMETERS_G:copy(x)
146 | end
147 |
148 | GRAD_PARAMETERS_G:zero() -- reset gradients
149 |
150 | -- forward pass
151 | --local samples = NN_UTILS.createImagesFromNoise(noiseInputs, false, true)
152 | local samples = MODEL_G:forward({noiseInputs, condInputs})
153 | -- condInputs = y, samples = uv
154 | local outputs = MODEL_D:forward({condInputs, samples})
155 | local f = CRITERION:forward(outputs, targets)
156 |
157 | -- backward pass
158 | local df_samples = CRITERION:backward(outputs, targets)
159 | MODEL_D:backward({condInputs, samples}, df_samples)
160 | local df_do = MODEL_D.modules[1].gradInput[2] -- 1=grad of y/condInput, 2=grad of uv/samples
161 | MODEL_G:backward({noiseInputs, condInputs}, df_do)
162 |
163 | -- penalties (L1 and L2):
164 | if OPT.G_L1 ~= 0 or OPT.G_L2 ~= 0 then
165 | -- Loss:
166 | f = f + OPT.G_L1 * torch.norm(PARAMETERS_G, 1)
167 | f = f + OPT.G_L2 * torch.norm(PARAMETERS_G, 2)^2/2
168 | -- Gradients:
169 | GRAD_PARAMETERS_G:add(torch.sign(PARAMETERS_G):mul(OPT.G_L2) + PARAMETERS_G:clone():mul(OPT.G_L2))
170 | end
171 |
172 | -- clamp G's Gradient to the range of -1.0 to +1.0
173 | if OPT.G_clamp ~= 0 then
174 | GRAD_PARAMETERS_G:clamp((-1)*OPT.G_clamp, OPT.G_clamp)
175 | end
176 |
177 | return f,GRAD_PARAMETERS_G
178 | end
179 | ------------------- end of eval functions ---------------------------
180 |
181 | ----------------------------------------------------------------------
182 | -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
183 | -- Get half a minibatch of real, half fake
184 | for k=1, OPT.D_iterations do
185 | -- (1.1) Real data
186 | local inputIdx = 1
187 | local realDataSize = OPT.batchSize / 2
188 | for i = 1, realDataSize do
189 | local randomIdx = math.random(trainData:size())
190 | local trainingExample = trainData[randomIdx]
191 | --inputs[inputIdx] = trainingExample.uv:clone()
192 | inputs[inputIdx] = trainingExample.color:clone()
193 | condInputs[inputIdx] = trainingExample.grayscale:clone()
194 | targets[inputIdx] = Y_NOT_GENERATOR
195 | inputIdx = inputIdx + 1
196 | end
197 |
198 | -- (1.2) Sampled data
199 | noiseInputs:uniform(0, 1)
200 | for i = 1, realDataSize do
201 | local randomIdx = math.random(trainData:size())
202 | local trainingExample = trainData[randomIdx]
203 | condInputs[inputIdx] = trainingExample.grayscale:clone()
204 | inputIdx = inputIdx + 1
205 | end
206 | inputIdx = inputIdx - realDataSize
207 |
208 | local generatedUV = MODEL_G:forward({
209 | noiseInputs[{{realDataSize+1,2*realDataSize}}],
210 | condInputs[{{realDataSize+1,2*realDataSize}}]
211 | })
212 | for i=1, realDataSize do
213 | inputs[inputIdx] = generatedUV[i]:clone()
214 | targets[inputIdx] = Y_GENERATOR
215 | inputIdx = inputIdx + 1
216 | end
217 |
218 | if OPT.D_optmethod == "sgd" then
219 | optim.sgd(fevalD, PARAMETERS_D, OPTSTATE.sgd.D)
220 | elseif OPT.D_optmethod == "adagrad" then
221 | optim.adagrad(fevalD, PARAMETERS_D, OPTSTATE.adagrad.D)
222 | elseif OPT.D_optmethod == "adadelta" then
223 | optim.adadelta(fevalD, PARAMETERS_D, OPTSTATE.adadelta.D)
224 | elseif OPT.D_optmethod == "adamax" then
225 | optim.adamax(fevalD, PARAMETERS_D, OPTSTATE.adamax.D)
226 | elseif OPT.D_optmethod == "adam" then
227 | optim.adam(fevalD, PARAMETERS_D, OPTSTATE.adam.D)
228 | elseif OPT.D_optmethod == "rmsprop" then
229 | optim.rmsprop(fevalD, PARAMETERS_D, OPTSTATE.rmsprop.D)
230 | else
231 | print("[Warning] Unknown optimizer method chosen for D.")
232 | end
233 | end
234 |
235 | ----------------------------------------------------------------------
236 | -- (2) Update G network: maximize log(D(G(z)))
237 | for k=1, OPT.G_iterations do
238 | noiseInputs:uniform(0, 1)
239 | targets:fill(Y_NOT_GENERATOR)
240 | for i=1,OPT.batchSize do
241 | local randomIdx = math.random(trainData:size())
242 | local trainingExample = trainData[randomIdx]
243 | condInputs[i] = trainingExample.grayscale:clone()
244 | end
245 |
246 | if OPT.G_optmethod == "sgd" then
247 | optim.sgd(fevalG_on_D, PARAMETERS_G, OPTSTATE.sgd.G)
248 | elseif OPT.G_optmethod == "adagrad" then
249 | optim.adagrad(fevalG_on_D, PARAMETERS_G, OPTSTATE.adagrad.G)
250 | elseif OPT.G_optmethod == "adadelta" then
251 | optim.adadelta(fevalG_on_D, PARAMETERS_G, OPTSTATE.adadelta.G)
252 | elseif OPT.G_optmethod == "adamax" then
253 | optim.adamax(fevalG_on_D, PARAMETERS_G, OPTSTATE.adamax.G)
254 | elseif OPT.G_optmethod == "adam" then
255 | optim.adam(fevalG_on_D, PARAMETERS_G, OPTSTATE.adam.G)
256 | elseif OPT.G_optmethod == "rmsprop" then
257 | optim.rmsprop(fevalG_on_D, PARAMETERS_G, OPTSTATE.rmsprop.G)
258 | else
259 | print("[Warning] Unknown optimizer method chosen for G.")
260 | end
261 | end
262 |
263 | batchIdx = batchIdx + 1
264 | -- display progress
265 | xlua.progress(batchIdx * OPT.batchSize, N_epoch * OPT.batchSize)
266 | end
267 |
268 | -- time taken
269 | time = sys.clock() - time
270 | if maxAccuracyD < 1.0 then
271 | print(string.format(" trained D %d of %d times.", countTrainedD, countTrainedD + countNotTrainedD))
272 | end
273 |
274 | -- print confusion matrix
275 | print("Confusion of D:")
276 | print(CONFUSION)
277 | local tV = CONFUSION.totalValid
278 | CONFUSION:zero()
279 |
280 | return tV
281 | end
282 |
283 | -- Show the activity of a network in windows (i.e. windows full of blinking dots).
284 | -- The windows will automatically be reused.
285 | -- Only the activity of the layer types nn.SpatialConvolution and nn.Linear will be shown.
286 | -- Linear layers must have a minimum size to be shown (i.e. to not show the tiny output layers).
287 | --
288 | -- NOTE: This function can only visualize one network proberly while the program runs.
289 | -- I.e. you can't call this function to show network A and then another time to show network B,
290 | -- because the function tries to reuse windows and that will not work correctly in such a case.
291 | --
292 | -- NOTE: Old function, probably doesn't work anymore.
293 | --
294 | -- @param net The network to visualize.
295 | -- @param minOutputs Minimum (output) size of a linear layer to be shown.
296 | function adversarial.visualizeNetwork(net, minOutputs)
297 | if minOutputs == nil then
298 | minOutputs = 150
299 | end
300 |
301 | -- (Global) Table to save the window ids in, so that we can reuse them between calls.
302 | netvis_windows = netvis_windows or {}
303 |
304 | local modules = net:listModules()
305 | local winIdx = 1
306 | -- last module seems to have no output?
307 | for i=1,(#modules-1) do
308 | local t = torch.type(modules[i])
309 | local showTensor = nil
310 | -- This function only shows the activity of 2d convolutions and linear layers
311 | if t == 'nn.SpatialConvolution' then
312 | showTensor = modules[i].output[1]
313 | elseif t == 'nn.Linear' then
314 | local output = modules[i].output
315 | local shape = output:size()
316 | local nbValues = shape[2]
317 |
318 | if nbValues >= minOutputs and nbValues >= minOutputs then
319 | local nbRows = torch.floor(torch.sqrt(nbValues))
320 | while nbValues % nbRows ~= 0 and nbRows < nbValues do
321 | nbRows = nbRows + 1
322 | end
323 |
324 | if nbRows >= nbValues then
325 | showTensor = nil
326 | else
327 | showTensor = output[1]:view(nbRows, nbValues / nbRows)
328 | end
329 | end
330 | end
331 |
332 | -- Show the layer outputs in a window
333 | -- Note that windows are reused if possible
334 | if showTensor ~= nil then
335 | netvis_windows[winIdx] = image.display{
336 | image=showTensor, zoom=1, nrow=32,
337 | min=-1, max=1,
338 | win=netvis_windows[winIdx], legend=t .. ' (#' .. i .. ')',
339 | padding=1
340 | }
341 | winIdx = winIdx + 1
342 | end
343 | end
344 | end
345 |
346 | return adversarial
347 |
--------------------------------------------------------------------------------
/utils/nn_utils.lua:
--------------------------------------------------------------------------------
1 | require 'torch'
2 |
3 | local nn_utils = {}
4 |
5 | -- Sets the weights of a layer to random values within a range.
6 | -- @param weights The weights module to change, e.g. mlp.modules[1].weight.
7 | -- @param range Range for the new values (single number, e.g. 0.005)
8 | function nn_utils.setWeights(weights, range)
9 | weights:randn(weights:size())
10 | weights:mul(range)
11 | end
12 |
13 | -- Initializes all weights of a multi layer network.
14 | -- @param model The nn.Sequential() model with one or more layers
15 | -- @param rangeWeights A range for the new weights values (single number, e.g. 0.005)
16 | -- @param rangeBias A range for the new bias values (single number, e.g. 0.005)
17 | function nn_utils.initializeWeights(model, rangeWeights, rangeBias)
18 | rangeWeights = rangeWeights or 0.005
19 | rangeBias = rangeBias or 0.001
20 |
21 | for m = 1, #model.modules do
22 | if model.modules[m].weight then
23 | nn_utils.setWeights(model.modules[m].weight, rangeWeights)
24 | end
25 | if model.modules[m].bias then
26 | nn_utils.setWeights(model.modules[m].bias, rangeBias)
27 | end
28 | end
29 | end
30 |
31 | function nn_utils.forwardBatched(model, input, batchSize)
32 | local N
33 | if input.size then
34 | N = input:size(1)
35 | else
36 | N = #input
37 | end
38 |
39 | local output
40 | local nBatches = math.ceil(N/batchSize)
41 | for i=1,nBatches do
42 | local batchStart = 1 + (i-1) * batchSize
43 | local batchEnd = math.min(i*batchSize, N)
44 | local forwarded = model:forward(input[{{batchStart, batchEnd}}]):clone()
45 | if output == nil then
46 | local sizes = forwarded:size()
47 | sizes[1] = N
48 | output = torch.Tensor():resize(sizes)
49 | end
50 | output[{{batchStart, batchEnd}, {}, {}, {}}] = forwarded
51 | end
52 |
53 | return output
54 | end
55 |
56 | -- Creates a tensor of N vectors, each of dimension OPT.noiseDim with random values
57 | -- between -1 and +1.
58 | -- @param N Number of vectors to generate
59 | -- @returns Tensor of shape (N, OPT.noiseDim)
60 | function nn_utils.createNoiseInputs(N)
61 | local noiseInputs = torch.Tensor(N, OPT.noiseDim)
62 | noiseInputs:uniform(-1.0, 1.0)
63 | return noiseInputs
64 | end
65 |
66 | -- Feeds noise vectors into G or AE+G and returns the result.
67 | -- @param noiseInputs Tensor from createNoiseInputs()
68 | -- @param outputAsList Whether to return the images as one list or as a tensor.
69 | -- @returns Either list of images (as returned by G/AE) or tensor of images
70 | function nn_utils.createImagesFromNoise(noiseInputs, outputAsList)
71 | local images
72 | local N = noiseInputs:size(1)
73 | local nBatches = math.ceil(N/OPT.batchSize)
74 | for i=1,nBatches do
75 | local batchStart = 1 + (i-1)*OPT.batchSize
76 | local batchEnd = math.min(i*OPT.batchSize, N)
77 | local generated = MODEL_G:forward(noiseInputs[{{batchStart, batchEnd}}]):clone()
78 | if images == nil then
79 | local img = generated[1]
80 | images = torch.Tensor(N, img:size(1), img:size(2), img:size(3))
81 | end
82 | images[{{batchStart, batchEnd}, {}, {}, {}}] = generated
83 | end
84 |
85 | if outputAsList then
86 | local imagesList = {}
87 | for i=1, images:size(1) do
88 | imagesList[#imagesList+1] = images[i]:float()
89 | end
90 | return imagesList
91 | else
92 | return images
93 | end
94 | end
95 |
96 | -- Creates new random images with G or AE+G.
97 | -- @param N Number of images to create.
98 | -- @param outputAsList Whether to return the images as one list or as a tensor.
99 | -- @returns Either list of images (as returned by G/AE) or tensor of images
100 | function nn_utils.createImages(N, outputAsList)
101 | return nn_utils.createImagesFromNoise(nn_utils.createNoiseInputs(N), outputAsList)
102 | end
103 |
104 | -- Sorts images based on D's certainty that they are fake/real.
105 | -- Descending order starts at y=1 (Y_NOT_GENERATOR) and ends with y=0 (Y_GENERATOR).
106 | -- Therefore, in case of descending order, images for which D is very certain that they are real
107 | -- come first and images that seem to be fake (according to D) come last.
108 | -- @param images Tensor of the images to sort.
109 | -- @param ascending If true then images that seem most fake to D are placed at the start of the list.
110 | -- Otherwise the list starts with probably real images.
111 | -- @param nbMaxOut Sets how many images may be returned max (cant be more images than provided).
112 | -- @return Tuple (list of images, list of predictions between 0.0 and 1.0)
113 | -- where 1.0 means "probably real"
114 | function nn_utils.sortImagesByPrediction(images, ascending, nbMaxOut)
115 | local predictions = torch.Tensor(images:size(1), 1)
116 | local nBatches = math.ceil(images:size(1)/OPT.batchSize)
117 | for i=1,nBatches do
118 | local batchStart = 1 + (i-1)*OPT.batchSize
119 | local batchEnd = math.min(i*OPT.batchSize, images:size(1))
120 | predictions[{{batchStart, batchEnd}, {1}}] = MODEL_D:forward(images[{{batchStart, batchEnd}, {}, {}, {}}]):clone()
121 | end
122 |
123 | local imagesWithPreds = {}
124 | for i=1,images:size(1) do
125 | table.insert(imagesWithPreds, {images[i], predictions[i][1]})
126 | end
127 |
128 | if ascending then
129 | table.sort(imagesWithPreds, function (a,b) return a[2] < b[2] end)
130 | else
131 | table.sort(imagesWithPreds, function (a,b) return a[2] > b[2] end)
132 | end
133 |
134 | resultImages = {}
135 | resultPredictions = {}
136 | for i=1,math.min(nbMaxOut,#imagesWithPreds) do
137 | resultImages[i] = imagesWithPreds[i][1]
138 | resultPredictions[i] = imagesWithPreds[i][2]
139 | end
140 |
141 | return resultImages, resultPredictions
142 | end
143 |
144 | function nn_utils.switchColorSpace(images, from, to)
145 | images = nn_utils.toRgb(images, from)
146 | images = nn_utils.rgbToColorSpace(images, to)
147 | return images
148 | end
149 |
150 | function nn_utils.switchColorSpaceSingle(image, from, to)
151 | local images = nn_utils.toBatch(image)
152 | images = nn_utils.toRgb(images, from)
153 | images = nn_utils.rgbToColorSpace(images, to)
154 | return images[1]
155 | end
156 |
157 | function nn_utils.toRgb(images, from)
158 | local images = nn_utils.toImageTensor(images)
159 | if from == "rgb" then
160 | return images
161 | elseif from == "y" then
162 | --[[
163 | local imagesTmp
164 | if images:size(4) == nil then
165 | imagesTmp = images:clone()
166 | else
167 | imagesTmp = images:clone():squeeze(2)
168 | end
169 |
170 | local N = imagesTmp:size(1)
171 | local height = imagesTmp:size(2)
172 | local width = imagesTmp:size(3)
173 | --]]
174 | return torch.repeatTensor(images, 1, 3, 1, 1)
175 | elseif from == "hsl" then
176 | local out = torch.Tensor(images:size(1), 3, images:size(3), images:size(4))
177 | for i=1,images:size(1) do
178 | out[i] = image.hsl2rgb(images[i])
179 | end
180 | return out
181 | elseif from == "yuv" then
182 | local out = torch.Tensor(images:size(1), 3, images:size(3), images:size(4))
183 | for i=1,images:size(1) do
184 | out[i] = image.yuv2rgb(images[i])
185 | end
186 | return out
187 | else
188 | print("[WARNING] unknown color space : '" .. from .. "'")
189 | end
190 | end
191 |
192 | function nn_utils.rgbToColorSpace(images, colorSpace)
193 | if colorSpace == "rgb" then
194 | return images
195 | else
196 | if colorSpace == "y" then
197 | local out = torch.Tensor(images:size(1), 1, images:size(3), images:size(4))
198 | for i=1,images:size(1) do
199 | out[i] = nn_utils.rgb2y(images[i])
200 | end
201 | return out
202 | elseif colorSpace == "hsl" then
203 | local out = torch.Tensor(images:size(1), 3, images:size(3), images:size(4))
204 | for i=1,images:size(1) do
205 | out[i] = image.rgb2hsl(images[i])
206 | end
207 | return out
208 | elseif colorSpace == "yuv" then
209 | local out = torch.Tensor(images:size(1), 3, images:size(3), images:size(4))
210 | for i=1,images:size(1) do
211 | out[i] = image.rgb2yuv(images[i])
212 | end
213 | return out
214 | else
215 | print("[WARNING] unknown color space in rgbToColorSpace: '" .. colorSpace .. "'")
216 | end
217 | end
218 | end
219 |
220 | -- convert rgb to grayscale by averaging channel intensities
221 | -- https://gist.github.com/jkrish/29ca7302e98554dd0fcb
222 | function nn_utils.rgb2y(im, threeChannels)
223 | -- Image.rgb2y uses a different weight mixture
224 | local dim, w, h = im:size()[1], im:size()[2], im:size()[3]
225 | if dim ~= 3 then
226 | print(' expected 3 channels')
227 | return im
228 | end
229 |
230 | -- a cool application of tensor:select
231 | local r = im:select(1, 1)
232 | local g = im:select(1, 2)
233 | local b = im:select(1, 3)
234 |
235 | local z = torch.Tensor(1, w, h):zero()
236 |
237 | -- z = z + 0.21r
238 | z = z:add(0.21, r)
239 | z = z:add(0.72, g)
240 | z = z:add(0.07, b)
241 |
242 | if threeChannels == true then
243 | z = torch.repeatTensor(z, 3, 1, 1)
244 | end
245 |
246 | return z
247 | end
248 |
249 | function nn_utils.toBatch(image)
250 | if image:size() == 2 then
251 | local tnsr = torch.Tensor():resize(1, image:size(1), image:size(2))
252 | tnsr[1] = image
253 | return tnsr
254 | else
255 | local tnsr = torch.Tensor():resize(1, image:size(1), image:size(2), image:size(3))
256 | tnsr[1] = image
257 | return tnsr
258 | end
259 | end
260 |
261 | -- Convert a list (table) of images to a Tensor.
262 | -- If the parameter is already a tensor, it will be returned unchanged.
263 | -- @param imageList A non-empty list/table or tensor of images (each being a tensor).
264 | -- @returns A tensor of shape (N, channels, height, width)
265 | function nn_utils.toImageTensor(imageList, forceChannel)
266 | if imageList.size ~= nil then
267 | if not forceChannel or (#imageList:size() == 4) then
268 | return imageList
269 | else
270 | -- forceChannel activated and images lack channel dimension
271 | -- add it
272 | local tens = torch.Tensor(imageList:size(1), 1, imageList:size(2), imageList:size(3))
273 | for i=1,imageList:size(1) do
274 | tens[i][1] = imageList[i]
275 | end
276 | return tens
277 | end
278 | else
279 | if forceChannel == nil then
280 | forceChannel = false
281 | end
282 |
283 | local hasChannel = (#imageList[1]:size() == 3)
284 |
285 | local tens
286 | if hasChannel then
287 | tens = torch.Tensor(#imageList, imageList[1]:size(1), imageList[1]:size(2), imageList[1]:size(3))
288 | elseif not hasChannel and forceChannel then
289 | tens = torch.Tensor(#imageList, 1, imageList[1]:size(1), imageList[1]:size(2))
290 | else
291 | tens = torch.Tensor(#imageList, imageList[1]:size(1), imageList[1]:size(2))
292 | end
293 |
294 | for i=1,#imageList do
295 | if (not hasChannel and forceChannel) then
296 | tens[i][1] = imageList[i]
297 | else
298 | tens[i] = imageList[i]
299 | end
300 | end
301 | return tens
302 | end
303 | end
304 |
305 | function nn_utils.toImageList(imageTensor, forceChannel)
306 | local tens = nn_utils.toImageTensor(imageTensor, forceChannel)
307 | local lst = {}
308 | for i=1,tens:size(1) do
309 | table.insert(lst, tens[i])
310 | end
311 | return lst
312 | end
313 |
314 | -- Switch networks to training mode (activate Dropout)
315 | function nn_utils.switchToTrainingMode()
316 | if MODEL_AE then
317 | MODEL_AE:training()
318 | end
319 | MODEL_G:training()
320 | MODEL_D:training()
321 | end
322 |
323 | -- Switch networks to evaluation mode (deactivate Dropout)
324 | function nn_utils.switchToEvaluationMode()
325 | if MODEL_AE then
326 | MODEL_AE:evaluate()
327 | end
328 | MODEL_G:evaluate()
329 | MODEL_D:evaluate()
330 | end
331 |
332 | -- Normalize given images, currently to range -1.0 (black) to +1.0 (white), assuming that
333 | -- the input images are normalized to range 0.0 (black) to +1.0 (white).
334 | -- @param data Tensor of images
335 | -- @param mean_ Currently ignored.
336 | -- @param std_ Currently ignored.
337 | -- @return (mean, std), both currently always 0.5 dummy values
338 | function nn_utils.normalize(data, mean_, std_)
339 | -- Code to normalize to zero-mean and unit-variance.
340 | --[[
341 | local mean = mean_ or data:mean(1)
342 | local std = std_ or data:std(1, true)
343 | local eps = 1e-7
344 | local N
345 | if data.size ~= nil then
346 | N = data:size(1)
347 | else
348 | N = #data
349 | end
350 |
351 | for i=1,N do
352 | data[i]:add(-1, mean)
353 | data[i]:cdiv(std + eps)
354 | end
355 |
356 | return mean, std
357 | --]]
358 |
359 | -- Code to normalize to range -1.0 to +1.0, where -1.0 is black and 1.0 is the maximum
360 | -- value in this image.
361 | --[[
362 | local N
363 | if data.size ~= nil then
364 | N = data:size(1)
365 | else
366 | N = #data
367 | end
368 |
369 | for i=1,N do
370 | local m = torch.max(data[i])
371 | data[i]:div(m * 0.5)
372 | data[i]:add(-1.0)
373 | data[i] = torch.clamp(data[i], -1.0, 1.0)
374 | end
375 | --]]
376 |
377 | -- Normalize to range -1.0 to +1.0, where -1.0 is black and +1.0 is white.
378 | local N
379 | if data.size ~= nil then
380 | N = data:size(1)
381 | else
382 | N = #data
383 | end
384 |
385 | for i=1,N do
386 | data[i]:mul(2)
387 | data[i]:add(-1.0)
388 | data[i] = torch.clamp(data[i], -1.0, 1.0)
389 | end
390 |
391 | -- Dummy return values
392 | return 0.5, 0.5
393 | end
394 |
395 | -- from https://github.com/torch/DEPRECEATED-torch7-distro/issues/47
396 | function nn_utils.zeroDataSize(data)
397 | if type(data) == 'table' then
398 | for i = 1, #data do
399 | data[i] = nn_utils.zeroDataSize(data[i])
400 | end
401 | elseif type(data) == 'userdata' then
402 | data = torch.Tensor():typeAs(data)
403 | end
404 | return data
405 | end
406 |
407 | -- from https://github.com/torch/DEPRECEATED-torch7-distro/issues/47
408 | -- Resize the output, gradInput, etc temporary tensors to zero (so that the on disk size is smaller)
409 | function nn_utils.prepareNetworkForSave(node)
410 | if node.output ~= nil then
411 | node.output = nn_utils.zeroDataSize(node.output)
412 | end
413 | if node.gradInput ~= nil then
414 | node.gradInput = nn_utils.zeroDataSize(node.gradInput)
415 | end
416 | if node.finput ~= nil then
417 | node.finput = nn_utils.zeroDataSize(node.finput)
418 | end
419 | -- Recurse on nodes with 'modules'
420 | if (node.modules ~= nil) then
421 | if (type(node.modules) == 'table') then
422 | for i = 1, #node.modules do
423 | local child = node.modules[i]
424 | nn_utils.prepareNetworkForSave(child)
425 | end
426 | end
427 | end
428 | collectgarbage()
429 | end
430 |
431 | function nn_utils.getNumberOfParameters(net)
432 | local nparams = 0
433 | local dModules = net:listModules()
434 | for i=1,#dModules do
435 | if dModules[i].weight ~= nil then
436 | nparams = nparams + dModules[i].weight:nElement()
437 | end
438 | end
439 | return nparams
440 | end
441 |
442 | -- Contains the pixels necessary to draw digits 0 to 9
443 | CHAR_TENSORS = {}
444 | CHAR_TENSORS[0] = torch.Tensor({{1, 1, 1},
445 | {1, 0, 1},
446 | {1, 0, 1},
447 | {1, 0, 1},
448 | {1, 1, 1}})
449 | CHAR_TENSORS[1] = torch.Tensor({{0, 0, 1},
450 | {0, 0, 1},
451 | {0, 0, 1},
452 | {0, 0, 1},
453 | {0, 0, 1}})
454 | CHAR_TENSORS[2] = torch.Tensor({{1, 1, 1},
455 | {0, 0, 1},
456 | {1, 1, 1},
457 | {1, 0, 0},
458 | {1, 1, 1}})
459 | CHAR_TENSORS[3] = torch.Tensor({{1, 1, 1},
460 | {0, 0, 1},
461 | {0, 1, 1},
462 | {0, 0, 1},
463 | {1, 1, 1}})
464 | CHAR_TENSORS[4] = torch.Tensor({{1, 0, 1},
465 | {1, 0, 1},
466 | {1, 1, 1},
467 | {0, 0, 1},
468 | {0, 0, 1}})
469 | CHAR_TENSORS[5] = torch.Tensor({{1, 1, 1},
470 | {1, 0, 0},
471 | {1, 1, 1},
472 | {0, 0, 1},
473 | {1, 1, 1}})
474 | CHAR_TENSORS[6] = torch.Tensor({{1, 1, 1},
475 | {1, 0, 0},
476 | {1, 1, 1},
477 | {1, 0, 1},
478 | {1, 1, 1}})
479 | CHAR_TENSORS[7] = torch.Tensor({{1, 1, 1},
480 | {0, 0, 1},
481 | {0, 0, 1},
482 | {0, 0, 1},
483 | {0, 0, 1}})
484 | CHAR_TENSORS[8] = torch.Tensor({{1, 1, 1},
485 | {1, 0, 1},
486 | {1, 1, 1},
487 | {1, 0, 1},
488 | {1, 1, 1}})
489 | CHAR_TENSORS[9] = torch.Tensor({{1, 1, 1},
490 | {1, 0, 1},
491 | {1, 1, 1},
492 | {0, 0, 1},
493 | {1, 1, 1}})
494 |
495 | -- Converts a list of images to a grid of images that can be saved easily.
496 | -- It will also place the epoch number at the bottom of the image.
497 | -- At least parts of this function probably should have been a simple call
498 | -- to image.toDisplayTensor().
499 | -- @param images Tensor of image tensors
500 | -- @param height Height of the grid
501 | -- @param width Width of the grid
502 | -- @param epoch The epoch number to draw at the bottom of the grid
503 | -- @returns tensor
504 | function nn_utils.imagesToGridTensor(images, height, width, epoch)
505 | local imgChannels = images:size(2)
506 | local imgHeightPx = IMG_DIMENSIONS[2]
507 | local imgWidthPx = IMG_DIMENSIONS[3]
508 | local heightPx = height * imgHeightPx + (1 + 5 + 1)
509 | local widthPx = width * imgWidthPx
510 | local grid = torch.Tensor(imgChannels, heightPx, widthPx)
511 | grid:zero()
512 |
513 | -- add images to grid, one by one
514 | local yGridPos = 1
515 | local xGridPos = 1
516 | for i=1,math.min(images:size(1), height*width) do
517 | -- set pixels of image
518 | local yStart = 1 + ((yGridPos-1) * imgHeightPx)
519 | local yEnd = yStart + imgHeightPx - 1
520 | local xStart = 1 + ((xGridPos-1) * imgWidthPx)
521 | local xEnd = xStart + imgWidthPx - 1
522 | grid[{{1,imgChannels}, {yStart,yEnd}, {xStart,xEnd}}] = images[i]:float()
523 |
524 | -- move to next position in grid
525 | xGridPos = xGridPos + 1
526 | if xGridPos > width then
527 | xGridPos = 1
528 | yGridPos = yGridPos + 1
529 | end
530 | end
531 |
532 | -- add the epoch at the bottom of the image
533 | local epochStr = tostring(epoch)
534 | local pos = 1
535 | for i=epochStr:len(),1,-1 do
536 | local c = tonumber(epochStr:sub(i,i))
537 | for channel=1,imgChannels do
538 | local yStart = heightPx - 1 - 5 -- constant for all
539 | local yEnd = yStart + 5 - 1 -- constant for all
540 | local xStart = widthPx - 1 - pos*5 - pos
541 | local xEnd = xStart + 3 - 1
542 |
543 | grid[{{channel}, {yStart, yEnd}, {xStart, xEnd}}] = CHAR_TENSORS[c]
544 | end
545 | pos = pos + 1
546 | end
547 |
548 | return grid
549 | end
550 |
551 | -- Saves the list of image to the provided filepath (as a grid image).
552 | -- @param filepath Save the grid image to that filepath
553 | -- @param images List of image tensors
554 | -- @param height Height of the grid
555 | -- @param width Width of the grid
556 | -- @param epoch The epoch number to draw at the bottom of the grid
557 | -- @returns tensor
558 | function nn_utils.saveImagesAsGrid(filepath, images, height, width, epoch)
559 | local grid = nn_utils.imagesToGridTensor(images, height, width, epoch)
560 | os.execute(string.format("mkdir -p %s", sys.dirname(filepath)))
561 | image.save(filepath, grid)
562 | end
563 |
564 | -- Deactivates CUDA mode on a network and returns the result.
565 | -- Expects networks in CUDA mode to be a Sequential of the form
566 | -- [1] Copy layer [2] Sequential [3] Copy layer
567 | -- as created by activateCuda().
568 | -- @param net The network to deactivate CUDA mode on.
569 | -- @returns The CPU network
570 | function nn_utils.deactivateCuda(net)
571 | local newNet = net:clone()
572 | newNet:float()
573 | if torch.type(newNet:get(1)) == 'nn.Copy' then
574 | return newNet:get(2)
575 | else
576 | return newNet
577 | end
578 | end
579 |
580 | -- Returns whether a Sequential contains any copy layers.
581 | -- @param net The network to analyze.
582 | -- @return boolean
583 | function nn_utils.containsCopyLayers(net)
584 | local modules = net:listModules()
585 | for i=1,#modules do
586 | local t = torch.type(modules[i])
587 | if string.find(t, "Copy") ~= nil then
588 | return true
589 | end
590 | end
591 | return false
592 | end
593 |
594 | -- Activates CUDA mode on a network and returns the result.
595 | -- This adds Copy layers at the start and end of the network.
596 | -- Expects the default tensor to be FloatTensor.
597 | -- @param net The network to activate CUDA mode on.
598 | -- @returns The CUDA network
599 | function nn_utils.activateCuda(net)
600 | --[[
601 | local newNet = net:clone()
602 | newNet:cuda()
603 | local tmp = nn.Sequential()
604 | tmp:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor'))
605 | tmp:add(newNet)
606 | tmp:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor'))
607 | return tmp
608 | --]]
609 | local newNet = net:clone()
610 |
611 | -- does the network already contain any copy layers?
612 | local containsCopyLayers = nn_utils.containsCopyLayers(newNet)
613 |
614 | -- no copy layers in the network yet
615 | -- add them at the start and end
616 | if not containsCopyLayers then
617 | local tmp = nn.Sequential()
618 | tmp:add(nn.Copy('torch.FloatTensor', 'torch.CudaTensor'))
619 | tmp:add(newNet)
620 | tmp:add(nn.Copy('torch.CudaTensor', 'torch.FloatTensor'))
621 | newNet:cuda()
622 | newNet = tmp
623 | end
624 |
625 | --[[
626 | local firstCopyFound = false
627 | local lastCopyFound = false
628 | modules = newNet:listModules()
629 | for i=1,#modules do
630 | print("module "..i.." " .. torch.type(modules[i]))
631 | local t = torch.type(modules[i])
632 | if string.find(t, "Copy") ~= nil then
633 | if not firstCopyFound then
634 | firstCopyFound = true
635 | modules[i]:cuda()
636 | modules[i].intype = 'torch.FloatTensor'
637 | modules[i].outtype = 'torch.CudaTensor'
638 | else
639 | -- last copy found
640 | lastCopyFound = true
641 | modules[i]:float()
642 | modules[i].intype = 'torch.CudaTensor'
643 | modules[i].outtype = 'torch.FloatTensor'
644 | end
645 | elseif lastCopyFound then
646 | print("calling float() A")
647 | modules[i]:float()
648 | elseif firstCopyFound then
649 | print("calling cuda()")
650 | modules[i]:cuda()
651 | else
652 | print("calling float() B")
653 | modules[i]:float()
654 | end
655 | end
656 | --]]
657 |
658 | return newNet
659 | end
660 |
661 | -- Creates an average rating (0 to 1) for a list of images.
662 | -- 1 is best.
663 | -- @param images List of image tensors.
664 | -- @returns float
665 | function nn_utils.rateWithV(images)
666 | local imagesTensor
667 | local N
668 | if type(images) == 'table' then
669 | N = #images
670 | imagesTensor = torch.Tensor(N, IMG_DIMENSIONS[1], IMG_DIMENSIONS[2], IMG_DIMENSIONS[3])
671 | for i=1,N do
672 | imagesTensor[i] = images[i]
673 | end
674 | else
675 | N = images:size(1)
676 | imagesTensor = images
677 | end
678 |
679 | local predictions = MODEL_V:forward(imagesTensor)
680 | local sm = 0
681 | for i=1,N do
682 | -- first neuron in V signals whether the image is fake (1=yes, 0=no)
683 | sm = sm + predictions[i][1]
684 | end
685 |
686 | local fakiness = sm / N
687 |
688 | -- higher values for better images
689 | return (1 - fakiness)
690 | end
691 |
692 | return nn_utils
693 |
--------------------------------------------------------------------------------