├── .idea
├── Image-Background-Generator.iml
├── markdown-navigator.xml
├── markdown-navigator
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── Inference.py
├── README.md
├── __pycache__
└── config.cpython-36.pyc
├── config.py
├── model
├── __init__.py
├── verifier.py
└── verifier_base.py
├── resource
├── CV_paper.pdf
├── Places365_val_00034821 change(1).jpg
├── color150.mat
├── demo1.jpg
├── demo2.png
├── moving_sequence1.jpg
├── moving_sequence2.jpg
├── pipeline.jpg
├── result1.png
├── test_1.png
└── test_2.png
├── train.py
└── utils
├── Batcher.py
├── __init__.py
├── __pycache__
├── Batcher.cpython-36.pyc
├── __init__.cpython-36.pyc
├── batcher.cpython-36.pyc
├── logger.cpython-36.pyc
└── util.cpython-36.pyc
├── batcher.py
├── logger.py
└── util.py
/.idea/Image-Background-Generator.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/markdown-navigator.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/.idea/markdown-navigator/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
121 |
122 |
123 |
124 | set_args
125 | color150
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 | true
153 | DEFINITION_ORDER
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 | 1542247829615
243 |
244 |
245 | 1542247829615
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
--------------------------------------------------------------------------------
/Inference.py:
--------------------------------------------------------------------------------
1 | import random
2 | from model.verifier_base import VerifierBase
3 | import torch
4 | from config import set_args
5 | from torch.autograd import Variable
6 | from utils.util import *
7 | import time
8 | import argparse
9 |
10 | """
11 | Description:
12 |
13 | This is a script of Gradient-based Foreground Adjustment Algorithm.
14 | (x, y, Scale) of foreground objects will be adjust guided by model's gradient.
15 | """
16 |
17 | # ========================== Constants =====================
18 | parser = argparse.ArgumentParser(description='Inference Phase')
19 | time = time.gmtime()
20 | time = "-".join([str(p) for p in list(time)[:5]])
21 | config = set_args()
22 | test_fg = []
23 |
24 | SAMPLE_NUM = config['sample_num']
25 | ROUND = config['update_rd']
26 | TOPK = config['top_k']
27 |
28 | start_x = 0
29 | start_y = 0
30 | fx = [[-1, 0, 1], [1, 0, 1], [0, -1, 1], [0, 1, 1],
31 | [-1, 0, 0.95], [1, 0, 0.95], [0, -1, 0.95], [0, 1, 0.95],
32 | [-1, 0, 1.05], [1, 0, 1.05], [0, -1, 1.05], [0, 1, 1.05]]
33 |
34 | # ======================== loading ckpt ================== #
35 | ckpt = os.path.join("checkpoints", "ckpt_2_epoch_1:2:1_Regression_sigmoid_shuffle_score_debug.pth")
36 | scene_parsing_folder_name = 'background_gallery_sp'
37 | model_pred = VerifierBase(config)
38 | #model_pred = Verifier(config)
39 | model_pred.cuda()
40 | model_pred.load_state_dict(torch.load(ckpt))
41 | model_pred.eval()
42 |
43 |
44 | def patch(v):
45 | v = Variable(v.cuda())
46 | return v
47 |
48 |
49 | def f(background, foreground, scene_parsing):
50 | # -- TODO -- #
51 | colors = loadmat('resource/color150.mat')['colors']
52 | scene_parsing = colorEncode(scene_parsing, colors)
53 | batch = dict()
54 | batch['BGD'] = patch(torch.FloatTensor(background[:, :, :3].copy().transpose(2, 0, 1)).unsqueeze(0))
55 | batch['FGD'] = patch(torch.FloatTensor(foreground[:, :, :3].copy().transpose(2, 0, 1)).unsqueeze(0))
56 | batch['SPS'] = patch(torch.FloatTensor(scene_parsing[:, :, :3].copy().transpose(2, 0, 1)).unsqueeze(0))
57 |
58 | y1_pred, y2_pred = model_pred(batch)
59 | picture_match_score = y1_pred.detach().cpu().numpy()[..., 0]
60 | location_match_score = y2_pred.detach().cpu().numpy()[..., 0]
61 | print(picture_match_score[0], location_match_score[0])
62 | return [picture_match_score[0], location_match_score[0]]
63 |
64 |
65 | def cvt2RGBA(img):
66 | _, _, channel = img.shape
67 | if channel == 4:
68 | return img
69 | if channel == 1:
70 | return cv2.cvtColor(img, cv2.COLOR_GRAY2RGBA)
71 | else:
72 | return cv2.cvtColor(img, cv2.COLOR_RGB2RGBA)
73 |
74 |
75 | black_canvas = np.zeros((256, 256, 3), np.uint8)
76 | black_canvas = np.concatenate([black_canvas, np.ones((256,256,1))*255.0], axis=2)
77 |
78 |
79 | def paste(target, source, pos=(0,0)):
80 | left_up_x, left_up_y = pos
81 | bg_height, bg_width = target[:, :, 0].shape
82 | fg_height, fg_width = source[:, :, 0].shape
83 | result = target.copy()
84 | target_x_start = max(left_up_x, 0)
85 | target_x_end = min(left_up_x + fg_height, bg_height)
86 | target_y_start = max(left_up_y, 0)
87 | target_y_end = min(left_up_y + fg_width, bg_width)
88 | source_x_start = max(0, -left_up_x)
89 | source_x_end = min(bg_height-left_up_x, fg_height)
90 | source_y_start = max(0, -left_up_y)
91 | source_y_end = min(bg_width-left_up_y, fg_width)
92 | fg = source[source_x_start:source_x_end, source_y_start:source_y_end, :]
93 | bg = result[target_x_start:target_x_end, target_y_start:target_y_end, :]
94 | mask = fg[:, :, 3]
95 | mask_inv = cv2.bitwise_not(mask)
96 | bg = cv2.bitwise_and(bg, bg, mask=mask_inv)
97 | fg = cv2.bitwise_and(fg, fg, mask=mask)
98 | result[target_x_start:target_x_end,
99 | target_y_start:target_y_end, :] = cv2.add(fg, bg)
100 | return result
101 |
102 |
103 | def change(source, delta_x, delta_y, slope):
104 | alpha = source[:, :, 3]
105 | fg_height, fg_width = alpha.shape
106 | x0 = 0
107 | y0 = 0
108 | x1 = fg_height
109 | y1 = fg_width
110 | for i in range(fg_height):
111 | if np.sum(alpha[i, :]) != 0:
112 | x0 = i
113 | break
114 | for i in range(fg_height, x0, -1):
115 | if np.sum(alpha[i - 1, :]) != 0:
116 | x1 = i
117 | break
118 | for i in range(fg_width):
119 | if np.sum(alpha[:, i]) != 0:
120 | y0 = i
121 | break
122 | for i in range(fg_width, y0, -1):
123 | if np.sum(alpha[:, i - 1]) != 0:
124 | y1 = i
125 | break
126 | fg = source[x0:x1, y0:y1, :]
127 | new_fg = cv2.resize(fg, None, fx=slope, fy=slope)
128 | result = np.zeros(source.shape, np.uint8)
129 | result = paste(result, new_fg, (int(delta_x), int(delta_y)))
130 | return result
131 |
132 |
133 | fg = cv2.imread(config['test_img'], -1)
134 | fg = cvt2RGBA(fg)
135 |
136 | rootpath = "/newNAS/Share/ykli" # os.getcwd()
137 | gallery_dir = "background_gallery"
138 | os.mkdir(f'result/{time}')
139 | picture_list = os.listdir(f'{rootpath}/{gallery_dir}/')
140 | choosen_pictures = random.sample(picture_list, SAMPLE_NUM)
141 |
142 | pic_scores = []
143 | for picture_name in choosen_pictures:
144 | bg = cv2.imread(f'{rootpath}/{gallery_dir}/{picture_name}', -1)
145 | bg = cvt2RGBA(bg)
146 | with open(f'{rootpath}/{scene_parsing_folder_name}/{picture_name[0:-4]}.sg.pkl', 'rb') as fr:
147 | sp = pickle.load(fr)
148 | # try_pic = paste(bg, fg, start_x, start_y)
149 | # pic_scores.append(f(bg, fg, sp)[0])
150 | sc = f(bg, fg, sp)
151 | pic_scores.append(sc[0])
152 |
153 | sorted_pic_scores = sorted(pic_scores)
154 | # print(sorted_pic_scores)
155 | theshold_score = sorted_pic_scores[TOPK - 1]
156 | theshold_score_2 = sorted_pic_scores[SAMPLE_NUM - TOPK]
157 | to_test_pictures = []
158 | to_diss_pictures = []
159 | for i in range(SAMPLE_NUM):
160 | if pic_scores[i] <= theshold_score:
161 | to_test_pictures.append(i)
162 | if pic_scores[i] >= theshold_score_2:
163 | to_diss_pictures.append(i)
164 |
165 |
166 | # BAD CASES
167 | for i_pic in range(TOPK):
168 | print("ipc", i_pic)
169 | picture_name = choosen_pictures[to_test_pictures[i_pic]]
170 | picture_score = pic_scores[to_test_pictures[i_pic]]
171 |
172 | os.mkdir(f'result/{time}/{picture_score}_{picture_name[0:-4]}')
173 | bg = cv2.imread(f'{rootpath}/{gallery_dir}/{picture_name}', -1)
174 | bg = cvt2RGBA(bg)
175 | bg_height, bg_width = bg[:, :, 0].shape
176 | mv_height = bg_height / 20
177 | mv_width = bg_width / 20
178 | with open(f'{rootpath}/{scene_parsing_folder_name}/{picture_name[0:-4]}.sg.pkl', 'rb') as fr:
179 | sp = pickle.load(fr)
180 | current_x = start_x
181 | current_y = start_y
182 | current_s = 1
183 | for iter_g in range(ROUND):
184 | tmp_pic_scores = []
185 | for i_fx in range(12):
186 | tmp_fg = change(fg, current_x + fx[i_fx][0]*mv_height, current_y + fx[i_fx][1]*mv_width, current_s * fx[i_fx][2])
187 | # try_pics = paste(bg, tmp_fg, start_x, start_y)
188 | tmp_pic_scores.append(f(bg, tmp_fg, sp)[1])
189 | max_index = tmp_pic_scores.index(max(tmp_pic_scores))
190 | current_x += fx[max_index][0]*mv_height
191 | current_y += fx[max_index][1]*mv_width
192 | current_s *= fx[max_index][2]
193 | mid_fg = change(fg, current_x, current_y, current_s)
194 | mid_result = paste(bg, mid_fg, (start_x, start_y))
195 | max_score = max(tmp_pic_scores)
196 | cv2.imwrite(f'./result/{time}/{picture_score}_{picture_name[0:-4]}/{iter_g}_{max_score}_{fx[max_index][0]}_{fx[max_index][1]}_{fx[max_index][2]}.png', mid_result)
197 | # final_fg = change(fg, current_x, current_y, current_s)
198 | # result = paste(bg, final_fg, start_x, start_y)
199 | # cv2.imwrite(f'./{i_pic}_{max_score}.png', result)
200 |
201 |
202 | for i_pic in range(TOPK):
203 | print("ipc", i_pic)
204 | picture_name = choosen_pictures[to_diss_pictures[i_pic]]
205 | picture_score = pic_scores[to_diss_pictures[i_pic]]
206 | os.mkdir(f'result/{time}/{picture_score}_{picture_name[0:-4]}')
207 | bg = cv2.imread(f'{rootpath}/{gallery_dir}/{picture_name}', -1)
208 | bg = cvt2RGBA(bg)
209 | bg_height, bg_width = bg[:, :, 0].shape
210 | mv_height = bg_height / 20
211 | mv_width = bg_width / 20
212 | with open(f'{rootpath}/{scene_parsing_folder_name}/{picture_name[0:-4]}.sg.pkl', 'rb') as fr:
213 | sp = pickle.load(fr)
214 | current_x = start_x
215 | current_y = start_y
216 | current_s = 1
217 | for iter_g in range(ROUND):
218 | tmp_pic_scores = []
219 | for i_fx in range(12):
220 | tmp_fg = change(fg, current_x + fx[i_fx][0]*mv_height, current_y + fx[i_fx][1]*mv_width, current_s * fx[i_fx][2])
221 | # try_pics = paste(bg, tmp_fg, start_x, start_y)
222 | tmp_pic_scores.append(f(bg, tmp_fg, sp)[1])
223 | max_index = tmp_pic_scores.index(max(tmp_pic_scores))
224 | current_x += fx[max_index][0]*mv_height
225 | current_y += fx[max_index][1]*mv_width
226 | current_s *= fx[max_index][2]
227 | mid_fg = change(fg, current_x, current_y, current_s)
228 | mid_result = paste(bg, mid_fg, (start_x, start_y))
229 | max_score = max(tmp_pic_scores)
230 | cv2.imwrite(f'./result/{time}/{picture_score}_{picture_name[0:-4]}/{iter_g}_{max_score}_{fx[max_index][0]}_{fx[max_index][1]}_{fx[max_index][2]}.png', mid_result)
231 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Auto-Retoucher(ART)--A Framework for Background Replacement and Foreground adjustment
2 | Given someone's photo, generates a new image with best matching background, and find the best spatial location and scale.
3 |
4 | A PyTorch implementation of ART framework. Our preprint paper has been upleaded to arXiv: http://arxiv.org/abs/1901.03954
5 |
6 | For more information please visit our [website](https://suyang98.github.io/Auto-Retoucher/). Besides, there is a video [demo](https://v.qq.com/x/page/o082798ic8d.html).
7 |
8 | ## Abstract
9 | Replacing the background and simultaneously adjusting foreground objects is a challenging task in image editing. Current techniques for generating such images are heavily relied on user interactions with image editing softwares, which is a tedious job for professional retouchers. Some exciting progress on image editing has been made to ease their workload. However, few models focused on guarantee the semantic consistency between the foreground and background. To solve this problem, we propose a framework —— ART(Auto-Retoucher),to generate images with sufficient semantic and spatial consistency from a given image. Inputs are first processed by semantic matting and scene parsing modules, then a multi-task verifier model will give two confidence scores for the current matching and foreground location. We demonstrate that our jointly optimized verifier model successfully guides the foreground adjustment and improves the global visual consistency.
10 |
11 | ### Example foreground images:
12 |
13 | 
14 | 
15 |
16 | ### Output Images:
17 |
18 | The backgrounds are selected from gallery, with best content-level consistency.
19 |
20 | 
21 | 
22 |
23 | ### moving sequence:
24 | Adjustment procedure guided by model's gradient.
25 |
26 | Fg moves from a random initial location to a plausible position
27 |
28 | 
29 | 
30 |
31 | ## Requirements
32 | ```
33 | pytroch=0.4.1
34 | tensorboardX
35 | tqdm
36 | cv2
37 | ```
38 | ## Training:
39 |
40 | ```
41 | python train.py --train_path=YOUR_DATA_PATH
42 | --test_path=YOUR_TEST_DATA_PATH
43 | --submit_dir=./submission
44 | --batch_size=20
45 | --epochs=10
46 | --attention
47 | ```
48 |
49 | ## Inference:
50 |
51 | ```
52 | python Inference.py --test_img=YOUR_TEST_IMAGE
53 | --top_k=5
54 | --sample_num=100
55 | ```
56 |
--------------------------------------------------------------------------------
/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os.path import expanduser, join
3 | import os
4 |
5 |
6 | def set_args():
7 | parser = argparse.ArgumentParser()
8 | source_path = join(expanduser("~"), "xyx", "data")
9 | data_path = os.path.join(expanduser("~"), "cvdata", "desert")
10 | # Basic Information
11 | # parser.add_argument('--name', type=str, default="test_sf")
12 | # parser.add_argument('--name', type=str, default="1:2:1_Regression_sigmoid_shuffle_score_debug_dropout_lmd1.5")
13 | parser.add_argument('--name', type=str, default="attention")
14 | parser.add_argument('--train_path', type=str, default=data_path)
15 | parser.add_argument('--test_path', type=str, default=" ")
16 | parser.add_argument('--submit_dir', type=str, default="submission")
17 |
18 | # Training Settings
19 | parser.add_argument('--batch_size', type=int, default=20)
20 | parser.add_argument('--sample_size', type=int, default=4)
21 | parser.add_argument('--epochs', type=int, default=12)
22 | parser.add_argument('--n_eval', type=int, default=512)
23 | parser.add_argument('--lambda', type=float, default=1.5)
24 | parser.add_argument('--dropout', type=float, default=0.3)
25 | parser.add_argument('--cuda', action='store_true', default=True)
26 | parser.add_argument('--Regression', action='store_true', default=True)
27 | parser.add_argument('--Attention', action='store_true', default=True)
28 | parser.add_argument('--debug', action='store_true', default=False) # If use small data 'desert'
29 |
30 | # Optimizer Settings
31 | parser.add_argument('--lr', type=float, default=1e-5) # 8e-4
32 | parser.add_argument('--b1', type=float, default=0.9)
33 | parser.add_argument('--b2', type=float, default=0.999)
34 | parser.add_argument('--e', type=float, default=1e-5)
35 | parser.add_argument('--decay', type=float, default=0)
36 | parser.add_argument('--grad_clipping', type=float, default=5)
37 |
38 | # Inference Settings
39 | parser.add_argument('--sample_num', type=int, default=100)
40 | parser.add_argument('--top_k', type=int, default=5)
41 | parser.add_argument('--update_rd', type=int, default=10)
42 | parser.add_argument('--test_img', type=str, default='resource/test_1.png')
43 | config = parser.parse_args().__dict__
44 | return config
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/model/__init__.py
--------------------------------------------------------------------------------
/model/verifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models.resnet import Bottleneck, ResNet, model_urls
4 | import torch.utils.model_zoo as model_zoo
5 | import torch.nn.functional as F
6 | from torch.nn.parameter import Parameter
7 | import math
8 |
9 |
10 | class ResNetWrapper(ResNet):
11 | def __init__(self):
12 | super(ResNetWrapper, self).__init__(Bottleneck, [3, 4, 6, 3])
13 | self.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
14 | self.output_size = 2048
15 |
16 | def forward(self, x):
17 | x = self.conv1(x)
18 | x = self.bn1(x)
19 | x = self.relu(x)
20 | x = self.maxpool(x)
21 |
22 | x = self.layer1(x)
23 | x = self.layer2(x)
24 | x = self.layer3(x)
25 | x = self.layer4(x)
26 |
27 | x = self.avgpool(x)
28 | x = x.view(x.size(0), -1)
29 | return x
30 |
31 |
32 | def gelu(x):
33 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
34 |
35 |
36 | class Attention(nn.Module):
37 | def __init__(self, m, n):
38 | super(Attention, self).__init__()
39 | self.m = m
40 | self.n = n
41 | self.proj_1 = Parameter(torch.Tensor(30, m))
42 | self.proj_2 = Parameter(torch.Tensor(30, n))
43 | self.reset_parameters()
44 |
45 | def reset_parameters(self):
46 | stdv1 = 1. / math.sqrt(self.proj_1.size(1))
47 | self.proj_1.data.uniform_(-stdv1, stdv1)
48 | stdv2 = 1. / math.sqrt(self.proj_2.size(1))
49 | self.proj_2.data.uniform_(-stdv2, stdv2)
50 |
51 | def forward(self, input1, input2, input3):
52 | N = input1.shape[0]
53 | h = torch.cat([input1, input2], dim=-1)
54 | proj1 = F.linear(h, self.proj_1).unsqueeze(-1)
55 | proj2 = F.linear(input3, self.proj_2).unsqueeze(-2)
56 | return gelu(torch.matmul(proj1, proj2).view(N, -1))
57 |
58 |
59 | class Verifier(nn.Module):
60 | def __init__(self, config):
61 | super(Verifier, self).__init__()
62 | self.background_reader = ResNetWrapper()
63 | self.portrait_reader = ResNetWrapper()
64 | self.scene_reader = ResNetWrapper()
65 | self.config = config
66 | logit_size = 8192
67 | print("logits:", logit_size)
68 | self.maxpool = torch.nn.MaxPool1d(3)
69 | self.context_attn = Attention(2 * logit_size, logit_size)
70 | self.spatial_attn = Attention(2 * logit_size, logit_size)
71 | self.linear1 = nn.Linear(3 * logit_size + 900, 2)
72 | self.linear2 = nn.Linear(3 * logit_size + 900, 2)
73 |
74 | def forward(self, batch):
75 | xb = self.background_reader(batch['BGD'])
76 | xf = self.portrait_reader(batch['FGD'])
77 | xs = self.scene_reader(batch['SPS'])
78 |
79 | xb = F.dropout(xb, p=self.config['dropout'])
80 | xf = F.dropout(xf, p=self.config['dropout'])
81 | xs = F.dropout(xs, p=self.config['dropout'])
82 |
83 | xbn = torch.unsqueeze(xb, dim=-2).transpose(-1, -2)
84 | xfn = torch.unsqueeze(xf, dim=-2).transpose(-1, -2)
85 | xsn = torch.unsqueeze(xs, dim=-2).transpose(-1, -2)
86 | # print("xb", xb.shape)
87 | # print("xf", xf.shape)
88 | # print("xs", xs.shape)
89 | # print("xbn", xbn.shape)
90 | # print("xfn", xfn.shape)
91 | # print("xsn", xsn.shape)
92 |
93 | xn = self.maxpool(torch.cat([xbn, xfn, xsn], dim=-1)).squeeze(-1)
94 | x_cattn = self.context_attn(xb, xf, xn)
95 | x_sattn = self.spatial_attn(xs, xf, xn)
96 | print(x_sattn.shape)
97 |
98 | input_1 = torch.cat([xb, xf, xn, x_cattn], dim=-1)
99 | input_2 = torch.cat([xs, xf, xn, x_sattn], dim=-1)
100 |
101 | if self.config['Regression']:
102 | logit_1 = torch.sigmoid(self.linear1(input_1))
103 | logit_2 = torch.sigmoid(self.linear2(input_2))
104 | else:
105 | logit_1 = torch.softmax(self.linear1(input_1), dim=-1)
106 | logit_2 = torch.softmax(self.linear2(input_2), dim=-1)
107 |
108 | return logit_1, logit_2
109 |
--------------------------------------------------------------------------------
/model/verifier_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision.models.resnet import Bottleneck, ResNet, model_urls
4 | import torch.utils.model_zoo as model_zoo
5 | import torch.nn.functional as F
6 |
7 |
8 | class ResNetWrapper(ResNet):
9 | def __init__(self):
10 | super(ResNetWrapper, self).__init__(Bottleneck, [3, 4, 6, 3])
11 | self.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
12 | self.output_size = 2048
13 |
14 | def forward(self, x):
15 | x = self.conv1(x)
16 | x = self.bn1(x)
17 | x = self.relu(x)
18 | x = self.maxpool(x)
19 |
20 | x = self.layer1(x)
21 | x = self.layer2(x)
22 | x = self.layer3(x)
23 | x = self.layer4(x)
24 |
25 | x = self.avgpool(x)
26 | x = x.view(x.size(0), -1)
27 | return x
28 |
29 |
30 | class VerifierBase(nn.Module):
31 | def __init__(self, config):
32 | super(VerifierBase, self).__init__()
33 | self.background_reader = ResNetWrapper()
34 | self.portrait_reader = ResNetWrapper()
35 | self.scene_reader = ResNetWrapper()
36 | self.config = config
37 | logit_size = 8192
38 | print("logits:", logit_size)
39 | self.maxpool = torch.nn.MaxPool1d(3)
40 | self.linear1 = nn.Linear(3 * logit_size, 2)
41 | self.linear2 = nn.Linear(3 * logit_size, 2)
42 |
43 | def forward(self, batch):
44 | xb = self.background_reader(batch['BGD'])
45 | xf = self.portrait_reader(batch['FGD'])
46 | xs = self.scene_reader(batch['SPS'])
47 |
48 | xb = F.dropout(xb, p=self.config['dropout'])
49 | xf = F.dropout(xf, p=self.config['dropout'])
50 | xs = F.dropout(xs, p=self.config['dropout'])
51 |
52 | xbn = torch.unsqueeze(xb, dim=-2).transpose(-1, -2)
53 | xfn = torch.unsqueeze(xf, dim=-2).transpose(-1, -2)
54 | xsn = torch.unsqueeze(xs, dim=-2).transpose(-1, -2)
55 | # print("xb", xb.shape)
56 | # print("xf", xf.shape)
57 | # print("xs", xs.shape)
58 | # print("xbn", xbn.shape)
59 | # print("xfn", xfn.shape)
60 | # print("xsn", xsn.shape)
61 |
62 | xn = self.maxpool(torch.cat([xbn, xfn, xsn], dim=-1)).squeeze(-1)
63 |
64 | input_1 = torch.cat([xb, xf, xn], dim=-1)
65 | input_2 = torch.cat([xs, xf, xn], dim=-1)
66 |
67 | if self.config['Regression']:
68 | logit_1 = torch.sigmoid(self.linear1(input_1))
69 | logit_2 = torch.sigmoid(self.linear2(input_2))
70 | else:
71 | logit_1 = torch.softmax(self.linear1(input_1), dim=-1)
72 | logit_2 = torch.softmax(self.linear2(input_2), dim=-1)
73 |
74 | return logit_1, logit_2
75 |
--------------------------------------------------------------------------------
/resource/CV_paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/CV_paper.pdf
--------------------------------------------------------------------------------
/resource/Places365_val_00034821 change(1).jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/Places365_val_00034821 change(1).jpg
--------------------------------------------------------------------------------
/resource/color150.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/color150.mat
--------------------------------------------------------------------------------
/resource/demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/demo1.jpg
--------------------------------------------------------------------------------
/resource/demo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/demo2.png
--------------------------------------------------------------------------------
/resource/moving_sequence1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/moving_sequence1.jpg
--------------------------------------------------------------------------------
/resource/moving_sequence2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/moving_sequence2.jpg
--------------------------------------------------------------------------------
/resource/pipeline.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/pipeline.jpg
--------------------------------------------------------------------------------
/resource/result1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/result1.png
--------------------------------------------------------------------------------
/resource/test_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/test_1.png
--------------------------------------------------------------------------------
/resource/test_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/resource/test_2.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from utils.batcher import BatchGenerator
4 | from utils.util import *
5 | from tqdm import tqdm
6 | from config import set_args
7 | from utils.logger import create_logger
8 | from tensorboardX import SummaryWriter
9 | from model.verifier import Verifier
10 | from model.verifier_base import VerifierBase
11 |
12 |
13 | def predict(batcher, model):
14 | test_data = batcher.test_batches
15 | model.eval()
16 | acc = 0.0
17 | rmse = []
18 | for i, batch in tqdm(enumerate(test_data)):
19 | batch = batcher.batch2cuda(batch)
20 | y1 = batch['y1'].detach().cpu().numpy()
21 | y2 = batch['y2'].detach().cpu().numpy()
22 | y1_pred, y2_pred = model(batch)
23 | y1_pred, y2_pred = y1_pred.detach().cpu().numpy()[:, 0], y2_pred.detach().cpu().numpy()[:, 0]
24 | y1_pred[y1_pred >= 0.5] = 1
25 | y1_pred[y1_pred < 0.5] = 0
26 | acc += np.sum(y1 == y1_pred) / y1_pred.shape[0]
27 | rmse.append((y2_pred - y2) ** 2)
28 | acc /= len(test_data)
29 | rmse = np.sqrt(np.mean(np.concatenate(rmse)))
30 | model.train()
31 | return float(acc), float(rmse)
32 |
33 |
34 | def print_acc(batch, y1_pred, y2_pred, step):
35 | global total_acc1, total_acc2
36 |
37 | acc1, acc2 = accuracy(batch, y1_pred, y2_pred)
38 | total_acc1 += acc1
39 | total_acc2 += acc2
40 | print("batch_acc:", acc1, acc2)
41 | print("total_acc:", total_acc1/step, total_acc2/step)
42 |
43 |
44 | if __name__ == "__main__":
45 | logger = create_logger(__name__)
46 | config = set_args()
47 | global_step = 0
48 | writer = SummaryWriter(log_dir="figures")
49 |
50 | bg_data, fg_data, sp_data, sf_data, score = load_data(config['debug'])
51 | # test_data = load_data(config['test_path'])
52 |
53 | Batcher = BatchGenerator(config, bg_data, fg_data, sp_data, sf_data, score)
54 | print("Batch number: ", Batcher.total)
55 |
56 | if config['Attention']:
57 | print("Use Attention")
58 | model = Verifier(config)
59 | else:
60 | model = VerifierBase(config)
61 |
62 | optimizer = torch.optim.Adam(model.parameters(),
63 | lr=config['lr'],
64 | betas=(config['b1'], config['b2']),
65 | eps=config['e'],
66 | weight_decay=config['decay'])
67 | if config['Regression']:
68 | criterion = nn.MSELoss()
69 | else:
70 | criterion = nn.CrossEntropyLoss(reduce=False)
71 |
72 | if config['cuda']:
73 | model.cuda()
74 | model.train()
75 |
76 | for epc in range(config['epochs']):
77 | Batcher.reset()
78 | total_acc1 = 0
79 | total_acc2 = 0
80 | for i, batch in tqdm(enumerate(Batcher), total=len(Batcher)):
81 | global_step += 1
82 | # Just for test
83 | y1_pred, y2_pred = model(batch)
84 | if config['Regression']:
85 | loss1 = torch.sum(criterion(y1_pred[:, 0], batch['y1'].float()))
86 | loss2 = torch.sum(criterion(y2_pred[:, 0], batch['y2'].float()))
87 | else:
88 | loss1 = torch.sum(criterion(y1_pred, batch['y1']))
89 | loss2 = torch.sum(criterion(y2_pred, batch['y2']))
90 | loss = loss1 + config['lambda'] * loss2
91 | loss.backward()
92 |
93 | y1_pred, y2_pred = y1_pred.detach().cpu().numpy(), y2_pred.detach().cpu().numpy()
94 |
95 | print(loss1.detach().cpu().numpy(), loss2.detach().cpu().numpy(), loss.detach().cpu().numpy())
96 | if i % 100 == 0:
97 | print(y1_pred, y2_pred)
98 | add_figure(config['name'], writer, global_step, loss1, loss2, loss)
99 | print(loss.detach().cpu().numpy())
100 | if global_step % 1000 == 0:
101 | acc, rmse = predict(Batcher, model)
102 | add_result(config['name'], writer, global_step, acc, rmse)
103 | print("acc: ", acc, "rmse: ", rmse)
104 | optimizer.step()
105 | optimizer.zero_grad()
106 | torch.save(model.state_dict(), "checkpoints/ckpt_{}_epoch_{}.pth".format(epc, config['name']))
107 | writer.close()
108 |
--------------------------------------------------------------------------------
/utils/Batcher.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from os.path import join
3 | import numpy as np
4 | import torch
5 | from torch.autograd import Variable
6 | import random
7 |
8 |
9 | class BatchGenerator(object):
10 | def __init__(self, config, bg_data, fg_data, sp_data, sf_data, is_training=True, ratio=[14, 6, 2]):
11 | self.config = config
12 | self.sample_size = ratio[0] # config['sample_size']
13 | self.batch_size = ratio[0] # config['sample_size'] * 6
14 | self.bg_data = bg_data
15 | self.fg_data = fg_data
16 | self.sp_data = sp_data
17 | self.sf_data = sf_data
18 | self.data = list(bg_data.keys())
19 | self.is_training = is_training
20 | self.ratio = ratio
21 | print("load data ok")
22 | self.total = len(self.data)
23 |
24 | self.offset = 0
25 | # self.set_number()
26 | print(self.total)
27 | indices = list(range(self.total))
28 | np.random.shuffle(indices)
29 | self.neg_data = [self.data[idx] for idx in indices]
30 |
31 | self.clip_tail()
32 | self.pos_batches = [self.data[i:i + self.sample_size] for i in range(0, self.total, self.sample_size)]
33 | self.neg_batches = [self.neg_data[i:i + self.sample_size] for i in range(0, self.total, self.sample_size)]
34 | print("cut batch ok")
35 |
36 | def set_number(self):
37 | self.color_neg_num = int(self.batch_size * (self.ratio[1] / sum(self.ratio)))
38 | self.position_neg_num = int(self.batch_size * (self.ratio[2] / sum(self.ratio)))
39 | self.pos_num = self.batch_size - self.color_neg_num - self.position_neg_num
40 |
41 | def clip_tail(self):
42 | self.data = self.data[:self.total-(self.total % self.batch_size)]
43 | self.total = len(self.data)
44 |
45 | def reset(self):
46 | self.offset = 0
47 | if self.is_training:
48 | indices = list(range(self.total))
49 | np.random.shuffle(indices)
50 | self.data = [self.data[idx] for idx in indices]
51 | return
52 |
53 | def patch(self, v):
54 | if self.config['cuda']:
55 | v = Variable(v.cuda())
56 | else:
57 | v = Variable(v)
58 | return v
59 |
60 | def __len__(self):
61 | return len(self.pos_batches)
62 |
63 | def __iter__(self):
64 | while self.offset < len(self):
65 | pos_batch = self.pos_batches[self.offset]
66 | neg_batch = self.neg_batches[self.offset]
67 | self.offset += 1
68 |
69 | # True Backgrounds <--> True Foregrounds [N]
70 | pos_backgrounds = [self.bg_data[idx] for idx in pos_batch]
71 | pos_foregrounds = [self.fg_data[idx] for idx in pos_batch]
72 | pos_sceneparsing = [self.sp_data[idx] for idx in pos_batch]
73 | pos_y1 = [True for _ in range(self.sample_size)]
74 | pos_y2 = [True for _ in range(self.sample_size)]
75 |
76 | # True Backgrounds <--> False Foregrounds [N]
77 | neg_col_backgrounds = []
78 | neg_col_foregrounds = []
79 | neg_col_sceneparsing = []
80 | for i in range(self.ratio[1]):
81 | neg_col_backgrounds.append(self.bg_data[pos_batch[i]])
82 | neg_col_foregrounds.append(self.fg_data[neg_batch[i]])
83 | neg_col_sceneparsing.append(self.sp_data[pos_batch[i]])
84 | neg_col_y1 = [False for _ in range(self.ratio[1])]
85 | neg_col_y2 = [False for _ in range(self.ratio[1])]
86 |
87 | # True Backgrounds <--> True Foregrounds & False Position [4N]
88 | neg_pos_backgrounds = []
89 | neg_pos_foregrounds = []
90 | neg_pos_sceneparsing = []
91 | neg_pos_y1 = [True for _ in range(self.ratio[2] * 4)]
92 | neg_pos_y2 = [False for _ in range(self.ratio[2] * 4)]
93 | for i in range(self.ratio[2]):
94 | idx = pos_batch[i]
95 | for k in range(4):
96 | neg_pos_backgrounds.append(self.bg_data[idx])
97 | neg_pos_foregrounds.append(self.sf_data[idx][k])
98 | neg_pos_sceneparsing.append(self.sp_data[idx])
99 |
100 | BGD = pos_foregrounds + neg_col_foregrounds + neg_pos_foregrounds
101 | FGD = pos_backgrounds + neg_col_backgrounds + neg_pos_backgrounds
102 | SPS = pos_sceneparsing + neg_col_sceneparsing + neg_pos_sceneparsing
103 | y1 = pos_y1 + neg_col_y1 + neg_pos_y1
104 | y2 = pos_y2 + neg_col_y2 + neg_pos_y2
105 |
106 | # Shuffle Data
107 | N = len(y1)
108 | indices = list(range(N))
109 | random.shuffle(indices)
110 | BGD = [BGD[idx].tolist() for idx in indices]
111 | FGD = [FGD[idx].tolist() for idx in indices]
112 | SPS = [SPS[idx].tolist() for idx in indices]
113 | y1 = [y1[idx] for idx in indices]
114 | y2 = [y2[idx] for idx in indices]
115 |
116 | batch_dict = dict()
117 | batch_dict['BGD'] = self.patch(torch.FloatTensor(BGD))
118 | batch_dict['FGD'] = self.patch(torch.FloatTensor(FGD))
119 | batch_dict['SPS'] = self.patch(torch.FloatTensor(SPS))
120 | batch_dict['y1'] = self.patch(torch.LongTensor(y1))
121 | batch_dict['y2'] = self.patch(torch.LongTensor(y2))
122 | print(batch_dict['BGD'].shape)
123 | print(batch_dict['FGD'].shape)
124 | print(batch_dict['SPS'].shape)
125 | print(batch_dict['y1'].shape)
126 | print(batch_dict['y2'].shape)
127 |
128 | yield batch_dict
129 |
130 | return
131 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/Batcher.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/utils/__pycache__/Batcher.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/batcher.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/utils/__pycache__/batcher.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/logger.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/utils/__pycache__/logger.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/woshiyyya/Auto-Retoucher-pytorch/f1ac09c981f2194ded330217da5e6e9c9e0056f3/utils/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/batcher.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from os.path import join
3 | import numpy as np
4 | import torch
5 | from torch.autograd import Variable
6 | from utils.util import shuffle_fn
7 | import random
8 |
9 |
10 | class BatchGenerator(object):
11 | def __init__(self, config, bg_data, fg_data, sp_data, sf_data, score, is_training=True, ratio=[1, 2, 1]):
12 | self.config = config
13 | self.batch_size = config['batch_size']
14 | self.bg_data = bg_data
15 | self.fg_data = fg_data
16 | self.sp_data = sp_data
17 | self.sf_data = sf_data
18 | self.score = score
19 | self.keys = list(bg_data.keys())
20 | self.is_training = is_training
21 | self.ratio = ratio
22 | print("load data ok")
23 |
24 | self.offset = 0
25 | self.data = self.construct_dataset()
26 | self.total = len(self.data)
27 | self.clip_tail()
28 | self.batches = [self.data[i:i + self.batch_size] for i in range(0, self.total, self.batch_size)]
29 | if is_training:
30 | self.test_batches = self.batches[-int(0.2 * len(self.batches)):]
31 | self.batches = self.batches[:-int(0.2 * len(self.batches))]
32 | print("cut batch ok")
33 |
34 | @staticmethod
35 | def shuffle_key(data):
36 | indices = list(range(len(data)))
37 | np.random.shuffle(indices)
38 | return [data[idx] for idx in indices]
39 |
40 | def construct_dataset(self):
41 | pos_bg = self.keys
42 | pos_fg = self.keys
43 | pos_fgid = [-1 for _ in range(len(pos_fg))]
44 | pos_y1 = [True for _ in range(len(pos_fg))]
45 | pos_y2 = [True for _ in range(len(pos_fg))]
46 |
47 | neg_sem_bg = self.keys
48 | neg_sem_fg = self.shuffle_key(self.keys)
49 | neg_sem_fgid = [-1 for _ in range(len(neg_sem_fg))]
50 | neg_sem_y1 = [False for _ in range(len(neg_sem_fg))]
51 | neg_sem_y2 = [False for _ in range(len(neg_sem_fg))]
52 |
53 | neg_spa_bg = []
54 | neg_spa_fgid = []
55 | neg_spa_y2 = []
56 | for key in self.keys:
57 | num_neg = self.ratio[1]
58 | neg_spa_bg.extend([key] * num_neg)
59 | idxs = [np.random.randint(0, 4) for _ in range(num_neg)]
60 | neg_spa_fgid.extend(idxs)
61 | neg_spa_y2.extend([self.score[shuffle_fn(key, idx)] for idx in idxs])
62 | neg_spa_fg = neg_spa_bg.copy()
63 | neg_spa_y1 = [True for _ in range(len(neg_spa_fg))]
64 | # neg_spa_y2 = [False for _ in range(len(neg_spa_fg))]
65 | # print(neg_spa_y2)
66 |
67 | # input()
68 | bg_keys = pos_bg * self.ratio[0] + neg_sem_bg * self.ratio[2] + neg_spa_bg
69 | fg_keys = pos_fg * self.ratio[0] + neg_sem_fg * self.ratio[2] + neg_spa_fg
70 | fg_ids = pos_fgid * self.ratio[0] + neg_sem_fgid * self.ratio[2] + neg_spa_fgid
71 | Y1 = pos_y1 * self.ratio[0] + neg_sem_y1 * self.ratio[2] + neg_spa_y1
72 | Y2 = pos_y2 * self.ratio[0] + neg_sem_y2 * self.ratio[2] + neg_spa_y2
73 |
74 | return self.shuffle_key([(bg, fg, fgid, y1, y2) for bg, fg, fgid, y1, y2 in zip(bg_keys, fg_keys, fg_ids, Y1, Y2)])
75 |
76 | def clip_tail(self):
77 | self.data = self.data[:self.total-(self.total % self.batch_size)]
78 | self.total = len(self.data)
79 |
80 | def reset(self):
81 | self.offset = 0
82 | if self.is_training:
83 | indices = list(range(self.total))
84 | np.random.shuffle(indices)
85 | self.data = [self.data[idx] for idx in indices]
86 | return
87 |
88 | def patch(self, v):
89 | if self.config['cuda']:
90 | v = Variable(v.cuda())
91 | else:
92 | v = Variable(v)
93 | return v
94 |
95 | def __len__(self):
96 | return len(self.batches)
97 |
98 | def __iter__(self):
99 | while self.offset < len(self):
100 | batch = self.batches[self.offset]
101 | self.offset += 1
102 |
103 | # True Backgrounds <--> True Foregrounds [N]
104 | foregrounds = []
105 | for case in batch:
106 | if case[2] == -1:
107 | foregrounds.append(self.fg_data[case[1]])
108 | else:
109 | foregrounds.append(self.sf_data[case[1]][case[2]])
110 | backgrounds = [self.bg_data[case[0]] for case in batch]
111 | sceneparsing = [self.sp_data[case[0]] for case in batch]
112 | y1 = [case[3] for case in batch]
113 | y2 = [case[4] for case in batch]
114 |
115 | batch_dict = dict()
116 | batch_dict['BGD'] = self.patch(torch.FloatTensor(backgrounds))
117 | batch_dict['FGD'] = self.patch(torch.FloatTensor(foregrounds))
118 | batch_dict['SPS'] = self.patch(torch.FloatTensor(sceneparsing))
119 | batch_dict['y1'] = self.patch(torch.LongTensor(y1))
120 | batch_dict['y2'] = self.patch(torch.FloatTensor(y2))
121 | # print(batch_dict['BGD'].shape)
122 | # print(batch_dict['FGD'].shape)
123 | # print(batch_dict['SPS'].shape)
124 | # print(batch_dict['y1'].shape)
125 | # print(batch_dict['y2'].shape)
126 |
127 | yield batch_dict
128 | return
129 |
130 | def batch2cuda(self, batch):
131 | foregrounds = []
132 | for case in batch:
133 | if case[2] == -1:
134 | foregrounds.append(self.fg_data[case[1]])
135 | else:
136 | foregrounds.append(self.sf_data[case[1]][case[2]])
137 | backgrounds = [self.bg_data[case[0]] for case in batch]
138 | sceneparsing = [self.sp_data[case[0]] for case in batch]
139 | y1 = [case[3] for case in batch]
140 | y2 = [case[4] for case in batch]
141 |
142 | batch_dict = dict()
143 | batch_dict['BGD'] = self.patch(torch.FloatTensor(backgrounds))
144 | batch_dict['FGD'] = self.patch(torch.FloatTensor(foregrounds))
145 | batch_dict['SPS'] = self.patch(torch.FloatTensor(sceneparsing))
146 | batch_dict['y1'] = self.patch(torch.LongTensor(y1))
147 | batch_dict['y2'] = self.patch(torch.FloatTensor(y2))
148 | return batch_dict
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 | from time import gmtime, strftime
4 | from colorlog import ColoredFormatter
5 |
6 |
7 | def create_logger(name, silent=False, to_disk=False, log_file=None, prefix=None):
8 | """Logger wrapper
9 | by xiaodl
10 | """
11 | # setup logger
12 | log = logging.getLogger(name)
13 | log.setLevel(logging.DEBUG)
14 | formatter = ColoredFormatter(
15 | "%(asctime)s %(log_color)s%(levelname)-8s%(reset)s [%(blue)s%(message)s%(reset)s]",
16 | datefmt='%Y-%m-%d %I:%M:%S',
17 | reset=True,
18 | log_colors={
19 | 'DEBUG': 'cyan',
20 | 'INFO': 'green',
21 | 'WARNING': 'yellow',
22 | 'ERROR': 'red',
23 | 'CRITICAL': 'red,bg_white',
24 | },
25 | secondary_log_colors={},
26 | style='%'
27 | )
28 | fformatter = logging.Formatter(
29 | "%(asctime)s [%(funcName)-12s] %(levelname)-8s [%(message)s]",
30 | datefmt='%Y-%m-%d %I:%M:%S',
31 | style='%'
32 | )
33 | if not silent:
34 | ch = logging.StreamHandler(sys.stdout)
35 | ch.setLevel(logging.INFO)
36 | ch.setFormatter(formatter)
37 | log.addHandler(ch)
38 | if to_disk:
39 | prefix = prefix if prefix is not None else 'my_log'
40 | log_file = log_file if log_file is not None else strftime('{}-%Y-%m-%d-%H-%M-%S.log'.format(prefix), gmtime())
41 | fh = logging.FileHandler(log_file)
42 | fh.setLevel(logging.DEBUG)
43 | fh.setFormatter(fformatter)
44 | log.addHandler(fh)
45 | # disable elmo info
46 | log.propagate = False
47 | return log
48 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | from os.path import expanduser
4 | from os.path import join
5 | from tqdm import tqdm
6 | import pickle
7 | import numpy as np
8 | from scipy.io import loadmat
9 |
10 |
11 | def shuffle_fn(idx, num):
12 | return '{}.shuffle.{}.jpg'.format(idx, num)
13 |
14 |
15 | def accuracy(batch, y1_pred, y2_pred):
16 | y1 = batch['y1'].detach().cpu().numpy()
17 | y2 = batch['y2'].detach().cpu().numpy()
18 | # print(np.vstack([y1, y1_pred[:, 0]]))
19 | # print(np.vstack([y2, y2_pred[:, 0]]))
20 | y1_pred = np.argmax(y1_pred, axis=1)
21 | y2_pred = np.argmax(y2_pred, axis=1)
22 | acc1 = np.sum(y1 == y1_pred)/y1_pred.shape[0]
23 | acc2 = np.sum(y2 == y2_pred)/y2_pred.shape[0]
24 | # print(np.vstack([y1, y1_pred]))
25 | # print(np.vstack([y2, y2_pred]))
26 | return acc1, acc2
27 |
28 |
29 | def add_figure(name, writer, global_step, loss1, loss2, loss):
30 | writer.add_scalar(name + ' data/train_loss', loss, global_step)
31 | writer.add_scalars(name + ' data/loss_group', {'loss1': loss1, 'loss2': loss2}, global_step)
32 | return
33 |
34 |
35 | def add_result(name, writer, global_step, acc, rmse):
36 | writer.add_scalars(name + ' data/result_group', {'acc': acc, 'rmse': rmse}, global_step)
37 | return
38 |
39 |
40 | def unique(ar, return_index=False, return_inverse=False, return_counts=False):
41 | ar = np.asanyarray(ar).flatten()
42 |
43 | optional_indices = return_index or return_inverse
44 | optional_returns = optional_indices or return_counts
45 |
46 | if ar.size == 0:
47 | if not optional_returns:
48 | ret = ar
49 | else:
50 | ret = (ar,)
51 | if return_index:
52 | ret += (np.empty(0, np.bool),)
53 | if return_inverse:
54 | ret += (np.empty(0, np.bool),)
55 | if return_counts:
56 | ret += (np.empty(0, np.intp),)
57 | return ret
58 | if optional_indices:
59 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')
60 | aux = ar[perm]
61 | else:
62 | ar.sort()
63 | aux = ar
64 | flag = np.concatenate(([True], aux[1:] != aux[:-1]))
65 |
66 | if not optional_returns:
67 | ret = aux[flag]
68 | else:
69 | ret = (aux[flag],)
70 | if return_index:
71 | ret += (perm[flag],)
72 | if return_inverse:
73 | iflag = np.cumsum(flag) - 1
74 | inv_idx = np.empty(ar.shape, dtype=np.intp)
75 | inv_idx[perm] = iflag
76 | ret += (inv_idx,)
77 | if return_counts:
78 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
79 | ret += (np.diff(idx),)
80 | return ret
81 |
82 |
83 | def colorEncode(labelmap, colors, mode='BGR'):
84 | labelmap = labelmap.astype('int')
85 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
86 | dtype=np.uint8)
87 | for label in unique(labelmap):
88 | if label < 0:
89 | continue
90 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
91 | np.tile(colors[label],
92 | (labelmap.shape[0], labelmap.shape[1], 1))
93 |
94 | if mode == 'BGR':
95 | return labelmap_rgb[:, :, ::-1]
96 | else:
97 | return labelmap_rgb
98 |
99 |
100 | def load_data(small=False):
101 | source_path = os.path.join(expanduser("~"), "cvdata", "data_jpg")
102 | shuffle_path = os.path.join(expanduser("~"), "cvdata", "sigmoid_shuffle")
103 | scene_path = os.path.join(expanduser("~"), "cvdata", "scene_parsing_np")
104 | folder_names = ['golf', 'kitchen', 'office', 'airport_terminal', 'banquet',
105 | 'beach', 'boat', 'coffee_shop', 'conference_room', 'desert',
106 | 'football', 'hospital', 'ice_skating', 'stage', 'staircase',
107 | 'supermarket']
108 | folder_names = ['golf', 'kitchen', 'office', 'airport_terminal', 'banquet',
109 | 'beach', 'boat', 'coffee_shop', 'conference_room', 'desert',
110 | 'hospital', 'ice_skating', 'staircase',
111 | 'supermarket']
112 | if small:
113 | folder_names = ['desert']
114 | bg_data = dict()
115 | fg_data = dict()
116 | sp_data = dict()
117 | sf_data = dict()
118 | colors = loadmat('resource/color150.mat')['colors']
119 |
120 | for folder in folder_names:
121 | for img in tqdm(os.listdir(join(source_path, folder))):
122 | fn = join(source_path, folder, img)
123 | idx = img.split(".")[0]
124 | if "bg.jpg" in fn:
125 | bg_data[idx] = cv2.imread(fn).transpose(2, 0, 1)
126 | else:
127 | fg_data[idx] = cv2.imread(fn).transpose(2, 0, 1)
128 |
129 | for folder in folder_names:
130 | for img in tqdm(os.listdir(join(scene_path, folder))):
131 | fn = join(scene_path, folder, img)
132 | idx = img.split(".")[0]
133 | labels = pickle.load(open(fn, 'rb'))
134 | sp_data[idx] = colorEncode(labels, colors).transpose(2, 0, 1)
135 |
136 | for folder in folder_names:
137 | for img in tqdm(os.listdir(join(shuffle_path, folder))):
138 | fn = join(shuffle_path, folder, img)
139 | idx = img.split(".")[0]
140 | if idx in sf_data.keys():
141 | sf_data[idx].append(cv2.imread(fn).transpose(2, 0, 1))
142 | else:
143 | sf_data[idx] = [cv2.imread(fn).transpose(2, 0, 1)]
144 |
145 | score = pickle.load(open("/newNAS/Share/ykli/sigmoid_shuffle_score_dict.pkl", 'rb'))
146 |
147 | print(len(bg_data), list(bg_data.values())[0].shape)
148 | print(len(fg_data))
149 | print(len(sp_data), list(sp_data.values())[0].shape)
150 | print(len(sf_data))
151 | return bg_data, fg_data, sp_data, sf_data, score
152 |
153 |
--------------------------------------------------------------------------------