├── Datasets ├── IMDB │ ├── 1.jpg │ ├── 2.jpg │ └── 3.jpg └── RealImg │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ └── 4.jpg ├── README.md ├── data2 ├── dataC.lua ├── dataset.lua └── donkey_folderC.lua ├── imgs ├── IMDb │ ├── 1_1.jpg │ ├── 1_2.jpg │ ├── 1_3.jpg │ ├── 1_4.jpg │ ├── 2_1.jpg │ ├── 2_2.jpg │ ├── 2_3.jpg │ └── 2_4.jpg ├── architecture │ └── pipeline.jpg ├── realresults │ └── 1.jpg └── warpface │ └── warp.jpg ├── test.lua └── util ├── cudnn_convert_custom.lua └── util.lua /Datasets/IMDB/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/IMDB/1.jpg -------------------------------------------------------------------------------- /Datasets/IMDB/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/IMDB/2.jpg -------------------------------------------------------------------------------- /Datasets/IMDB/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/IMDB/3.jpg -------------------------------------------------------------------------------- /Datasets/RealImg/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/RealImg/1.jpg -------------------------------------------------------------------------------- /Datasets/RealImg/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/RealImg/2.jpg -------------------------------------------------------------------------------- /Datasets/RealImg/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/RealImg/3.jpg -------------------------------------------------------------------------------- /Datasets/RealImg/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/RealImg/4.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [GFRNet](https://arxiv.org/abs/1804.04829) 2 | Torch implementation for [Learning Warped Guidance for Blind Face Restoration](https://arxiv.org/abs/1804.04829) 3 | 4 | # GFRNet framework 5 | Overview of our GFRNet. The WarpNet takes the degraded observation and guided image as input to predict the dense flow field, which is adopted to deform guided image to the warped guidance. Warped guidance is expected to be spatially well aligned with ground-truth. Thus the RecNet takes warped guidance and degradated observation as input to produce the restoration result. 6 | 7 | 8 | 9 | 10 | # Testing 11 | 12 | ```bash 13 | th test.lua 14 | ``` 15 | # Models 16 | Download the pre-trained model with the following url and put it into ./checkpoints/FaceRestoration/. 17 | - [BaiduNetDisk](https://pan.baidu.com/s/1q96l3qmTf5Luh-nlqot6Xw) 18 | - [GoogleDrive](https://drive.google.com/open?id=1PhE3Gi9-eHrofyR3LhqEhuVnzh9D7IsX) 19 | 20 | # Results 21 | ## Restoration on real low quality images 22 | The first row is real low quality image(close-up in right bottom is the guided image). The second row is GFRNet result. 23 | 24 | 25 | 26 | ## Warped guidance 27 | 28 | 29 | 30 | ## IMDB results 31 | The content marked with green box is the restoration results by our GFRNet. All of these images are collected from [Internet Movie Database (IMDb)](https://www.imdb.com/). 32 | 33 | 34 |   35 | 36 | 37 | 40 | 43 | 46 | 49 | 50 | 51 | 54 | 57 | 60 | 63 | 64 |
InputGuided ImageBicubicGFRNet Results
38 | 39 | 41 | 42 | 44 | 45 | 47 | 48 |
52 | 53 | 55 | 56 | 58 | 59 | 61 | 62 |
65 | 66 | - [More IMDB resutls can be found here](http://csxmli.xin/GFRNet/). 67 | - [Poster can be found here](http://csxmli.xin/GFRNet/poster.pdf). 68 | 69 | # Requirements and Dependencies 70 | 71 | - [Torch](https://github.com/torch/distro) 72 | - [Cuda](https://developer.nvidia.com/cuda-toolkit-archive)-8.0 73 | - [Stn](https://github.com/qassemoquab/stnbhwd) 74 | 75 | # Acknowledgments 76 | 77 | Code borrows heavily from [pix2pix](https://github.com/phillipi/pix2pix). Thanks for their excellent work! 78 | 79 | # Citation 80 | 81 | ``` 82 | @InProceedings{Li_2018_ECCV, 83 | author = {Li, Xiaoming and Liu, Ming and Ye, Yuting and Zuo, Wangmeng and Lin, Liang and Yang, Ruigang}, 84 | title = {Learning Warped Guidance for Blind Face Restoration}, 85 | booktitle = {The European Conference on Computer Vision (ECCV)}, 86 | month = {September}, 87 | year = {2018} 88 | } 89 | ``` 90 | -------------------------------------------------------------------------------- /data2/dataC.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | This data loader is a modified version of the one from dcgan.torch 3 | (see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua). 4 | 5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details] 6 | ]]-- 7 | 8 | local Threads = require 'threads' 9 | Threads.serialization('threads.sharedserialize') 10 | 11 | local data = {} 12 | 13 | local result = {} 14 | local unpack = unpack and unpack or table.unpack 15 | 16 | function data.new(n, opt_) 17 | opt_ = opt_ or {} 18 | local self = {} 19 | for k,v in pairs(data) do 20 | self[k] = v 21 | end 22 | 23 | local donkey_file = 'donkey_folderC.lua' 24 | -- print('n..' .. n) 25 | if n > 0 then 26 | local options = opt_ 27 | self.threads = Threads(n, 28 | function() require 'torch' end, 29 | function(idx) 30 | opt = options 31 | tid = idx 32 | local seed = (opt.manualSeed and opt.manualSeed or 0) + idx 33 | torch.manualSeed(seed) 34 | torch.setnumthreads(1) 35 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) 36 | assert(options, 'options not found') 37 | assert(opt, 'opt not given') 38 | print(opt) 39 | paths.dofile(donkey_file) 40 | end 41 | 42 | ) 43 | else 44 | if donkey_file then paths.dofile(donkey_file) end 45 | -- print('empty threads') 46 | self.threads = {} 47 | function self.threads:addjob(f1, f2) f2(f1()) end 48 | function self.threads:dojob() end 49 | function self.threads:synchronize() end 50 | end 51 | 52 | local nSamples = 0 53 | self.threads:addjob(function() return trainLoader:size() end, 54 | function(c) nSamples = c end) 55 | self.threads:synchronize() 56 | self._size = nSamples 57 | 58 | for i = 1, n do 59 | self.threads:addjob(self._getFromThreads, 60 | self._pushResult) 61 | end 62 | -- print(self.threads) 63 | return self 64 | end 65 | 66 | function data._getFromThreads() 67 | assert(opt.batchSize, 'opt.batchSize not found') 68 | return trainLoader:sample(opt.batchSize) 69 | end 70 | 71 | function data._pushResult(...) 72 | local res = {...} 73 | if res == nil then 74 | self.threads:synchronize() 75 | end 76 | result[1] = res 77 | 78 | end 79 | 80 | 81 | 82 | function data:getBatch() 83 | -- queue another job 84 | -- print(self.threads) 85 | self.threads:addjob(self._getFromThreads, self._pushResult) 86 | self.threads:dojob() 87 | local res = result[1] 88 | 89 | -- print(res) 90 | -- print('result') 91 | -- print(res) 92 | -- os.exit() 93 | -- paths = results[3] 94 | -- print(paths) 95 | 96 | img_data = res[1] 97 | img_paths = res[3] 98 | -- print(img_data:size()) 99 | -- print(type(img_data)) 100 | -- print(img_paths) 101 | -- print(type(img_paths)) 102 | -- result[3] = nil 103 | -- print(type(res)) 104 | 105 | result[1] = nil 106 | if torch.type(img_data) == 'table' then 107 | img_data = unpack(img_data) 108 | end 109 | 110 | 111 | return img_data, img_paths,res[4] 112 | end 113 | 114 | function data:size() 115 | return self._size 116 | end 117 | 118 | return data 119 | -------------------------------------------------------------------------------- /data2/dataset.lua: -------------------------------------------------------------------------------- 1 | --[[ 2 | Copyright (c) 2015-present, Facebook, Inc. 3 | All rights reserved. 4 | 5 | This source code is licensed under the BSD-style license found in the 6 | LICENSE file in the root directory of this source tree. An additional grant 7 | of patent rights can be found in the PATENTS file in the same directory. 8 | ]]-- 9 | 10 | require 'torch' 11 | torch.setdefaulttensortype('torch.FloatTensor') 12 | local ffi = require 'ffi' 13 | local class = require('pl.class') 14 | local dir = require 'pl.dir' 15 | local tablex = require 'pl.tablex' 16 | local argcheck = require 'argcheck' 17 | require 'sys' 18 | require 'xlua' 19 | require 'image' 20 | 21 | local dataset = torch.class('dataLoader') 22 | 23 | local initcheck = argcheck{ 24 | pack=true, 25 | help=[[ 26 | A dataset class for images in a flat folder structure (folder-name is class-name). 27 | Optimized for extremely large datasets (upwards of 14 million images). 28 | Tested only on Linux (as it uses command-line linux utilities to scale up) 29 | ]], 30 | {check=function(paths) 31 | local out = true; 32 | for k,v in ipairs(paths) do 33 | if type(v) ~= 'string' then 34 | print('paths can only be of string input'); 35 | out = false 36 | end 37 | end 38 | return out 39 | end, 40 | name="paths", 41 | type="table", 42 | help="Multiple paths of directories with images"}, 43 | 44 | {name="sampleSize", 45 | type="table", 46 | help="a consistent sample size to resize the images"}, 47 | 48 | {name="split", 49 | type="number", 50 | help="Percentage of split to go to Training" 51 | }, 52 | {name="serial_batches", 53 | type="number", 54 | help="if randomly sample training images"}, 55 | 56 | {name="samplingMode", 57 | type="string", 58 | help="Sampling mode: random | balanced ", 59 | default = "balanced"}, 60 | 61 | {name="verbose", 62 | type="boolean", 63 | help="Verbose mode during initialization", 64 | default = false}, 65 | 66 | {name="loadSize", 67 | type="table", 68 | help="a size to load the images to, initially", 69 | opt = true}, 70 | 71 | {name="forceClasses", 72 | type="table", 73 | help="If you want this loader to map certain classes to certain indices, " 74 | .. "pass a classes table that has {classname : classindex} pairs." 75 | .. " For example: {3 : 'dog', 5 : 'cat'}" 76 | .. "This function is very useful when you want two loaders to have the same " 77 | .. "class indices (trainLoader/testLoader for example)", 78 | opt = true}, 79 | 80 | {name="sampleHookTrain", 81 | type="function", 82 | help="applied to sample during training(ex: for lighting jitter). " 83 | .. "It takes the image path as input", 84 | opt = true}, 85 | 86 | {name="sampleHookTest", 87 | type="function", 88 | help="applied to sample during testing", 89 | opt = true}, 90 | } 91 | 92 | function dataset:__init(...) 93 | 94 | -- argcheck 95 | local args = initcheck(...) 96 | print(args) 97 | for k,v in pairs(args) do self[k] = v end 98 | 99 | if not self.loadSize then self.loadSize = self.sampleSize; end 100 | 101 | if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end 102 | if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end 103 | self.image_count = 1 104 | -- print('image_count_init', self.image_count) 105 | -- find class names 106 | self.classes = {} 107 | local classPaths = {} 108 | if self.forceClasses then 109 | for k,v in pairs(self.forceClasses) do 110 | self.classes[k] = v 111 | classPaths[k] = {} 112 | end 113 | end 114 | local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end 115 | -- loop over each paths folder, get list of unique class names, 116 | -- also store the directory paths per class 117 | -- for each class, 118 | for k,path in ipairs(self.paths) do 119 | -- print('path', path) 120 | local dirs = {} -- hack 121 | dirs[1] = path 122 | -- local dirs = dir.getdirectories(path); 123 | for k,dirpath in ipairs(dirs) do 124 | local class = paths.basename(dirpath) 125 | local idx = tableFind(self.classes, class) 126 | -- print(class) 127 | -- print(idx) 128 | if not idx then 129 | table.insert(self.classes, class) 130 | idx = #self.classes 131 | classPaths[idx] = {} 132 | end 133 | if not tableFind(classPaths[idx], dirpath) then 134 | table.insert(classPaths[idx], dirpath); 135 | end 136 | end 137 | end 138 | 139 | self.classIndices = {} 140 | for k,v in ipairs(self.classes) do 141 | self.classIndices[v] = k 142 | end 143 | 144 | -- define command-line tools, try your best to maintain OSX compatibility 145 | local wc = 'wc' 146 | local cut = 'cut' 147 | local find = 'find -H' -- if folder name is symlink, do find inside it after dereferencing 148 | if jit.os == 'OSX' then 149 | wc = 'gwc' 150 | cut = 'gcut' 151 | find = 'gfind' 152 | end 153 | ---------------------------------------------------------------------- 154 | -- Options for the GNU find command 155 | local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} 156 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"' 157 | for i=2,#extensionList do 158 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' 159 | end 160 | 161 | -- find the image path names 162 | self.imagePath = torch.CharTensor() -- path to each image in dataset 163 | self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) 164 | self.classList = {} -- index of imageList to each image of a particular class 165 | self.classListSample = self.classList -- the main list used when sampling data 166 | 167 | print('running "find" on each class directory, and concatenate all' 168 | .. ' those filenames into a single file containing all image paths for a given class') 169 | -- so, generates one file per class 170 | local classFindFiles = {} 171 | for i=1,#self.classes do 172 | classFindFiles[i] = os.tmpname() 173 | end 174 | local combinedFindList = os.tmpname(); 175 | 176 | local tmpfile = os.tmpname() 177 | local tmphandle = assert(io.open(tmpfile, 'w')) 178 | -- iterate over classes 179 | for i, class in ipairs(self.classes) do 180 | -- iterate over classPaths 181 | for j,path in ipairs(classPaths[i]) do 182 | local command = find .. ' "' .. path .. '" ' .. findOptions 183 | .. ' >>"' .. classFindFiles[i] .. '" \n' 184 | tmphandle:write(command) 185 | end 186 | end 187 | io.close(tmphandle) 188 | os.execute('bash ' .. tmpfile) 189 | os.execute('rm -f ' .. tmpfile) 190 | 191 | print('now combine all the files to a single large file') 192 | local tmpfile = os.tmpname() 193 | local tmphandle = assert(io.open(tmpfile, 'w')) 194 | -- concat all finds to a single large file in the order of self.classes 195 | for i=1,#self.classes do 196 | local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n' 197 | tmphandle:write(command) 198 | end 199 | io.close(tmphandle) 200 | os.execute('bash ' .. tmpfile) 201 | os.execute('rm -f ' .. tmpfile) 202 | 203 | --========================================================================== 204 | print('load the large concatenated list of sample paths to self.imagePath') 205 | local cmd = wc .. " -L '" 206 | .. combinedFindList .. "' |" 207 | .. cut .. " -f1 -d' '" 208 | print('cmd..' .. cmd) 209 | local maxPathLength = tonumber(sys.fexecute(wc .. " -L '" 210 | .. combinedFindList .. "' |" 211 | .. cut .. " -f1 -d' '")) + 1 212 | local length = tonumber(sys.fexecute(wc .. " -l '" 213 | .. combinedFindList .. "' |" 214 | .. cut .. " -f1 -d' '")) 215 | assert(length > 0, "Could not find any image file in the given input paths") 216 | assert(maxPathLength > 0, "paths of files are length 0?") 217 | self.imagePath:resize(length, maxPathLength):fill(0) 218 | local s_data = self.imagePath:data() 219 | local count = 0 220 | for line in io.lines(combinedFindList) do 221 | ffi.copy(s_data, line) 222 | s_data = s_data + maxPathLength 223 | if self.verbose and count % 10000 == 0 then 224 | xlua.progress(count, length) 225 | end; 226 | count = count + 1 227 | end 228 | 229 | self.numSamples = self.imagePath:size(1) 230 | if self.verbose then print(self.numSamples .. ' samples found.') end 231 | --========================================================================== 232 | print('Updating classList and imageClass appropriately') 233 | self.imageClass:resize(self.numSamples) 234 | local runningIndex = 0 235 | for i=1,#self.classes do 236 | if self.verbose then xlua.progress(i, #(self.classes)) end 237 | local length = tonumber(sys.fexecute(wc .. " -l '" 238 | .. classFindFiles[i] .. "' |" 239 | .. cut .. " -f1 -d' '")) 240 | if length == 0 then 241 | error('Class has zero samples') 242 | else 243 | self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long() 244 | self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i) 245 | end 246 | runningIndex = runningIndex + length 247 | end 248 | 249 | --========================================================================== 250 | -- clean up temporary files 251 | print('Cleaning up temporary files') 252 | local tmpfilelistall = '' 253 | for i=1,#(classFindFiles) do 254 | tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"' 255 | if i % 1000 == 0 then 256 | os.execute('rm -f ' .. tmpfilelistall) 257 | tmpfilelistall = '' 258 | end 259 | end 260 | os.execute('rm -f ' .. tmpfilelistall) 261 | os.execute('rm -f "' .. combinedFindList .. '"') 262 | --========================================================================== 263 | 264 | if self.split == 100 then 265 | self.testIndicesSize = 0 266 | else 267 | print('Splitting training and test sets to a ratio of ' 268 | .. self.split .. '/' .. (100-self.split)) 269 | self.classListTrain = {} 270 | self.classListTest = {} 271 | self.classListSample = self.classListTrain 272 | local totalTestSamples = 0 273 | -- split the classList into classListTrain and classListTest 274 | for i=1,#self.classes do 275 | local list = self.classList[i] 276 | local count = self.classList[i]:size(1) 277 | local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round 278 | local perm = torch.randperm(count) 279 | self.classListTrain[i] = torch.LongTensor(splitidx) 280 | for j=1,splitidx do 281 | self.classListTrain[i][j] = list[perm[j]] 282 | end 283 | if splitidx == count then -- all samples were allocated to train set 284 | self.classListTest[i] = torch.LongTensor() 285 | else 286 | self.classListTest[i] = torch.LongTensor(count-splitidx) 287 | totalTestSamples = totalTestSamples + self.classListTest[i]:size(1) 288 | local idx = 1 289 | for j=splitidx+1,count do 290 | self.classListTest[i][idx] = list[perm[j]] 291 | idx = idx + 1 292 | end 293 | end 294 | end 295 | -- Now combine classListTest into a single tensor 296 | self.testIndices = torch.LongTensor(totalTestSamples) 297 | self.testIndicesSize = totalTestSamples 298 | local tdata = self.testIndices:data() 299 | local tidx = 0 300 | for i=1,#self.classes do 301 | local list = self.classListTest[i] 302 | if list:dim() ~= 0 then 303 | local ldata = list:data() 304 | for j=0,list:size(1)-1 do 305 | tdata[tidx] = ldata[j] 306 | tidx = tidx + 1 307 | end 308 | end 309 | end 310 | end 311 | end 312 | 313 | -- size(), size(class) 314 | function dataset:size(class, list) 315 | list = list or self.classList 316 | if not class then 317 | return self.numSamples 318 | elseif type(class) == 'string' then 319 | return list[self.classIndices[class]]:size(1) 320 | elseif type(class) == 'number' then 321 | return list[class]:size(1) 322 | end 323 | end 324 | 325 | -- getByClass 326 | function dataset:getByClass(class) 327 | local index = 0 328 | if self.serial_batches == 1 then 329 | index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1 330 | self.image_count = self.image_count +1 331 | else 332 | index = math.ceil(torch.uniform() * self.classListSample[class]:nElement()) 333 | end 334 | -- print('serial_batches: ', self.serial_batches) 335 | -- print('max_index:, ', self.classListSample[class]:nElement()) 336 | -- print('index: ', index) 337 | -- print('image_count', 338 | local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]])) 339 | 340 | local imgAB, flip_flag 341 | imgAB, flip_flag=self:sampleHookTrain(imgpath) 342 | return imgAB, imgpath, flip_flag 343 | end 344 | 345 | -- converts a table of samples (and corresponding labels) to a clean tensor 346 | local function tableToOutput(self, dataTable, scalarTable) 347 | local data, scalarLabels, labels 348 | local quantity = #scalarTable 349 | -- print(dataTable[1]:()) 350 | assert(dataTable[1]:dim() == 3) 351 | -- print(quantity) 352 | -- print(self.sampleSize[1]) 353 | -- print(self.sampleSize[2]) 354 | -- print(self.sampleSize[3]) 355 | data = torch.Tensor(quantity, 356 | self.sampleSize[1], self.sampleSize[2], self.sampleSize[3]) 357 | -- print(data:size()) 358 | scalarLabels = torch.LongTensor(quantity):fill(-1111) 359 | for i=1,#dataTable do 360 | data[i]:copy(dataTable[i]) 361 | scalarLabels[i] = scalarTable[i] 362 | end 363 | return data, scalarLabels 364 | end 365 | 366 | -- sampler, samples from the training set. 367 | function dataset:sample(quantity) 368 | assert(quantity) 369 | local dataTable = {} 370 | local scalarTable = {} 371 | local samplePaths = {} 372 | local flipTabel={} 373 | for i=1,quantity do 374 | local class = torch.random(1, #self.classes) 375 | -- print(class) 376 | local out, imgpath,flip_flag = self:getByClass(class) 377 | table.insert(dataTable, out) 378 | table.insert(scalarTable, class) 379 | table.insert(flipTabel,flip_flag) 380 | samplePaths[i] = imgpath 381 | -- print(imgpath) 382 | -- table.insert(pathTable, imgpath) 383 | -- table.insert() 384 | -- print('out', out:size()) 385 | end 386 | -- print('table') 387 | -- print(table) 388 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) 389 | return data, scalarLabels, samplePaths,flipTabel-- filePaths 390 | end 391 | 392 | function dataset:get(i1, i2) 393 | local indices = torch.range(i1, i2); 394 | local quantity = i2 - i1 + 1; 395 | assert(quantity > 0) 396 | -- now that indices has been initialized, get the samples 397 | local dataTable = {} 398 | local scalarTable = {} 399 | for i=1,quantity do 400 | -- load the sample 401 | local imgpath = ffi.string(torch.data(self.imagePath[indices[i]])) 402 | local out = self:sampleHookTest(imgpath) 403 | table.insert(dataTable, out) 404 | table.insert(scalarTable, self.imageClass[indices[i]]) 405 | end 406 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) 407 | return data, scalarLabels 408 | end 409 | 410 | return dataset 411 | -------------------------------------------------------------------------------- /data2/donkey_folderC.lua: -------------------------------------------------------------------------------- 1 | 2 | --[[ 3 | This data loader is a modified version of the one from dcgan.torch 4 | (see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua). 5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details] 6 | Copyright (c) 2015-present, Facebook, Inc. 7 | All rights reserved. 8 | This source code is licensed under the BSD-style license found in the 9 | LICENSE file in the root directory of this source tree. An additional grant 10 | of patent rights can be found in the PATENTS file in the same directory. 11 | ]]-- 12 | 13 | require 'image' 14 | paths.dofile('dataset.lua') 15 | -- This file contains the data-loading logic and details. 16 | -- It is run by each data-loader thread. 17 | ------------------------------------------ 18 | -------- COMMON CACHES and PATHS 19 | -- Check for existence of opt.data 20 | 21 | if opt.DATA_ROOT then 22 | opt.data=paths.concat(opt.DATA_ROOT, opt.phase) 23 | else 24 | 25 | opt.data=paths.concat(os.getenv('DATA_ROOT'), opt.phase) 26 | end 27 | 28 | if not paths.dirp(opt.data) then 29 | error('Did not find directory: ' .. opt.data) 30 | end 31 | 32 | -- a cache file of the training metadata (if doesnt exist, will be created) 33 | local cache = "cache" 34 | local cache_prefix = opt.data:gsub('/', '_') 35 | os.execute('mkdir -p cache') 36 | local trainCache = paths.concat(cache, cache_prefix .. '_trainCache.t7') 37 | 38 | -------------------------------------------------------------------------------------------- 39 | local input_nc = opt.input_nc -- input channels 40 | local output_nc = opt.output_nc 41 | local loadSize = {input_nc/3, opt.loadSize} 42 | local sampleSize = {input_nc/3, opt.fineSize} 43 | 44 | local preprocessAandBC = function(imA, imB) 45 | imA = image.scale(imA, loadSize[2], loadSize[2]) 46 | imB = image.scale(imB, loadSize[2], loadSize[2]) 47 | local perm = torch.LongTensor{3, 2, 1} 48 | 49 | -- imA = imA:index(1, perm)--:mul(256.0): brg, rgb 50 | -- imA = imA:mul(2):add(-1) --这里把范围从-1到1 改为0到1 51 | -- imB = imB:index(1, perm) 52 | -- imB = imB:mul(2):add(-1) 53 | -- imC = imC:index(1, perm) 54 | -- imC = imC:mul(2):add(-1) 55 | 56 | -- assert(imA:max()<=1,"A: badly scaled inputs") 57 | -- assert(imA:min()>=-1,"A: badly scaled inputs") 58 | -- assert(imB:max()<=1,"B: badly scaled inputs") 59 | -- assert(imB:min()>=-1,"B: badly scaled inputs") 60 | -- assert(imC:max()<=1,"C: badly scaled inputs") 61 | -- assert(imC:min()>=-1,"C: badly scaled inputs") 62 | 63 | --这里把范围从-1到1 改为0到1 64 | imA = imA:index(1, perm)--:mul(256.0): brg, rgb 65 | -- imA = imA:mul(2):add(-1) 66 | imB = imB:index(1, perm) 67 | -- imB = imB:mul(2):add(-1) 68 | 69 | 70 | assert(imA:max()<=1,"A: badly scaled inputs") 71 | assert(imA:min()>=0,"A: badly scaled inputs") 72 | assert(imB:max()<=1,"B: badly scaled inputs") 73 | assert(imB:min()>=0,"B: badly scaled inputs") 74 | 75 | 76 | 77 | local oW = sampleSize[2] 78 | local oH = sampleSize[2] 79 | local iH = imA:size(2) 80 | local iW = imA:size(3) 81 | 82 | if iH~=oH then 83 | h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 84 | end 85 | 86 | if iW~=oW then 87 | w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 88 | end 89 | if iH ~= oH or iW ~= oW then 90 | 91 | imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH) 92 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH) 93 | 94 | 95 | end 96 | 97 | local flip_flag=0 98 | if opt.flip == 1 and torch.uniform() > 0.5 then -- 99 | imA = image.hflip(imA) 100 | imB = image.hflip(imB) 101 | flip_flag=1 102 | end 103 | 104 | return imA, imB,flip_flag 105 | end 106 | 107 | 108 | 109 | local function loadImageChannel(path) 110 | local input = image.load(path, 3, 'float') 111 | input = image.scale(input, loadSize[2], loadSize[2]) 112 | 113 | local oW = sampleSize[2] 114 | local oH = sampleSize[2] 115 | local iH = input:size(2) 116 | local iW = input:size(3) 117 | 118 | if iH~=oH then 119 | h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 120 | end 121 | 122 | if iW~=oW then 123 | w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 124 | end 125 | if iH ~= oH or iW ~= oW then 126 | input = image.crop(input, w1, h1, w1 + oW, h1 + oH) 127 | end 128 | 129 | 130 | if opt.flip == 1 and torch.uniform() > 0.5 then 131 | input = image.hflip(input) 132 | end 133 | 134 | -- print(input:mean(), input:min(), input:max()) 135 | local input_lab = image.rgb2lab(input) 136 | -- print(input_lab:size()) 137 | -- os.exit() 138 | local imA = input_lab[{{1}, {}, {} }]:div(50.0) - 1.0 139 | local imB = input_lab[{{2,3},{},{}}]:div(110.0) 140 | 141 | local imAB = torch.cat(imA, imB, 1) 142 | assert(imAB:max()<=1,"A: badly scaled inputs") 143 | assert(imAB:min()>=-1,"A: badly scaled inputs") 144 | 145 | return imAB 146 | end 147 | 148 | --local function loadImage 149 | 150 | local function loadImage(path) 151 | local input = image.load(path, 3, 'float') 152 | local h = input:size(2) 153 | local w = input:size(3) 154 | 155 | local imA = image.crop(input, 0, 0, w/2, h) 156 | local imB = image.crop(input, w/2, 0, w, h) 157 | 158 | 159 | return imA, imB 160 | end 161 | 162 | local function loadImageInpaint(path) 163 | local imB = image.load(path, 3, 'float') 164 | imB = image.scale(imB, loadSize[2], loadSize[2]) 165 | local perm = torch.LongTensor{3, 2, 1} 166 | imB = imB:index(1, perm)--:mul(256.0): brg, rgb 167 | imB = imB:mul(2):add(-1) 168 | assert(imB:max()<=1,"A: badly scaled inputs") 169 | assert(imB:min()>=-1,"A: badly scaled inputs") 170 | local oW = sampleSize[2] 171 | local oH = sampleSize[2] 172 | local iH = imB:size(2) 173 | local iW = imB:size(3) 174 | if iH~=oH then 175 | h1 = math.ceil(torch.uniform(1e-2, iH-oH)) 176 | end 177 | 178 | if iW~=oW then 179 | w1 = math.ceil(torch.uniform(1e-2, iW-oW)) 180 | end 181 | if iH ~= oH or iW ~= oW then 182 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH) 183 | end 184 | local imA = imB:clone() 185 | imA[{{},{1 + oH/4, oH/2 + oH/4},{1 + oW/4, oW/2 + oW/4}}] = 1.0 186 | if opt.flip == 1 and torch.uniform() > 0.5 then 187 | imA = image.hflip(imA) 188 | imB = image.hflip(imB) 189 | end 190 | imAB = torch.cat(imA, imB, 1) 191 | return imAB 192 | end 193 | 194 | -- channel-wise mean and std. Calculate or load them from disk later in the script. 195 | local mean,std 196 | -------------------------------------------------------------------------------- 197 | -- Hooks that are used for each image that is loaded 198 | 199 | -- function to load the image, jitter it appropriately (random crops etc.) 200 | local trainHook = function(self, path) 201 | collectgarbage() 202 | local flip_flag 203 | if opt.preprocess == 'regular' then 204 | -- print('process regular') 205 | local imA, imB = loadImage(path) 206 | 207 | imA, imB,flip_flag = preprocessAandBC(imA, imB) 208 | 209 | imAB = torch.cat(imA, imB, 1) 210 | 211 | 212 | --print('image C size') 213 | --print(imAB:size()) 214 | end 215 | 216 | if opt.preprocess == 'colorization' then 217 | -- print('process colorization') 218 | imAB = loadImageChannel(path) 219 | end 220 | 221 | if opt.preprocess == 'inpaint' then 222 | -- print('process inpaint') 223 | imAB = loadImageInpaint(path) 224 | end 225 | -- print('image AB size') 226 | -- print(imAB:size()) 227 | return imAB,flip_flag 228 | end 229 | 230 | -------------------------------------- 231 | -- trainLoader 232 | print('trainCache', trainCache) 233 | --if paths.filep(trainCache) then 234 | -- print('Loading train metadata from cache') 235 | -- trainLoader = torch.load(trainCache) 236 | -- trainLoader.sampleHookTrain = trainHook 237 | -- trainLoader.loadSize = {input_nc, opt.loadSize, opt.loadSize} 238 | -- trainLoader.sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]} 239 | -- trainLoader.serial_batches = opt.serial_batches 240 | -- trainLoader.split = 100 241 | --else 242 | print('Creating train metadata') 243 | -- print(opt.data) 244 | print('serial batch:, ', opt.serial_batches) 245 | trainLoader = dataLoader{ 246 | paths = {opt.data}, 247 | loadSize = {input_nc, loadSize[2], loadSize[2]}, 248 | sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]}, 249 | split = 100, 250 | serial_batches = opt.serial_batches, 251 | verbose = true 252 | } 253 | -- print('finish') 254 | --torch.save(trainCache, trainLoader) 255 | --print('saved metadata cache at', trainCache) 256 | trainLoader.sampleHookTrain = trainHook 257 | 258 | --end 259 | collectgarbage() 260 | 261 | -- do some sanity checks on trainLoader 262 | do 263 | local class = trainLoader.imageClass 264 | local nClasses = #trainLoader.classes 265 | assert(class:max() <= nClasses, "class logic has error") 266 | assert(class:min() >= 1, "class logic has error") 267 | end 268 | -------------------------------------------------------------------------------- /imgs/IMDb/1_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/1_1.jpg -------------------------------------------------------------------------------- /imgs/IMDb/1_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/1_2.jpg -------------------------------------------------------------------------------- /imgs/IMDb/1_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/1_3.jpg -------------------------------------------------------------------------------- /imgs/IMDb/1_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/1_4.jpg -------------------------------------------------------------------------------- /imgs/IMDb/2_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/2_1.jpg -------------------------------------------------------------------------------- /imgs/IMDb/2_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/2_2.jpg -------------------------------------------------------------------------------- /imgs/IMDb/2_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/2_3.jpg -------------------------------------------------------------------------------- /imgs/IMDb/2_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/2_4.jpg -------------------------------------------------------------------------------- /imgs/architecture/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/architecture/pipeline.jpg -------------------------------------------------------------------------------- /imgs/realresults/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/realresults/1.jpg -------------------------------------------------------------------------------- /imgs/warpface/warp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/warpface/warp.jpg -------------------------------------------------------------------------------- /test.lua: -------------------------------------------------------------------------------- 1 | 2 | require 'torch' 3 | require 'nn' 4 | require 'cudnn' 5 | require 'cunn' 6 | require 'nngraph' 7 | require 'optim' 8 | util = paths.dofile('util/util.lua') 9 | require 'image' 10 | require 'stn' 11 | 12 | 13 | opt = { 14 | DATA_ROOT = './Datasets', --DataRoot 15 | batchSize = 1, -- # images in batch 16 | loadSize = 256, -- scale images to this size 17 | fineSize = 256, -- then crop to this size 18 | flip=0, -- horizontal mirroring data augmentation 19 | gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X (cpu untested) 20 | which_direction = 'AtoB', -- AtoB or BtoA 21 | phase = 'RealImg', -- test dataset name 22 | preprocess = 'regular', -- for special purpose preprocessing, e.g., for colorization, change this (selects preprocessing functions in util.lua) 23 | aspect_ratio = 1.0, -- aspect ratio of result images 24 | name = 'FaceRestoration', -- name of experiment, selects which model to run, should generally should be passed on command line 25 | input_nc = 3, -- # of input image channels 26 | output_nc = 3, -- # of output image channels 27 | serial_batches = 1, -- if 1, takes images in order to make batches, otherwise takes them randomly 28 | serial_batch_iter = 1, -- iter into serial image list 29 | cudnn = 0, -- set to 0 to not use cudnn (untested) 30 | checkpoints_dir = './checkpoints', -- loads models from here 31 | results_dir='./results', -- saves results here 32 | which_model = 'netG', -- which epoch to test? 33 | 34 | } 35 | 36 | 37 | -- one-line argument parser. parses enviroment variables to override the defaults 38 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end 39 | opt.nThreads = 1 -- test only works with 1 thread... 40 | print(opt) 41 | 42 | 43 | opt.manualSeed = torch.random(1, 10000) -- set seed 44 | print("Random Seed: " .. opt.manualSeed) 45 | torch.manualSeed(opt.manualSeed) 46 | torch.setdefaulttensortype('torch.FloatTensor') 47 | 48 | opt.netG_name = opt.name .. '/' .. opt.which_model 49 | local data_loader = paths.dofile('data2/dataC.lua') 50 | print('#threads...' .. opt.nThreads) 51 | local data = data_loader.new(opt.nThreads, opt) 52 | print("Dataset Size: ", data:size()) 53 | opt.how_many=data:size() 54 | --------------------------------------------------------------------------------------------------- 55 | 56 | local input = torch.FloatTensor(opt.batchSize,6,opt.fineSize,opt.fineSize) 57 | local output = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 58 | local guidance = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 59 | local outputface = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize) 60 | 61 | local netG = util.load(paths.concat(opt.checkpoints_dir, opt.netG_name .. '.t7'), opt):cuda() 62 | 63 | netG:apply(printNet)--print network 64 | local filepaths = {} -- paths to images tested on 65 | local filenames ={} 66 | function TableConcat(t1,t2) 67 | for i=1,#t2 do 68 | t1[#t1+1] = t2[i] 69 | end 70 | return t1 71 | end 72 | 73 | 74 | for n=1,math.floor(opt.how_many/opt.batchSize) do 75 | print('processing batch ' .. n) 76 | 77 | local real_data, filepaths_curr, flips = data:getBatch() 78 | local imgname2 = filepaths_curr[1] 79 | 80 | filepaths_curr = util.basename_batch(filepaths_curr) 81 | print('filepaths_curr: ', filepaths_curr) 82 | 83 | real_A=real_data[{ {}, {1,3}, {}, {} }]:clone() -- Blur image 84 | real_B=real_data[{ {}, {4,6}, {}, {} }]:clone() -- guidance 85 | 86 | local outputss = netG:forward({real_A:cuda(),real_B:cuda()}) 87 | real_WC=outputss[1]:clone()--warped guidance 88 | fake_gout=outputss[3]:clone()--restoration result 89 | 90 | input = util.deprocess_batch(real_A):float() 91 | outputwarp = util.deprocess_batch(real_WC):float() 92 | guidance = util.deprocess_batch(real_B):float() 93 | outputface = util.deprocess_batch(fake_gout):float() 94 | 95 | 96 | -- save images 97 | paths.mkdir(paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase)) 98 | local image_dir = paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase, 'images') 99 | paths.mkdir(image_dir) 100 | paths.mkdir(paths.concat(image_dir,'Input')) 101 | paths.mkdir(paths.concat(image_dir,'WarpGuidance')) 102 | paths.mkdir(paths.concat(image_dir,'Guidance')) 103 | paths.mkdir(paths.concat(image_dir,'Output')) 104 | 105 | for i=1, opt.batchSize do 106 | image.save(paths.concat(image_dir,'Input',filepaths_curr[i]), image.scale(input[i],input[i]:size(2),input[i]:size(3)/opt.aspect_ratio)) 107 | image.save(paths.concat(image_dir,'WarpGuidance',filepaths_curr[i]), image.scale(outputwarp[i],outputwarp[i]:size(2),outputwarp[i]:size(3)/opt.aspect_ratio)) 108 | image.save(paths.concat(image_dir,'Guidance',filepaths_curr[i]), image.scale(guidance[i],guidance[i]:size(2),guidance[i]:size(3)/opt.aspect_ratio)) 109 | image.save(paths.concat(image_dir,'Output',filepaths_curr[i]), image.scale(outputface[i],outputface[i]:size(2),outputface[i]:size(3)/opt.aspect_ratio)) 110 | end 111 | print('Saved images to: ', image_dir) 112 | BB=string.split(imgname2,'/') 113 | local iename = BB[#BB] 114 | local ImgName={} 115 | ImgName[1]=iename 116 | 117 | filepaths = TableConcat(filepaths, filepaths_curr) 118 | filenames = TableConcat(filenames, ImgName) 119 | end 120 | 121 | -- make webpage 122 | io.output(paths.concat(opt.results_dir,opt.netG_name .. '_' .. opt.phase, 'index.html')) 123 | io.write('') 124 | io.write('') 125 | for i=1, #filepaths do 126 | io.write('') 127 | io.write('') 128 | io.write('') 129 | io.write('') 130 | io.write('') 131 | io.write('') 132 | io.write('') 133 | end 134 | io.write('
Image #InputGuidanceWarped GuidanceOutput
' .. filenames[i] .. '
') 135 | -------------------------------------------------------------------------------- /util/cudnn_convert_custom.lua: -------------------------------------------------------------------------------- 1 | -- modified from https://github.com/NVIDIA/torch-cudnn/blob/master/convert.lua 2 | -- removed error on nngraph 3 | 4 | -- modules that can be converted to nn seamlessly 5 | local layer_list = { 6 | 'BatchNormalization', 7 | 'SpatialBatchNormalization', 8 | 'SpatialConvolution', 9 | 'SpatialCrossMapLRN', 10 | 'SpatialFullConvolution', 11 | 'SpatialMaxPooling', 12 | 'SpatialAveragePooling', 13 | 'ReLU', 14 | 'Tanh', 15 | 'Sigmoid', 16 | 'SoftMax', 17 | 'LogSoftMax', 18 | 'VolumetricBatchNormalization', 19 | 'VolumetricConvolution', 20 | 'VolumetricFullConvolution', 21 | 'VolumetricMaxPooling', 22 | 'VolumetricAveragePooling', 23 | } 24 | 25 | -- goes over a given net and converts all layers to dst backend 26 | -- for example: net = cudnn_convert_custom(net, cudnn) 27 | -- same as cudnn.convert with gModule check commented out 28 | function cudnn_convert_custom(net, dst, exclusion_fn) 29 | return net:replace(function(x) 30 | --if torch.type(x) == 'nn.gModule' then 31 | -- io.stderr:write('Warning: cudnn.convert does not work with nngraph yet. Ignoring nn.gModule') 32 | -- return x 33 | --end 34 | local y = 0 35 | local src = dst == nn and cudnn or nn 36 | local src_prefix = src == nn and 'nn.' or 'cudnn.' 37 | local dst_prefix = dst == nn and 'nn.' or 'cudnn.' 38 | 39 | local function convert(v) 40 | local y = {} 41 | torch.setmetatable(y, dst_prefix..v) 42 | if v == 'ReLU' then y = dst.ReLU() end -- because parameters 43 | for k,u in pairs(x) do y[k] = u end 44 | if src == cudnn and x.clearDesc then x.clearDesc(y) end 45 | if src == cudnn and v == 'SpatialAveragePooling' then 46 | y.divide = true 47 | y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING' 48 | end 49 | return y 50 | end 51 | 52 | if exclusion_fn and exclusion_fn(x) then 53 | return x 54 | end 55 | local t = torch.typename(x) 56 | if t == 'nn.SpatialConvolutionMM' then 57 | y = convert('SpatialConvolution') 58 | elseif t == 'inn.SpatialCrossResponseNormalization' then 59 | y = convert('SpatialCrossMapLRN') 60 | else 61 | for i,v in ipairs(layer_list) do 62 | if torch.typename(x) == src_prefix..v then 63 | y = convert(v) 64 | end 65 | end 66 | end 67 | return y == 0 and x or y 68 | end) 69 | end -------------------------------------------------------------------------------- /util/util.lua: -------------------------------------------------------------------------------- 1 | -- 2 | -- code derived from https://github.com/soumith/dcgan.torch 3 | -- 4 | 5 | local util = {} 6 | 7 | require 'torch' 8 | function printNet(net) 9 | 10 | for i = 1, net:size(1) do 11 | print(string.format("%d: %s", i, net.modules[i])) 12 | end 13 | 14 | end 15 | function util.normalize(img) 16 | -- rescale image to 0 .. 1 17 | local min = img:min() 18 | local max = img:max() 19 | 20 | img = torch.FloatTensor(img:size()):copy(img) 21 | img:add(-min):mul(1/(max-min)) 22 | return img 23 | end 24 | 25 | function util.normalizeBatch(batch) 26 | for i = 1, batch:size(1) do 27 | batch[i] = util.normalize(batch[i]:squeeze()) 28 | end 29 | return batch 30 | end 31 | 32 | function util.basename_batch(batch) 33 | for i = 1, #batch do 34 | batch[i] = paths.basename(batch[i]) 35 | end 36 | return batch 37 | end 38 | 39 | 40 | 41 | -- default preprocessing 42 | -- 43 | -- Preprocesses an image before passing it to a net 44 | -- Converts from RGB to BGR and rescales from [0,1] to [-1,1] 45 | function util.preprocess(img) 46 | -- RGB to BGR 47 | local perm = torch.LongTensor{3, 2, 1} 48 | img = img:index(1, perm) 49 | 50 | -- -- [0,1] to [-1,1] 51 | -- img = img:mul(2):add(-1) 52 | 53 | -- -- check that input is in expected range 54 | -- assert(img:max()<=1,"badly scaled inputs") 55 | -- assert(img:min()>=-1,"badly scaled inputs") 56 | 57 | return img 58 | end 59 | 60 | -- Undo the above preprocessing. 61 | function util.deprocess(img) 62 | -- BGR to RGB 63 | local perm = torch.LongTensor{3, 2, 1} 64 | img = img:index(1, perm) 65 | 66 | -- [-1,1] to [0,1] 67 | -- 这里也改了,直接时0到1 68 | -- img = img:add(1):div(2) 69 | 70 | return img 71 | end 72 | function util.deprocess_vgg(img) 73 | -- BGR to RGB 74 | -- [-1,1] to [0,1] 75 | 76 | -- img = img:add(1):div(2) 77 | 78 | return img 79 | end 80 | function util.preprocess_batch(batch) 81 | for i = 1, batch:size(1) do 82 | batch[i] = util.preprocess(batch[i]:squeeze()) 83 | end 84 | return batch 85 | end 86 | 87 | function util.deprocess_batch(batch) 88 | for i = 1, batch:size(1) do 89 | batch[i] = util.deprocess(batch[i]:squeeze()) 90 | end 91 | return batch 92 | end 93 | 94 | 95 | function util.deprocess_batch_vgg(batch) 96 | for i = 1, batch:size(1) do 97 | batch[i] = util.deprocess_vgg(batch[i]:squeeze()) 98 | end 99 | return batch 100 | end 101 | -- preprocessing specific to colorization 102 | 103 | function util.deprocessLAB(L, AB) 104 | local L2 = torch.Tensor(L:size()):copy(L) 105 | if L2:dim() == 3 then 106 | L2 = L2[{1, {}, {} }] 107 | end 108 | local AB2 = torch.Tensor(AB:size()):copy(AB) 109 | AB2 = torch.clamp(AB2, -1.0, 1.0) 110 | -- local AB2 = AB 111 | L2 = L2:add(1):mul(50.0) 112 | AB2 = AB2:mul(110.0) 113 | 114 | L2 = L2:reshape(1, L2:size(1), L2:size(2)) 115 | 116 | im_lab = torch.cat(L2, AB2, 1) 117 | im_rgb = torch.clamp(image.lab2rgb(im_lab):mul(255.0), 0.0, 255.0)/255.0 118 | 119 | return im_rgb 120 | end 121 | 122 | function util.deprocessL(L) 123 | local L2 = torch.Tensor(L:size()):copy(L) 124 | L2 = L2:add(1):mul(255.0/2.0) 125 | 126 | if L2:dim()==2 then 127 | L2 = L2:reshape(1,L2:size(1),L2:size(2)) 128 | end 129 | L2 = L2:repeatTensor(L2,3,1,1)/255.0 130 | 131 | return L2 132 | end 133 | 134 | function util.deprocessL_batch(batch) 135 | local batch_new = {} 136 | for i = 1, batch:size(1) do 137 | batch_new[i] = util.deprocessL(batch[i]:squeeze()) 138 | end 139 | return batch_new 140 | end 141 | 142 | function util.deprocessLAB_batch(batchL, batchAB) 143 | local batch = {} 144 | 145 | for i = 1, batchL:size(1) do 146 | batch[i] = util.deprocessLAB(batchL[i]:squeeze(), batchAB[i]:squeeze()) 147 | end 148 | 149 | return batch 150 | end 151 | 152 | 153 | function util.scaleBatch(batch,s1,s2) 154 | local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2) 155 | for i = 1, batch:size(1) do 156 | scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze() 157 | end 158 | return scaled_batch 159 | end 160 | 161 | 162 | 163 | function util.toTrivialBatch(input) 164 | return input:reshape(1,input:size(1),input:size(2),input:size(3)) 165 | end 166 | function util.fromTrivialBatch(input) 167 | return input[1] 168 | end 169 | 170 | 171 | 172 | function util.scaleImage(input, loadSize) 173 | -- replicate bw images to 3 channels 174 | if input:size(1)==1 then 175 | input = torch.repeatTensor(input,3,1,1) 176 | end 177 | 178 | input = image.scale(input, loadSize, loadSize) 179 | 180 | return input 181 | end 182 | 183 | function util.getAspectRatio(path) 184 | local input = image.load(path, 3, 'float') 185 | local ar = input:size(3)/input:size(2) 186 | return ar 187 | end 188 | 189 | function util.loadImage(path, loadSize, nc) 190 | local input = image.load(path, 3, 'float') 191 | input= util.preprocess(util.scaleImage(input, loadSize)) 192 | 193 | if nc == 1 then 194 | input = input[{{1}, {}, {}}] 195 | end 196 | 197 | return input 198 | end 199 | 200 | 201 | 202 | -- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations 203 | function util.load(filename, opt) 204 | if opt.cudnn>0 then 205 | require 'cudnn' 206 | end 207 | 208 | if opt.gpu > 0 then 209 | require 'cunn' 210 | end 211 | 212 | local net = torch.load(filename) 213 | 214 | if opt.gpu > 0 then 215 | net:cuda() 216 | 217 | -- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below 218 | if net.forwardnodes then 219 | for i=1,#net.forwardnodes do 220 | if net.forwardnodes[i].data.module then 221 | net.forwardnodes[i].data.module:cuda() 222 | end 223 | end 224 | end 225 | else 226 | net:float() 227 | end 228 | net:apply(function(m) if m.weight then 229 | m.gradWeight = m.weight:clone():zero(); 230 | m.gradBias = m.bias:clone():zero(); end end) 231 | return net 232 | end 233 | 234 | function util.cudnn(net) 235 | require 'cudnn' 236 | require 'util/cudnn_convert_custom' 237 | return cudnn_convert_custom(net, cudnn) 238 | end 239 | 240 | function util.containsValue(table, value) 241 | for k, v in pairs(table) do 242 | if v == value then return true end 243 | end 244 | return false 245 | end 246 | 247 | function printNet(m) 248 | local name = torch.type(m) 249 | print(name) 250 | end 251 | 252 | return util 253 | --------------------------------------------------------------------------------