├── Datasets
├── IMDB
│ ├── 1.jpg
│ ├── 2.jpg
│ └── 3.jpg
└── RealImg
│ ├── 1.jpg
│ ├── 2.jpg
│ ├── 3.jpg
│ └── 4.jpg
├── README.md
├── data2
├── dataC.lua
├── dataset.lua
└── donkey_folderC.lua
├── imgs
├── IMDb
│ ├── 1_1.jpg
│ ├── 1_2.jpg
│ ├── 1_3.jpg
│ ├── 1_4.jpg
│ ├── 2_1.jpg
│ ├── 2_2.jpg
│ ├── 2_3.jpg
│ └── 2_4.jpg
├── architecture
│ └── pipeline.jpg
├── realresults
│ └── 1.jpg
└── warpface
│ └── warp.jpg
├── test.lua
└── util
├── cudnn_convert_custom.lua
└── util.lua
/Datasets/IMDB/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/IMDB/1.jpg
--------------------------------------------------------------------------------
/Datasets/IMDB/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/IMDB/2.jpg
--------------------------------------------------------------------------------
/Datasets/IMDB/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/IMDB/3.jpg
--------------------------------------------------------------------------------
/Datasets/RealImg/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/RealImg/1.jpg
--------------------------------------------------------------------------------
/Datasets/RealImg/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/RealImg/2.jpg
--------------------------------------------------------------------------------
/Datasets/RealImg/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/RealImg/3.jpg
--------------------------------------------------------------------------------
/Datasets/RealImg/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/Datasets/RealImg/4.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [GFRNet](https://arxiv.org/abs/1804.04829)
2 | Torch implementation for [Learning Warped Guidance for Blind Face Restoration](https://arxiv.org/abs/1804.04829)
3 |
4 | # GFRNet framework
5 | Overview of our GFRNet. The WarpNet takes the degraded observation and guided image as input to predict the dense flow field, which is adopted to deform guided image to the warped guidance. Warped guidance is expected to be spatially well aligned with ground-truth. Thus the RecNet takes warped guidance and degradated observation as input to produce the restoration result.
6 |
7 |
8 |
9 |
10 | # Testing
11 |
12 | ```bash
13 | th test.lua
14 | ```
15 | # Models
16 | Download the pre-trained model with the following url and put it into ./checkpoints/FaceRestoration/.
17 | - [BaiduNetDisk](https://pan.baidu.com/s/1q96l3qmTf5Luh-nlqot6Xw)
18 | - [GoogleDrive](https://drive.google.com/open?id=1PhE3Gi9-eHrofyR3LhqEhuVnzh9D7IsX)
19 |
20 | # Results
21 | ## Restoration on real low quality images
22 | The first row is real low quality image(close-up in right bottom is the guided image). The second row is GFRNet result.
23 |
24 |
25 |
26 | ## Warped guidance
27 |
28 |
29 |
30 | ## IMDB results
31 | The content marked with green box is the restoration results by our GFRNet. All of these images are collected from [Internet Movie Database (IMDb)](https://www.imdb.com/).
32 |
33 |
34 | Input | Guided Image | Bicubic | GFRNet Results |
35 |
36 |
37 |
38 |
39 | |
40 |
41 |
42 | |
43 |
44 |
45 | |
46 |
47 |
48 | |
49 |
50 |
51 |
52 |
53 | |
54 |
55 |
56 | |
57 |
58 |
59 | |
60 |
61 |
62 | |
63 |
64 |
65 |
66 | - [More IMDB resutls can be found here](http://csxmli.xin/GFRNet/).
67 | - [Poster can be found here](http://csxmli.xin/GFRNet/poster.pdf).
68 |
69 | # Requirements and Dependencies
70 |
71 | - [Torch](https://github.com/torch/distro)
72 | - [Cuda](https://developer.nvidia.com/cuda-toolkit-archive)-8.0
73 | - [Stn](https://github.com/qassemoquab/stnbhwd)
74 |
75 | # Acknowledgments
76 |
77 | Code borrows heavily from [pix2pix](https://github.com/phillipi/pix2pix). Thanks for their excellent work!
78 |
79 | # Citation
80 |
81 | ```
82 | @InProceedings{Li_2018_ECCV,
83 | author = {Li, Xiaoming and Liu, Ming and Ye, Yuting and Zuo, Wangmeng and Lin, Liang and Yang, Ruigang},
84 | title = {Learning Warped Guidance for Blind Face Restoration},
85 | booktitle = {The European Conference on Computer Vision (ECCV)},
86 | month = {September},
87 | year = {2018}
88 | }
89 | ```
90 |
--------------------------------------------------------------------------------
/data2/dataC.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | This data loader is a modified version of the one from dcgan.torch
3 | (see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua).
4 |
5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
6 | ]]--
7 |
8 | local Threads = require 'threads'
9 | Threads.serialization('threads.sharedserialize')
10 |
11 | local data = {}
12 |
13 | local result = {}
14 | local unpack = unpack and unpack or table.unpack
15 |
16 | function data.new(n, opt_)
17 | opt_ = opt_ or {}
18 | local self = {}
19 | for k,v in pairs(data) do
20 | self[k] = v
21 | end
22 |
23 | local donkey_file = 'donkey_folderC.lua'
24 | -- print('n..' .. n)
25 | if n > 0 then
26 | local options = opt_
27 | self.threads = Threads(n,
28 | function() require 'torch' end,
29 | function(idx)
30 | opt = options
31 | tid = idx
32 | local seed = (opt.manualSeed and opt.manualSeed or 0) + idx
33 | torch.manualSeed(seed)
34 | torch.setnumthreads(1)
35 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
36 | assert(options, 'options not found')
37 | assert(opt, 'opt not given')
38 | print(opt)
39 | paths.dofile(donkey_file)
40 | end
41 |
42 | )
43 | else
44 | if donkey_file then paths.dofile(donkey_file) end
45 | -- print('empty threads')
46 | self.threads = {}
47 | function self.threads:addjob(f1, f2) f2(f1()) end
48 | function self.threads:dojob() end
49 | function self.threads:synchronize() end
50 | end
51 |
52 | local nSamples = 0
53 | self.threads:addjob(function() return trainLoader:size() end,
54 | function(c) nSamples = c end)
55 | self.threads:synchronize()
56 | self._size = nSamples
57 |
58 | for i = 1, n do
59 | self.threads:addjob(self._getFromThreads,
60 | self._pushResult)
61 | end
62 | -- print(self.threads)
63 | return self
64 | end
65 |
66 | function data._getFromThreads()
67 | assert(opt.batchSize, 'opt.batchSize not found')
68 | return trainLoader:sample(opt.batchSize)
69 | end
70 |
71 | function data._pushResult(...)
72 | local res = {...}
73 | if res == nil then
74 | self.threads:synchronize()
75 | end
76 | result[1] = res
77 |
78 | end
79 |
80 |
81 |
82 | function data:getBatch()
83 | -- queue another job
84 | -- print(self.threads)
85 | self.threads:addjob(self._getFromThreads, self._pushResult)
86 | self.threads:dojob()
87 | local res = result[1]
88 |
89 | -- print(res)
90 | -- print('result')
91 | -- print(res)
92 | -- os.exit()
93 | -- paths = results[3]
94 | -- print(paths)
95 |
96 | img_data = res[1]
97 | img_paths = res[3]
98 | -- print(img_data:size())
99 | -- print(type(img_data))
100 | -- print(img_paths)
101 | -- print(type(img_paths))
102 | -- result[3] = nil
103 | -- print(type(res))
104 |
105 | result[1] = nil
106 | if torch.type(img_data) == 'table' then
107 | img_data = unpack(img_data)
108 | end
109 |
110 |
111 | return img_data, img_paths,res[4]
112 | end
113 |
114 | function data:size()
115 | return self._size
116 | end
117 |
118 | return data
119 |
--------------------------------------------------------------------------------
/data2/dataset.lua:
--------------------------------------------------------------------------------
1 | --[[
2 | Copyright (c) 2015-present, Facebook, Inc.
3 | All rights reserved.
4 |
5 | This source code is licensed under the BSD-style license found in the
6 | LICENSE file in the root directory of this source tree. An additional grant
7 | of patent rights can be found in the PATENTS file in the same directory.
8 | ]]--
9 |
10 | require 'torch'
11 | torch.setdefaulttensortype('torch.FloatTensor')
12 | local ffi = require 'ffi'
13 | local class = require('pl.class')
14 | local dir = require 'pl.dir'
15 | local tablex = require 'pl.tablex'
16 | local argcheck = require 'argcheck'
17 | require 'sys'
18 | require 'xlua'
19 | require 'image'
20 |
21 | local dataset = torch.class('dataLoader')
22 |
23 | local initcheck = argcheck{
24 | pack=true,
25 | help=[[
26 | A dataset class for images in a flat folder structure (folder-name is class-name).
27 | Optimized for extremely large datasets (upwards of 14 million images).
28 | Tested only on Linux (as it uses command-line linux utilities to scale up)
29 | ]],
30 | {check=function(paths)
31 | local out = true;
32 | for k,v in ipairs(paths) do
33 | if type(v) ~= 'string' then
34 | print('paths can only be of string input');
35 | out = false
36 | end
37 | end
38 | return out
39 | end,
40 | name="paths",
41 | type="table",
42 | help="Multiple paths of directories with images"},
43 |
44 | {name="sampleSize",
45 | type="table",
46 | help="a consistent sample size to resize the images"},
47 |
48 | {name="split",
49 | type="number",
50 | help="Percentage of split to go to Training"
51 | },
52 | {name="serial_batches",
53 | type="number",
54 | help="if randomly sample training images"},
55 |
56 | {name="samplingMode",
57 | type="string",
58 | help="Sampling mode: random | balanced ",
59 | default = "balanced"},
60 |
61 | {name="verbose",
62 | type="boolean",
63 | help="Verbose mode during initialization",
64 | default = false},
65 |
66 | {name="loadSize",
67 | type="table",
68 | help="a size to load the images to, initially",
69 | opt = true},
70 |
71 | {name="forceClasses",
72 | type="table",
73 | help="If you want this loader to map certain classes to certain indices, "
74 | .. "pass a classes table that has {classname : classindex} pairs."
75 | .. " For example: {3 : 'dog', 5 : 'cat'}"
76 | .. "This function is very useful when you want two loaders to have the same "
77 | .. "class indices (trainLoader/testLoader for example)",
78 | opt = true},
79 |
80 | {name="sampleHookTrain",
81 | type="function",
82 | help="applied to sample during training(ex: for lighting jitter). "
83 | .. "It takes the image path as input",
84 | opt = true},
85 |
86 | {name="sampleHookTest",
87 | type="function",
88 | help="applied to sample during testing",
89 | opt = true},
90 | }
91 |
92 | function dataset:__init(...)
93 |
94 | -- argcheck
95 | local args = initcheck(...)
96 | print(args)
97 | for k,v in pairs(args) do self[k] = v end
98 |
99 | if not self.loadSize then self.loadSize = self.sampleSize; end
100 |
101 | if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end
102 | if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end
103 | self.image_count = 1
104 | -- print('image_count_init', self.image_count)
105 | -- find class names
106 | self.classes = {}
107 | local classPaths = {}
108 | if self.forceClasses then
109 | for k,v in pairs(self.forceClasses) do
110 | self.classes[k] = v
111 | classPaths[k] = {}
112 | end
113 | end
114 | local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end
115 | -- loop over each paths folder, get list of unique class names,
116 | -- also store the directory paths per class
117 | -- for each class,
118 | for k,path in ipairs(self.paths) do
119 | -- print('path', path)
120 | local dirs = {} -- hack
121 | dirs[1] = path
122 | -- local dirs = dir.getdirectories(path);
123 | for k,dirpath in ipairs(dirs) do
124 | local class = paths.basename(dirpath)
125 | local idx = tableFind(self.classes, class)
126 | -- print(class)
127 | -- print(idx)
128 | if not idx then
129 | table.insert(self.classes, class)
130 | idx = #self.classes
131 | classPaths[idx] = {}
132 | end
133 | if not tableFind(classPaths[idx], dirpath) then
134 | table.insert(classPaths[idx], dirpath);
135 | end
136 | end
137 | end
138 |
139 | self.classIndices = {}
140 | for k,v in ipairs(self.classes) do
141 | self.classIndices[v] = k
142 | end
143 |
144 | -- define command-line tools, try your best to maintain OSX compatibility
145 | local wc = 'wc'
146 | local cut = 'cut'
147 | local find = 'find -H' -- if folder name is symlink, do find inside it after dereferencing
148 | if jit.os == 'OSX' then
149 | wc = 'gwc'
150 | cut = 'gcut'
151 | find = 'gfind'
152 | end
153 | ----------------------------------------------------------------------
154 | -- Options for the GNU find command
155 | local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
156 | local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
157 | for i=2,#extensionList do
158 | findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
159 | end
160 |
161 | -- find the image path names
162 | self.imagePath = torch.CharTensor() -- path to each image in dataset
163 | self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
164 | self.classList = {} -- index of imageList to each image of a particular class
165 | self.classListSample = self.classList -- the main list used when sampling data
166 |
167 | print('running "find" on each class directory, and concatenate all'
168 | .. ' those filenames into a single file containing all image paths for a given class')
169 | -- so, generates one file per class
170 | local classFindFiles = {}
171 | for i=1,#self.classes do
172 | classFindFiles[i] = os.tmpname()
173 | end
174 | local combinedFindList = os.tmpname();
175 |
176 | local tmpfile = os.tmpname()
177 | local tmphandle = assert(io.open(tmpfile, 'w'))
178 | -- iterate over classes
179 | for i, class in ipairs(self.classes) do
180 | -- iterate over classPaths
181 | for j,path in ipairs(classPaths[i]) do
182 | local command = find .. ' "' .. path .. '" ' .. findOptions
183 | .. ' >>"' .. classFindFiles[i] .. '" \n'
184 | tmphandle:write(command)
185 | end
186 | end
187 | io.close(tmphandle)
188 | os.execute('bash ' .. tmpfile)
189 | os.execute('rm -f ' .. tmpfile)
190 |
191 | print('now combine all the files to a single large file')
192 | local tmpfile = os.tmpname()
193 | local tmphandle = assert(io.open(tmpfile, 'w'))
194 | -- concat all finds to a single large file in the order of self.classes
195 | for i=1,#self.classes do
196 | local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n'
197 | tmphandle:write(command)
198 | end
199 | io.close(tmphandle)
200 | os.execute('bash ' .. tmpfile)
201 | os.execute('rm -f ' .. tmpfile)
202 |
203 | --==========================================================================
204 | print('load the large concatenated list of sample paths to self.imagePath')
205 | local cmd = wc .. " -L '"
206 | .. combinedFindList .. "' |"
207 | .. cut .. " -f1 -d' '"
208 | print('cmd..' .. cmd)
209 | local maxPathLength = tonumber(sys.fexecute(wc .. " -L '"
210 | .. combinedFindList .. "' |"
211 | .. cut .. " -f1 -d' '")) + 1
212 | local length = tonumber(sys.fexecute(wc .. " -l '"
213 | .. combinedFindList .. "' |"
214 | .. cut .. " -f1 -d' '"))
215 | assert(length > 0, "Could not find any image file in the given input paths")
216 | assert(maxPathLength > 0, "paths of files are length 0?")
217 | self.imagePath:resize(length, maxPathLength):fill(0)
218 | local s_data = self.imagePath:data()
219 | local count = 0
220 | for line in io.lines(combinedFindList) do
221 | ffi.copy(s_data, line)
222 | s_data = s_data + maxPathLength
223 | if self.verbose and count % 10000 == 0 then
224 | xlua.progress(count, length)
225 | end;
226 | count = count + 1
227 | end
228 |
229 | self.numSamples = self.imagePath:size(1)
230 | if self.verbose then print(self.numSamples .. ' samples found.') end
231 | --==========================================================================
232 | print('Updating classList and imageClass appropriately')
233 | self.imageClass:resize(self.numSamples)
234 | local runningIndex = 0
235 | for i=1,#self.classes do
236 | if self.verbose then xlua.progress(i, #(self.classes)) end
237 | local length = tonumber(sys.fexecute(wc .. " -l '"
238 | .. classFindFiles[i] .. "' |"
239 | .. cut .. " -f1 -d' '"))
240 | if length == 0 then
241 | error('Class has zero samples')
242 | else
243 | self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long()
244 | self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i)
245 | end
246 | runningIndex = runningIndex + length
247 | end
248 |
249 | --==========================================================================
250 | -- clean up temporary files
251 | print('Cleaning up temporary files')
252 | local tmpfilelistall = ''
253 | for i=1,#(classFindFiles) do
254 | tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"'
255 | if i % 1000 == 0 then
256 | os.execute('rm -f ' .. tmpfilelistall)
257 | tmpfilelistall = ''
258 | end
259 | end
260 | os.execute('rm -f ' .. tmpfilelistall)
261 | os.execute('rm -f "' .. combinedFindList .. '"')
262 | --==========================================================================
263 |
264 | if self.split == 100 then
265 | self.testIndicesSize = 0
266 | else
267 | print('Splitting training and test sets to a ratio of '
268 | .. self.split .. '/' .. (100-self.split))
269 | self.classListTrain = {}
270 | self.classListTest = {}
271 | self.classListSample = self.classListTrain
272 | local totalTestSamples = 0
273 | -- split the classList into classListTrain and classListTest
274 | for i=1,#self.classes do
275 | local list = self.classList[i]
276 | local count = self.classList[i]:size(1)
277 | local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round
278 | local perm = torch.randperm(count)
279 | self.classListTrain[i] = torch.LongTensor(splitidx)
280 | for j=1,splitidx do
281 | self.classListTrain[i][j] = list[perm[j]]
282 | end
283 | if splitidx == count then -- all samples were allocated to train set
284 | self.classListTest[i] = torch.LongTensor()
285 | else
286 | self.classListTest[i] = torch.LongTensor(count-splitidx)
287 | totalTestSamples = totalTestSamples + self.classListTest[i]:size(1)
288 | local idx = 1
289 | for j=splitidx+1,count do
290 | self.classListTest[i][idx] = list[perm[j]]
291 | idx = idx + 1
292 | end
293 | end
294 | end
295 | -- Now combine classListTest into a single tensor
296 | self.testIndices = torch.LongTensor(totalTestSamples)
297 | self.testIndicesSize = totalTestSamples
298 | local tdata = self.testIndices:data()
299 | local tidx = 0
300 | for i=1,#self.classes do
301 | local list = self.classListTest[i]
302 | if list:dim() ~= 0 then
303 | local ldata = list:data()
304 | for j=0,list:size(1)-1 do
305 | tdata[tidx] = ldata[j]
306 | tidx = tidx + 1
307 | end
308 | end
309 | end
310 | end
311 | end
312 |
313 | -- size(), size(class)
314 | function dataset:size(class, list)
315 | list = list or self.classList
316 | if not class then
317 | return self.numSamples
318 | elseif type(class) == 'string' then
319 | return list[self.classIndices[class]]:size(1)
320 | elseif type(class) == 'number' then
321 | return list[class]:size(1)
322 | end
323 | end
324 |
325 | -- getByClass
326 | function dataset:getByClass(class)
327 | local index = 0
328 | if self.serial_batches == 1 then
329 | index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1
330 | self.image_count = self.image_count +1
331 | else
332 | index = math.ceil(torch.uniform() * self.classListSample[class]:nElement())
333 | end
334 | -- print('serial_batches: ', self.serial_batches)
335 | -- print('max_index:, ', self.classListSample[class]:nElement())
336 | -- print('index: ', index)
337 | -- print('image_count',
338 | local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]]))
339 |
340 | local imgAB, flip_flag
341 | imgAB, flip_flag=self:sampleHookTrain(imgpath)
342 | return imgAB, imgpath, flip_flag
343 | end
344 |
345 | -- converts a table of samples (and corresponding labels) to a clean tensor
346 | local function tableToOutput(self, dataTable, scalarTable)
347 | local data, scalarLabels, labels
348 | local quantity = #scalarTable
349 | -- print(dataTable[1]:())
350 | assert(dataTable[1]:dim() == 3)
351 | -- print(quantity)
352 | -- print(self.sampleSize[1])
353 | -- print(self.sampleSize[2])
354 | -- print(self.sampleSize[3])
355 | data = torch.Tensor(quantity,
356 | self.sampleSize[1], self.sampleSize[2], self.sampleSize[3])
357 | -- print(data:size())
358 | scalarLabels = torch.LongTensor(quantity):fill(-1111)
359 | for i=1,#dataTable do
360 | data[i]:copy(dataTable[i])
361 | scalarLabels[i] = scalarTable[i]
362 | end
363 | return data, scalarLabels
364 | end
365 |
366 | -- sampler, samples from the training set.
367 | function dataset:sample(quantity)
368 | assert(quantity)
369 | local dataTable = {}
370 | local scalarTable = {}
371 | local samplePaths = {}
372 | local flipTabel={}
373 | for i=1,quantity do
374 | local class = torch.random(1, #self.classes)
375 | -- print(class)
376 | local out, imgpath,flip_flag = self:getByClass(class)
377 | table.insert(dataTable, out)
378 | table.insert(scalarTable, class)
379 | table.insert(flipTabel,flip_flag)
380 | samplePaths[i] = imgpath
381 | -- print(imgpath)
382 | -- table.insert(pathTable, imgpath)
383 | -- table.insert()
384 | -- print('out', out:size())
385 | end
386 | -- print('table')
387 | -- print(table)
388 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
389 | return data, scalarLabels, samplePaths,flipTabel-- filePaths
390 | end
391 |
392 | function dataset:get(i1, i2)
393 | local indices = torch.range(i1, i2);
394 | local quantity = i2 - i1 + 1;
395 | assert(quantity > 0)
396 | -- now that indices has been initialized, get the samples
397 | local dataTable = {}
398 | local scalarTable = {}
399 | for i=1,quantity do
400 | -- load the sample
401 | local imgpath = ffi.string(torch.data(self.imagePath[indices[i]]))
402 | local out = self:sampleHookTest(imgpath)
403 | table.insert(dataTable, out)
404 | table.insert(scalarTable, self.imageClass[indices[i]])
405 | end
406 | local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
407 | return data, scalarLabels
408 | end
409 |
410 | return dataset
411 |
--------------------------------------------------------------------------------
/data2/donkey_folderC.lua:
--------------------------------------------------------------------------------
1 |
2 | --[[
3 | This data loader is a modified version of the one from dcgan.torch
4 | (see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua).
5 | Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
6 | Copyright (c) 2015-present, Facebook, Inc.
7 | All rights reserved.
8 | This source code is licensed under the BSD-style license found in the
9 | LICENSE file in the root directory of this source tree. An additional grant
10 | of patent rights can be found in the PATENTS file in the same directory.
11 | ]]--
12 |
13 | require 'image'
14 | paths.dofile('dataset.lua')
15 | -- This file contains the data-loading logic and details.
16 | -- It is run by each data-loader thread.
17 | ------------------------------------------
18 | -------- COMMON CACHES and PATHS
19 | -- Check for existence of opt.data
20 |
21 | if opt.DATA_ROOT then
22 | opt.data=paths.concat(opt.DATA_ROOT, opt.phase)
23 | else
24 |
25 | opt.data=paths.concat(os.getenv('DATA_ROOT'), opt.phase)
26 | end
27 |
28 | if not paths.dirp(opt.data) then
29 | error('Did not find directory: ' .. opt.data)
30 | end
31 |
32 | -- a cache file of the training metadata (if doesnt exist, will be created)
33 | local cache = "cache"
34 | local cache_prefix = opt.data:gsub('/', '_')
35 | os.execute('mkdir -p cache')
36 | local trainCache = paths.concat(cache, cache_prefix .. '_trainCache.t7')
37 |
38 | --------------------------------------------------------------------------------------------
39 | local input_nc = opt.input_nc -- input channels
40 | local output_nc = opt.output_nc
41 | local loadSize = {input_nc/3, opt.loadSize}
42 | local sampleSize = {input_nc/3, opt.fineSize}
43 |
44 | local preprocessAandBC = function(imA, imB)
45 | imA = image.scale(imA, loadSize[2], loadSize[2])
46 | imB = image.scale(imB, loadSize[2], loadSize[2])
47 | local perm = torch.LongTensor{3, 2, 1}
48 |
49 | -- imA = imA:index(1, perm)--:mul(256.0): brg, rgb
50 | -- imA = imA:mul(2):add(-1) --这里把范围从-1到1 改为0到1
51 | -- imB = imB:index(1, perm)
52 | -- imB = imB:mul(2):add(-1)
53 | -- imC = imC:index(1, perm)
54 | -- imC = imC:mul(2):add(-1)
55 |
56 | -- assert(imA:max()<=1,"A: badly scaled inputs")
57 | -- assert(imA:min()>=-1,"A: badly scaled inputs")
58 | -- assert(imB:max()<=1,"B: badly scaled inputs")
59 | -- assert(imB:min()>=-1,"B: badly scaled inputs")
60 | -- assert(imC:max()<=1,"C: badly scaled inputs")
61 | -- assert(imC:min()>=-1,"C: badly scaled inputs")
62 |
63 | --这里把范围从-1到1 改为0到1
64 | imA = imA:index(1, perm)--:mul(256.0): brg, rgb
65 | -- imA = imA:mul(2):add(-1)
66 | imB = imB:index(1, perm)
67 | -- imB = imB:mul(2):add(-1)
68 |
69 |
70 | assert(imA:max()<=1,"A: badly scaled inputs")
71 | assert(imA:min()>=0,"A: badly scaled inputs")
72 | assert(imB:max()<=1,"B: badly scaled inputs")
73 | assert(imB:min()>=0,"B: badly scaled inputs")
74 |
75 |
76 |
77 | local oW = sampleSize[2]
78 | local oH = sampleSize[2]
79 | local iH = imA:size(2)
80 | local iW = imA:size(3)
81 |
82 | if iH~=oH then
83 | h1 = math.ceil(torch.uniform(1e-2, iH-oH))
84 | end
85 |
86 | if iW~=oW then
87 | w1 = math.ceil(torch.uniform(1e-2, iW-oW))
88 | end
89 | if iH ~= oH or iW ~= oW then
90 |
91 | imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH)
92 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
93 |
94 |
95 | end
96 |
97 | local flip_flag=0
98 | if opt.flip == 1 and torch.uniform() > 0.5 then --
99 | imA = image.hflip(imA)
100 | imB = image.hflip(imB)
101 | flip_flag=1
102 | end
103 |
104 | return imA, imB,flip_flag
105 | end
106 |
107 |
108 |
109 | local function loadImageChannel(path)
110 | local input = image.load(path, 3, 'float')
111 | input = image.scale(input, loadSize[2], loadSize[2])
112 |
113 | local oW = sampleSize[2]
114 | local oH = sampleSize[2]
115 | local iH = input:size(2)
116 | local iW = input:size(3)
117 |
118 | if iH~=oH then
119 | h1 = math.ceil(torch.uniform(1e-2, iH-oH))
120 | end
121 |
122 | if iW~=oW then
123 | w1 = math.ceil(torch.uniform(1e-2, iW-oW))
124 | end
125 | if iH ~= oH or iW ~= oW then
126 | input = image.crop(input, w1, h1, w1 + oW, h1 + oH)
127 | end
128 |
129 |
130 | if opt.flip == 1 and torch.uniform() > 0.5 then
131 | input = image.hflip(input)
132 | end
133 |
134 | -- print(input:mean(), input:min(), input:max())
135 | local input_lab = image.rgb2lab(input)
136 | -- print(input_lab:size())
137 | -- os.exit()
138 | local imA = input_lab[{{1}, {}, {} }]:div(50.0) - 1.0
139 | local imB = input_lab[{{2,3},{},{}}]:div(110.0)
140 |
141 | local imAB = torch.cat(imA, imB, 1)
142 | assert(imAB:max()<=1,"A: badly scaled inputs")
143 | assert(imAB:min()>=-1,"A: badly scaled inputs")
144 |
145 | return imAB
146 | end
147 |
148 | --local function loadImage
149 |
150 | local function loadImage(path)
151 | local input = image.load(path, 3, 'float')
152 | local h = input:size(2)
153 | local w = input:size(3)
154 |
155 | local imA = image.crop(input, 0, 0, w/2, h)
156 | local imB = image.crop(input, w/2, 0, w, h)
157 |
158 |
159 | return imA, imB
160 | end
161 |
162 | local function loadImageInpaint(path)
163 | local imB = image.load(path, 3, 'float')
164 | imB = image.scale(imB, loadSize[2], loadSize[2])
165 | local perm = torch.LongTensor{3, 2, 1}
166 | imB = imB:index(1, perm)--:mul(256.0): brg, rgb
167 | imB = imB:mul(2):add(-1)
168 | assert(imB:max()<=1,"A: badly scaled inputs")
169 | assert(imB:min()>=-1,"A: badly scaled inputs")
170 | local oW = sampleSize[2]
171 | local oH = sampleSize[2]
172 | local iH = imB:size(2)
173 | local iW = imB:size(3)
174 | if iH~=oH then
175 | h1 = math.ceil(torch.uniform(1e-2, iH-oH))
176 | end
177 |
178 | if iW~=oW then
179 | w1 = math.ceil(torch.uniform(1e-2, iW-oW))
180 | end
181 | if iH ~= oH or iW ~= oW then
182 | imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
183 | end
184 | local imA = imB:clone()
185 | imA[{{},{1 + oH/4, oH/2 + oH/4},{1 + oW/4, oW/2 + oW/4}}] = 1.0
186 | if opt.flip == 1 and torch.uniform() > 0.5 then
187 | imA = image.hflip(imA)
188 | imB = image.hflip(imB)
189 | end
190 | imAB = torch.cat(imA, imB, 1)
191 | return imAB
192 | end
193 |
194 | -- channel-wise mean and std. Calculate or load them from disk later in the script.
195 | local mean,std
196 | --------------------------------------------------------------------------------
197 | -- Hooks that are used for each image that is loaded
198 |
199 | -- function to load the image, jitter it appropriately (random crops etc.)
200 | local trainHook = function(self, path)
201 | collectgarbage()
202 | local flip_flag
203 | if opt.preprocess == 'regular' then
204 | -- print('process regular')
205 | local imA, imB = loadImage(path)
206 |
207 | imA, imB,flip_flag = preprocessAandBC(imA, imB)
208 |
209 | imAB = torch.cat(imA, imB, 1)
210 |
211 |
212 | --print('image C size')
213 | --print(imAB:size())
214 | end
215 |
216 | if opt.preprocess == 'colorization' then
217 | -- print('process colorization')
218 | imAB = loadImageChannel(path)
219 | end
220 |
221 | if opt.preprocess == 'inpaint' then
222 | -- print('process inpaint')
223 | imAB = loadImageInpaint(path)
224 | end
225 | -- print('image AB size')
226 | -- print(imAB:size())
227 | return imAB,flip_flag
228 | end
229 |
230 | --------------------------------------
231 | -- trainLoader
232 | print('trainCache', trainCache)
233 | --if paths.filep(trainCache) then
234 | -- print('Loading train metadata from cache')
235 | -- trainLoader = torch.load(trainCache)
236 | -- trainLoader.sampleHookTrain = trainHook
237 | -- trainLoader.loadSize = {input_nc, opt.loadSize, opt.loadSize}
238 | -- trainLoader.sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]}
239 | -- trainLoader.serial_batches = opt.serial_batches
240 | -- trainLoader.split = 100
241 | --else
242 | print('Creating train metadata')
243 | -- print(opt.data)
244 | print('serial batch:, ', opt.serial_batches)
245 | trainLoader = dataLoader{
246 | paths = {opt.data},
247 | loadSize = {input_nc, loadSize[2], loadSize[2]},
248 | sampleSize = {input_nc+output_nc, sampleSize[2], sampleSize[2]},
249 | split = 100,
250 | serial_batches = opt.serial_batches,
251 | verbose = true
252 | }
253 | -- print('finish')
254 | --torch.save(trainCache, trainLoader)
255 | --print('saved metadata cache at', trainCache)
256 | trainLoader.sampleHookTrain = trainHook
257 |
258 | --end
259 | collectgarbage()
260 |
261 | -- do some sanity checks on trainLoader
262 | do
263 | local class = trainLoader.imageClass
264 | local nClasses = #trainLoader.classes
265 | assert(class:max() <= nClasses, "class logic has error")
266 | assert(class:min() >= 1, "class logic has error")
267 | end
268 |
--------------------------------------------------------------------------------
/imgs/IMDb/1_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/1_1.jpg
--------------------------------------------------------------------------------
/imgs/IMDb/1_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/1_2.jpg
--------------------------------------------------------------------------------
/imgs/IMDb/1_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/1_3.jpg
--------------------------------------------------------------------------------
/imgs/IMDb/1_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/1_4.jpg
--------------------------------------------------------------------------------
/imgs/IMDb/2_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/2_1.jpg
--------------------------------------------------------------------------------
/imgs/IMDb/2_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/2_2.jpg
--------------------------------------------------------------------------------
/imgs/IMDb/2_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/2_3.jpg
--------------------------------------------------------------------------------
/imgs/IMDb/2_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/IMDb/2_4.jpg
--------------------------------------------------------------------------------
/imgs/architecture/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/architecture/pipeline.jpg
--------------------------------------------------------------------------------
/imgs/realresults/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/realresults/1.jpg
--------------------------------------------------------------------------------
/imgs/warpface/warp.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csxmli2016/GFRNet/d76a096a324653862ee0ef74cf621d50e8d5b431/imgs/warpface/warp.jpg
--------------------------------------------------------------------------------
/test.lua:
--------------------------------------------------------------------------------
1 |
2 | require 'torch'
3 | require 'nn'
4 | require 'cudnn'
5 | require 'cunn'
6 | require 'nngraph'
7 | require 'optim'
8 | util = paths.dofile('util/util.lua')
9 | require 'image'
10 | require 'stn'
11 |
12 |
13 | opt = {
14 | DATA_ROOT = './Datasets', --DataRoot
15 | batchSize = 1, -- # images in batch
16 | loadSize = 256, -- scale images to this size
17 | fineSize = 256, -- then crop to this size
18 | flip=0, -- horizontal mirroring data augmentation
19 | gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X (cpu untested)
20 | which_direction = 'AtoB', -- AtoB or BtoA
21 | phase = 'RealImg', -- test dataset name
22 | preprocess = 'regular', -- for special purpose preprocessing, e.g., for colorization, change this (selects preprocessing functions in util.lua)
23 | aspect_ratio = 1.0, -- aspect ratio of result images
24 | name = 'FaceRestoration', -- name of experiment, selects which model to run, should generally should be passed on command line
25 | input_nc = 3, -- # of input image channels
26 | output_nc = 3, -- # of output image channels
27 | serial_batches = 1, -- if 1, takes images in order to make batches, otherwise takes them randomly
28 | serial_batch_iter = 1, -- iter into serial image list
29 | cudnn = 0, -- set to 0 to not use cudnn (untested)
30 | checkpoints_dir = './checkpoints', -- loads models from here
31 | results_dir='./results', -- saves results here
32 | which_model = 'netG', -- which epoch to test?
33 |
34 | }
35 |
36 |
37 | -- one-line argument parser. parses enviroment variables to override the defaults
38 | for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
39 | opt.nThreads = 1 -- test only works with 1 thread...
40 | print(opt)
41 |
42 |
43 | opt.manualSeed = torch.random(1, 10000) -- set seed
44 | print("Random Seed: " .. opt.manualSeed)
45 | torch.manualSeed(opt.manualSeed)
46 | torch.setdefaulttensortype('torch.FloatTensor')
47 |
48 | opt.netG_name = opt.name .. '/' .. opt.which_model
49 | local data_loader = paths.dofile('data2/dataC.lua')
50 | print('#threads...' .. opt.nThreads)
51 | local data = data_loader.new(opt.nThreads, opt)
52 | print("Dataset Size: ", data:size())
53 | opt.how_many=data:size()
54 | ---------------------------------------------------------------------------------------------------
55 |
56 | local input = torch.FloatTensor(opt.batchSize,6,opt.fineSize,opt.fineSize)
57 | local output = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
58 | local guidance = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
59 | local outputface = torch.FloatTensor(opt.batchSize,3,opt.fineSize,opt.fineSize)
60 |
61 | local netG = util.load(paths.concat(opt.checkpoints_dir, opt.netG_name .. '.t7'), opt):cuda()
62 |
63 | netG:apply(printNet)--print network
64 | local filepaths = {} -- paths to images tested on
65 | local filenames ={}
66 | function TableConcat(t1,t2)
67 | for i=1,#t2 do
68 | t1[#t1+1] = t2[i]
69 | end
70 | return t1
71 | end
72 |
73 |
74 | for n=1,math.floor(opt.how_many/opt.batchSize) do
75 | print('processing batch ' .. n)
76 |
77 | local real_data, filepaths_curr, flips = data:getBatch()
78 | local imgname2 = filepaths_curr[1]
79 |
80 | filepaths_curr = util.basename_batch(filepaths_curr)
81 | print('filepaths_curr: ', filepaths_curr)
82 |
83 | real_A=real_data[{ {}, {1,3}, {}, {} }]:clone() -- Blur image
84 | real_B=real_data[{ {}, {4,6}, {}, {} }]:clone() -- guidance
85 |
86 | local outputss = netG:forward({real_A:cuda(),real_B:cuda()})
87 | real_WC=outputss[1]:clone()--warped guidance
88 | fake_gout=outputss[3]:clone()--restoration result
89 |
90 | input = util.deprocess_batch(real_A):float()
91 | outputwarp = util.deprocess_batch(real_WC):float()
92 | guidance = util.deprocess_batch(real_B):float()
93 | outputface = util.deprocess_batch(fake_gout):float()
94 |
95 |
96 | -- save images
97 | paths.mkdir(paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase))
98 | local image_dir = paths.concat(opt.results_dir, opt.netG_name .. '_' .. opt.phase, 'images')
99 | paths.mkdir(image_dir)
100 | paths.mkdir(paths.concat(image_dir,'Input'))
101 | paths.mkdir(paths.concat(image_dir,'WarpGuidance'))
102 | paths.mkdir(paths.concat(image_dir,'Guidance'))
103 | paths.mkdir(paths.concat(image_dir,'Output'))
104 |
105 | for i=1, opt.batchSize do
106 | image.save(paths.concat(image_dir,'Input',filepaths_curr[i]), image.scale(input[i],input[i]:size(2),input[i]:size(3)/opt.aspect_ratio))
107 | image.save(paths.concat(image_dir,'WarpGuidance',filepaths_curr[i]), image.scale(outputwarp[i],outputwarp[i]:size(2),outputwarp[i]:size(3)/opt.aspect_ratio))
108 | image.save(paths.concat(image_dir,'Guidance',filepaths_curr[i]), image.scale(guidance[i],guidance[i]:size(2),guidance[i]:size(3)/opt.aspect_ratio))
109 | image.save(paths.concat(image_dir,'Output',filepaths_curr[i]), image.scale(outputface[i],outputface[i]:size(2),outputface[i]:size(3)/opt.aspect_ratio))
110 | end
111 | print('Saved images to: ', image_dir)
112 | BB=string.split(imgname2,'/')
113 | local iename = BB[#BB]
114 | local ImgName={}
115 | ImgName[1]=iename
116 |
117 | filepaths = TableConcat(filepaths, filepaths_curr)
118 | filenames = TableConcat(filenames, ImgName)
119 | end
120 |
121 | -- make webpage
122 | io.output(paths.concat(opt.results_dir,opt.netG_name .. '_' .. opt.phase, 'index.html'))
123 | io.write('')
124 | io.write('Image # | Input | Guidance | Warped Guidance | Output |
')
125 | for i=1, #filepaths do
126 | io.write('')
127 | io.write('' .. filenames[i] .. ' | ')
128 | io.write(' | ')
129 | io.write(' | ')
130 | io.write(' | ')
131 | io.write(' | ')
132 | io.write('
')
133 | end
134 | io.write('
')
135 |
--------------------------------------------------------------------------------
/util/cudnn_convert_custom.lua:
--------------------------------------------------------------------------------
1 | -- modified from https://github.com/NVIDIA/torch-cudnn/blob/master/convert.lua
2 | -- removed error on nngraph
3 |
4 | -- modules that can be converted to nn seamlessly
5 | local layer_list = {
6 | 'BatchNormalization',
7 | 'SpatialBatchNormalization',
8 | 'SpatialConvolution',
9 | 'SpatialCrossMapLRN',
10 | 'SpatialFullConvolution',
11 | 'SpatialMaxPooling',
12 | 'SpatialAveragePooling',
13 | 'ReLU',
14 | 'Tanh',
15 | 'Sigmoid',
16 | 'SoftMax',
17 | 'LogSoftMax',
18 | 'VolumetricBatchNormalization',
19 | 'VolumetricConvolution',
20 | 'VolumetricFullConvolution',
21 | 'VolumetricMaxPooling',
22 | 'VolumetricAveragePooling',
23 | }
24 |
25 | -- goes over a given net and converts all layers to dst backend
26 | -- for example: net = cudnn_convert_custom(net, cudnn)
27 | -- same as cudnn.convert with gModule check commented out
28 | function cudnn_convert_custom(net, dst, exclusion_fn)
29 | return net:replace(function(x)
30 | --if torch.type(x) == 'nn.gModule' then
31 | -- io.stderr:write('Warning: cudnn.convert does not work with nngraph yet. Ignoring nn.gModule')
32 | -- return x
33 | --end
34 | local y = 0
35 | local src = dst == nn and cudnn or nn
36 | local src_prefix = src == nn and 'nn.' or 'cudnn.'
37 | local dst_prefix = dst == nn and 'nn.' or 'cudnn.'
38 |
39 | local function convert(v)
40 | local y = {}
41 | torch.setmetatable(y, dst_prefix..v)
42 | if v == 'ReLU' then y = dst.ReLU() end -- because parameters
43 | for k,u in pairs(x) do y[k] = u end
44 | if src == cudnn and x.clearDesc then x.clearDesc(y) end
45 | if src == cudnn and v == 'SpatialAveragePooling' then
46 | y.divide = true
47 | y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
48 | end
49 | return y
50 | end
51 |
52 | if exclusion_fn and exclusion_fn(x) then
53 | return x
54 | end
55 | local t = torch.typename(x)
56 | if t == 'nn.SpatialConvolutionMM' then
57 | y = convert('SpatialConvolution')
58 | elseif t == 'inn.SpatialCrossResponseNormalization' then
59 | y = convert('SpatialCrossMapLRN')
60 | else
61 | for i,v in ipairs(layer_list) do
62 | if torch.typename(x) == src_prefix..v then
63 | y = convert(v)
64 | end
65 | end
66 | end
67 | return y == 0 and x or y
68 | end)
69 | end
--------------------------------------------------------------------------------
/util/util.lua:
--------------------------------------------------------------------------------
1 | --
2 | -- code derived from https://github.com/soumith/dcgan.torch
3 | --
4 |
5 | local util = {}
6 |
7 | require 'torch'
8 | function printNet(net)
9 |
10 | for i = 1, net:size(1) do
11 | print(string.format("%d: %s", i, net.modules[i]))
12 | end
13 |
14 | end
15 | function util.normalize(img)
16 | -- rescale image to 0 .. 1
17 | local min = img:min()
18 | local max = img:max()
19 |
20 | img = torch.FloatTensor(img:size()):copy(img)
21 | img:add(-min):mul(1/(max-min))
22 | return img
23 | end
24 |
25 | function util.normalizeBatch(batch)
26 | for i = 1, batch:size(1) do
27 | batch[i] = util.normalize(batch[i]:squeeze())
28 | end
29 | return batch
30 | end
31 |
32 | function util.basename_batch(batch)
33 | for i = 1, #batch do
34 | batch[i] = paths.basename(batch[i])
35 | end
36 | return batch
37 | end
38 |
39 |
40 |
41 | -- default preprocessing
42 | --
43 | -- Preprocesses an image before passing it to a net
44 | -- Converts from RGB to BGR and rescales from [0,1] to [-1,1]
45 | function util.preprocess(img)
46 | -- RGB to BGR
47 | local perm = torch.LongTensor{3, 2, 1}
48 | img = img:index(1, perm)
49 |
50 | -- -- [0,1] to [-1,1]
51 | -- img = img:mul(2):add(-1)
52 |
53 | -- -- check that input is in expected range
54 | -- assert(img:max()<=1,"badly scaled inputs")
55 | -- assert(img:min()>=-1,"badly scaled inputs")
56 |
57 | return img
58 | end
59 |
60 | -- Undo the above preprocessing.
61 | function util.deprocess(img)
62 | -- BGR to RGB
63 | local perm = torch.LongTensor{3, 2, 1}
64 | img = img:index(1, perm)
65 |
66 | -- [-1,1] to [0,1]
67 | -- 这里也改了,直接时0到1
68 | -- img = img:add(1):div(2)
69 |
70 | return img
71 | end
72 | function util.deprocess_vgg(img)
73 | -- BGR to RGB
74 | -- [-1,1] to [0,1]
75 |
76 | -- img = img:add(1):div(2)
77 |
78 | return img
79 | end
80 | function util.preprocess_batch(batch)
81 | for i = 1, batch:size(1) do
82 | batch[i] = util.preprocess(batch[i]:squeeze())
83 | end
84 | return batch
85 | end
86 |
87 | function util.deprocess_batch(batch)
88 | for i = 1, batch:size(1) do
89 | batch[i] = util.deprocess(batch[i]:squeeze())
90 | end
91 | return batch
92 | end
93 |
94 |
95 | function util.deprocess_batch_vgg(batch)
96 | for i = 1, batch:size(1) do
97 | batch[i] = util.deprocess_vgg(batch[i]:squeeze())
98 | end
99 | return batch
100 | end
101 | -- preprocessing specific to colorization
102 |
103 | function util.deprocessLAB(L, AB)
104 | local L2 = torch.Tensor(L:size()):copy(L)
105 | if L2:dim() == 3 then
106 | L2 = L2[{1, {}, {} }]
107 | end
108 | local AB2 = torch.Tensor(AB:size()):copy(AB)
109 | AB2 = torch.clamp(AB2, -1.0, 1.0)
110 | -- local AB2 = AB
111 | L2 = L2:add(1):mul(50.0)
112 | AB2 = AB2:mul(110.0)
113 |
114 | L2 = L2:reshape(1, L2:size(1), L2:size(2))
115 |
116 | im_lab = torch.cat(L2, AB2, 1)
117 | im_rgb = torch.clamp(image.lab2rgb(im_lab):mul(255.0), 0.0, 255.0)/255.0
118 |
119 | return im_rgb
120 | end
121 |
122 | function util.deprocessL(L)
123 | local L2 = torch.Tensor(L:size()):copy(L)
124 | L2 = L2:add(1):mul(255.0/2.0)
125 |
126 | if L2:dim()==2 then
127 | L2 = L2:reshape(1,L2:size(1),L2:size(2))
128 | end
129 | L2 = L2:repeatTensor(L2,3,1,1)/255.0
130 |
131 | return L2
132 | end
133 |
134 | function util.deprocessL_batch(batch)
135 | local batch_new = {}
136 | for i = 1, batch:size(1) do
137 | batch_new[i] = util.deprocessL(batch[i]:squeeze())
138 | end
139 | return batch_new
140 | end
141 |
142 | function util.deprocessLAB_batch(batchL, batchAB)
143 | local batch = {}
144 |
145 | for i = 1, batchL:size(1) do
146 | batch[i] = util.deprocessLAB(batchL[i]:squeeze(), batchAB[i]:squeeze())
147 | end
148 |
149 | return batch
150 | end
151 |
152 |
153 | function util.scaleBatch(batch,s1,s2)
154 | local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2)
155 | for i = 1, batch:size(1) do
156 | scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze()
157 | end
158 | return scaled_batch
159 | end
160 |
161 |
162 |
163 | function util.toTrivialBatch(input)
164 | return input:reshape(1,input:size(1),input:size(2),input:size(3))
165 | end
166 | function util.fromTrivialBatch(input)
167 | return input[1]
168 | end
169 |
170 |
171 |
172 | function util.scaleImage(input, loadSize)
173 | -- replicate bw images to 3 channels
174 | if input:size(1)==1 then
175 | input = torch.repeatTensor(input,3,1,1)
176 | end
177 |
178 | input = image.scale(input, loadSize, loadSize)
179 |
180 | return input
181 | end
182 |
183 | function util.getAspectRatio(path)
184 | local input = image.load(path, 3, 'float')
185 | local ar = input:size(3)/input:size(2)
186 | return ar
187 | end
188 |
189 | function util.loadImage(path, loadSize, nc)
190 | local input = image.load(path, 3, 'float')
191 | input= util.preprocess(util.scaleImage(input, loadSize))
192 |
193 | if nc == 1 then
194 | input = input[{{1}, {}, {}}]
195 | end
196 |
197 | return input
198 | end
199 |
200 |
201 |
202 | -- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations
203 | function util.load(filename, opt)
204 | if opt.cudnn>0 then
205 | require 'cudnn'
206 | end
207 |
208 | if opt.gpu > 0 then
209 | require 'cunn'
210 | end
211 |
212 | local net = torch.load(filename)
213 |
214 | if opt.gpu > 0 then
215 | net:cuda()
216 |
217 | -- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below
218 | if net.forwardnodes then
219 | for i=1,#net.forwardnodes do
220 | if net.forwardnodes[i].data.module then
221 | net.forwardnodes[i].data.module:cuda()
222 | end
223 | end
224 | end
225 | else
226 | net:float()
227 | end
228 | net:apply(function(m) if m.weight then
229 | m.gradWeight = m.weight:clone():zero();
230 | m.gradBias = m.bias:clone():zero(); end end)
231 | return net
232 | end
233 |
234 | function util.cudnn(net)
235 | require 'cudnn'
236 | require 'util/cudnn_convert_custom'
237 | return cudnn_convert_custom(net, cudnn)
238 | end
239 |
240 | function util.containsValue(table, value)
241 | for k, v in pairs(table) do
242 | if v == value then return true end
243 | end
244 | return false
245 | end
246 |
247 | function printNet(m)
248 | local name = torch.type(m)
249 | print(name)
250 | end
251 |
252 | return util
253 |
--------------------------------------------------------------------------------