├── README.md
├── backup
├── dataset
├── dataloader.py
├── reds.py
└── vimeo7.py
├── demo
├── sigma10.gif
├── sigma100.gif
└── sigma50.gif
├── eval.sh
├── gen_img.py
├── gen_video.py
├── gif_combine.py
├── loss
└── loss.py
├── main.py
├── model
├── CRFP.py
├── CRFP_runtime.py
├── CRFP_test.py
└── LTE.py
├── option.py
├── overview.png
├── png2mp4.py
├── pretrained_models
├── EGVSR_iter420000.pth
├── fnet.pth
└── spynet_20210409-c6c1bd09.pth
├── pytorch_ssim
└── __init__.py
├── test.sh
├── test_img_coor.py
├── test_runtime.py
├── test_video.py
├── test_video_quality.sh
├── train.sh
├── trainer.py
├── untar_models.sh
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Cross-Resolution-Flow-Propagation-for-Foveated-Video-Super-Resolution
2 |
3 | Official implementation of Cross-Resolution Flow Propagation for Foveated Video Super-Resolution (CRFP) accepted by WACV 2023.
4 |
5 |
6 |
7 | ## Demo
8 |
9 | Demonstration how CRFP deal with various value of $\sigma^T$ representing the noise induced by the movement of eye tracker during pupil detection. Note that regions beyond the foveated region are resolved under $8\times$ super-resolution.
10 |
11 | $\sigma^T=10$
12 |
13 |
14 |
15 | $\sigma^T=50$
16 |
17 |
18 |
19 | $\sigma^T=100$
20 |
21 |
22 |
23 |
24 |
25 | ## Training and evaluation
26 | To train the model, you need to install DCN first from https://github.com/jinfagang/DCNv2_latest
27 |
28 | Run the following to start training
29 | ```
30 | bash train.sh
31 | ```
32 |
33 | To evaluate, run
34 | ```
35 | bash eval.sh
36 | ```
37 |
38 | To test, run
39 | ```
40 | bash test.sh
41 | ```
42 |
43 | ## References
44 | Most of the code is referenced from
45 |
46 | 1. TTSR: https://github.com/researchmm/TTSR
47 | 2. BasicVSR: https://github.com/open-mmlab/mmediting
48 |
49 |
50 |
--------------------------------------------------------------------------------
/backup:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | tar --exclude='train/*' --exclude='env/*' --exclude='eval/*' --exclude='test/*' --exclude='test_png/*' --exclude='test_video/*' --exclude='test_gif/*' --exclude='old_tree_x1/*' --exclude='old_tree_x8/*' -cvf CRFP_x8.tar .
4 |
--------------------------------------------------------------------------------
/dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from importlib import import_module
3 |
4 |
5 | def get_dataloader(args):
6 | ### import module
7 | m = import_module('dataset.' + args.dataset.lower())
8 |
9 | if (args.dataset == 'Vimeo7'):
10 | print('Processing Vimeo7 dataset...')
11 | data_train = getattr(m, 'TrainSet')(args)
12 | dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
13 | data_eval = getattr(m, 'EvalSet')(args)
14 | dataloader_eval = DataLoader(data_eval, batch_size=1, shuffle=True, num_workers=1)
15 | data_test = getattr(m, 'TestSet')(args)
16 | dataloader_test = DataLoader(data_test, batch_size=1, shuffle=False, num_workers=1)
17 | dataloader = {'train': dataloader_train, 'eval': dataloader_eval, 'test': dataloader_test}
18 | elif (args.dataset == 'Reds'):
19 | print('Processing Reds dataset...')
20 | data_train = getattr(m, 'TrainSet')(args)
21 | dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
22 | data_eval = getattr(m, 'EvalSet')(args)
23 | dataloader_eval = DataLoader(data_eval, batch_size=1, shuffle=False, num_workers=1)
24 | data_test = getattr(m, 'TestSet')(args)
25 | dataloader_test = DataLoader(data_test, batch_size=1, shuffle=False, num_workers=1)
26 | dataloader = {'train': dataloader_train, 'eval': dataloader_eval, 'test': dataloader_test}
27 | else:
28 | raise SystemExit('Error: no such type of dataset!')
29 |
30 | return dataloader
--------------------------------------------------------------------------------
/dataset/reds.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import pickle
5 | import logging
6 | import numpy as np
7 | import PIL
8 | import pdb
9 | import math
10 |
11 | import torch
12 | import torch.utils.data as data
13 | import torchvision.transforms.functional as F
14 |
15 | logger = logging.getLogger('base')
16 |
17 | def fovea_generator(GT_imgs, method='Rscan', step=0.1, FV_HW=(32, 32)):
18 |
19 | len_sp = len(GT_imgs)
20 | if torch.is_tensor(GT_imgs):
21 | _, GT_H, GT_W = GT_imgs[0].size()
22 | else:
23 | GT_H, GT_W, _ = GT_imgs[0].shape
24 |
25 | FV_H, FV_W = FV_HW
26 | if method == 'Cscan' or method == 'Zscan':
27 | SP = 0.1
28 | CP = 0.5
29 | EP = 0.9
30 | else:
31 | SP = 0.1
32 | CP = 0.5
33 | EP = 0.9
34 |
35 | #### shift according to center point
36 | CP_H = (GT_H*CP - FV_H//2)/GT_H
37 | CP_W = (GT_W*CP - FV_W//2)/GT_W
38 | EP_H = (GT_H*EP - FV_H)/GT_H
39 | EP_W = (GT_W*EP - FV_W)/GT_W
40 |
41 | #### finetune step size
42 | if method == 'Cscan' or method == 'Zscan':
43 | if SP + math.ceil(math.sqrt(len_sp)) * step > EP_H or SP + math.ceil(math.sqrt(len_sp)) * step > EP_W:
44 | step = min((EP_H - SP) / math.ceil(math.sqrt(len_sp)), (EP_W - SP) / math.ceil(math.sqrt(len_sp)))
45 | SP = int(SP * 100)
46 | step = int(step * 100)
47 | EP = int(SP + math.ceil(math.sqrt(len_sp) - 1) * step)
48 | elif method == 'Hscan':
49 | if SP + len_sp * step > EP_W:
50 | step = (EP_W - SP) / len_sp
51 | SP = int(SP * 100)
52 | step = int(step * 100)
53 | EP = int((SP + len_sp * step))
54 | elif method == 'Vscan':
55 | if SP + len_sp * step > EP_H:
56 | step = (EP_H - SP) / len_sp
57 | SP = int(SP * 100)
58 | step = int(step * 100)
59 | EP = int((SP + len_sp * step))
60 | else:
61 | if SP + len_sp * step > EP_H or SP + len_sp * step > EP_W:
62 | step = min((EP_H - SP) / len_sp, (EP_W - SP) / len_sp)
63 | SP = int(SP * 100)
64 | step = int(step * 100)
65 | EP = int((SP + len_sp * step))
66 |
67 | #### fovea scan simulation
68 | if method == 'Hscan':
69 | fv_sp = [[int(CP_H * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]]
70 | elif method == 'Vscan':
71 | fv_sp = [[int((v / 100) * GT_H), int(CP_W * GT_W)] for v in [*range(SP, EP, step)]]
72 | elif method == 'DRscan': # Not done
73 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]]
74 | elif method == 'DLscan': # Not done
75 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]]
76 | elif method == 'Cscan':
77 | fv_sp = []
78 | v, h = (SP, SP)
79 | v_step, h_step = (step, step)
80 | for t in range(len_sp):
81 | fv_sp.append([int((v / 100) * GT_H), int((h / 100) * GT_W)])
82 | if h == EP and h_step > 0:
83 | h_step = -h_step
84 | v += v_step
85 | elif h == SP and h_step < 0:
86 | h_step = -h_step
87 | v += v_step
88 | else:
89 | h += h_step
90 | elif method == 'Zscan':
91 | fv_sp = []
92 | v, h = (SP, SP)
93 | v_step, h_step = (step, step)
94 | for t in range(len_sp):
95 | fv_sp.append([int((v / 100) * GT_H), int((h / 100) * GT_W)])
96 | if h == EP and v_step < 0:
97 | v_step = -v_step
98 | v += v_step
99 | h_step = -abs(h_step)
100 | elif v == SP and h_step > 0:
101 | h += h_step
102 | h_step = -h_step
103 | v_step = abs(v_step)
104 | elif v == EP and h_step < 0:
105 | h_step = -h_step
106 | h += h_step
107 | v_step = -abs(v_step)
108 | elif h == SP and v_step > 0:
109 | v += v_step
110 | v_step = -v_step
111 | h_step = abs(h_step)
112 | else:
113 | h += h_step
114 | v += v_step
115 | elif method == 'Rscan':
116 | sigma = 0.05
117 | rand_h = np.random.normal(CP_H, sigma, len_sp).clip(0, EP_H)
118 | rand_w = np.random.normal(CP_W, sigma, len_sp).clip(0, EP_W)
119 | fv_sp = [[int(rh * GT_H), int(rw * GT_W)] for rh, rw in zip(rand_h, rand_w)]
120 | elif method == 'Nanascan': # Not done
121 | # SP_H = 0
122 | # EP_H = (GT_H - FV_H - 1) / GT_H
123 | # Q1_H = (0.25 - (FV_H / GT_H)/2) if (0.25 - (FV_H / GT_H)/2) > 0 else SP_H
124 | # Q2_H = (0.50 - (FV_H / GT_H)/2)
125 | # Q3_H = (0.75 - (FV_H / GT_H)/2) if (0.75 + (FV_H / GT_H)/2) < 1 else EP_H
126 | # T1_H = (0.33 - (FV_H / GT_H)/2) if (0.33 - (FV_H / GT_H)/2) > 0 else SP_H
127 | # T2_H = (0.66 - (FV_H / GT_H)/2) if (0.66 + (FV_H / GT_H)/2) < 1 else EP_H
128 |
129 | ratio_H = FV_H / GT_H
130 | SP_H = 0 + (ratio_H / 2)
131 | EP_H = 1 - (ratio_H / 2)
132 | Q1_H = SP_H + ((EP_H - SP_H) * 0.25)
133 | Q2_H = SP_H + ((EP_H - SP_H) * 0.50)
134 | Q3_H = SP_H + ((EP_H - SP_H) * 0.75)
135 | T1_H = SP_H + ((EP_H - SP_H) * 0.33)
136 | T2_H = SP_H + ((EP_H - SP_H) * 0.66)
137 |
138 | ratio_W = FV_W / GT_W
139 | SP_W = 0 + (ratio_W / 2)
140 | EP_W = 1 - (ratio_W / 2)
141 | Q1_W = SP_W + ((EP_W - SP_W) * 0.25)
142 | Q2_W = SP_W + ((EP_W - SP_W) * 0.50)
143 | Q3_W = SP_W + ((EP_W - SP_W) * 0.75)
144 | T1_W = SP_W + ((EP_W - SP_W) * 0.33)
145 | T2_W = SP_W + ((EP_W - SP_W) * 0.66)
146 |
147 | locs = [[SP_H, SP_W], [SP_H, T1_W], [SP_H, T2_W], [SP_H, EP_W],
148 | [T1_H, SP_W], [T1_H, T1_W], [T1_H, T2_W], [T1_H, EP_W],
149 | [T2_H, SP_W], [T2_H, T1_W], [T2_H, T2_W], [T2_H, EP_W],
150 | [EP_H, SP_W], [EP_H, T1_W], [EP_H, T2_W], [EP_H, EP_W]]
151 | locs = [(y - (ratio_H / 2), x - (ratio_H / 2)) for y, x in locs]
152 | # locs = [[Q1_H, T1_W], [Q1_H, T2_W], [Q2_H, Q1_W], [Q2_H, Q2_W], [Q2_H, Q3_W], [Q3_H, T1_W], [Q3_H, T2_W]]
153 | # locs = [[Q2_H, Q2_W]]
154 |
155 | fv_sp = random.choices(locs, k=len_sp)
156 | fv_sp = [[min(int(v[0] * GT_H), GT_H-FV_H), min(int(v[1] * GT_W), GT_W-FV_W)] for v in fv_sp]
157 | random.shuffle(fv_sp)
158 | elif method == 'Evenscan': # Not done
159 | idx = 20
160 | N_H = GT_H // FV_H
161 | N_W = GT_W // FV_W
162 | SP_H = GT_H / N_H
163 | SP_W = GT_W / N_W
164 | fv_sp = []
165 | for i in range(idx, idx + len_sp):
166 | x_i = i % N_W
167 | y_i = (i // N_W) % N_H
168 | fv_sp.append([int((1+y_i)*SP_H - (SP_H + FV_H)/2), int((1+x_i)*SP_W - (SP_W + FV_W)/2)])
169 | elif method == 'DemoHscan': # Not done
170 | SP_H = 0
171 | EP_H = 1
172 | SP_W = 0
173 | EP_W = 1
174 | fv_sp = []
175 | direction = -1
176 | scan_step = 8
177 | accm_step = GT_W - scan_step
178 | for _ in range(len_sp):
179 | fv_sp.append([0, accm_step])
180 | accm_step += direction * scan_step
181 | if accm_step < 0:
182 | direction *= -1
183 | accm_step += direction * scan_step
184 | elif accm_step >= GT_W:
185 | direction *= -1
186 | accm_step += direction * scan_step
187 | else:
188 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]]
189 |
190 | fv_sp = torch.tensor(fv_sp)
191 | if torch.is_tensor(GT_imgs):
192 | FV_imgs = []
193 | Ref_sps = []
194 | for t in range(len(GT_imgs)):
195 | #### With padding ####
196 | Ref_sp = torch.zeros_like(GT_imgs[t])
197 | if method == 'DemoHscan':
198 | Ref_sp[:, fv_sp[t][0]:, fv_sp[t][1]:] = 1
199 | else:
200 | Ref_sp[:, fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W] = 1
201 | Ref = GT_imgs[t] * Ref_sp
202 | FV_imgs.append(Ref)
203 | Ref_sps.append(Ref_sp)
204 | #### Without padding ####
205 | # FV_imgs.append(GT_imgs[t][:, fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W].clone())
206 | # FV_imgs = torch.stack(FV_imgs, dim=0)
207 | else:
208 | FV_imgs = []
209 | Ref_sps = []
210 | for t in range(len(GT_imgs)):
211 | #### With padding ####
212 | # Ref_sp = np.zeros_like(GT_imgs[t])
213 | H, W, C = GT_imgs[t].shape
214 | Ref_sp = np.zeros((H, W, 1))
215 | if method == 'DemoHscan':
216 | Ref_sp[fv_sp[t][0]:, fv_sp[t][1]:, :] = 1
217 | else:
218 | Ref_sp[fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W, :] = 1
219 | Ref = GT_imgs[t] * Ref_sp
220 | FV_imgs.append(Ref)
221 | Ref_sps.append(Ref_sp)
222 | #### Without padding ####
223 | # FV_imgs.append(GT_imgs[t][fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W, :].copy())
224 | # FV_imgs = np.stack(FV_imgs, dim=0)
225 |
226 | return FV_imgs, Ref_sps, fv_sp
227 |
228 | class TrainSet(data.Dataset):
229 | '''
230 | Reading the training Vimeo dataset
231 | key example: train/00001/0001/im1.png
232 | GT: Ground-Truth;
233 | LQ: Low-Quality, e.g., low-resolution frames
234 | support reading N HR frames, N = 3, 5, 7
235 | '''
236 | def __init__(self, args):
237 | super(TrainSet, self).__init__()
238 | self.args = args
239 |
240 | scale = self.args.scale
241 | if scale == 8:
242 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI_x8')
243 | elif scale == 4:
244 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI')
245 | self.GT_dir_list = sorted([os.path.join(args.dataset_dir, 'train/train/train_sharp', name) for name in
246 | os.listdir(os.path.join(args.dataset_dir, 'train/train/train_sharp')) if name not in ['000', '011', '015', '020']]) + \
247 | sorted([os.path.join(args.dataset_dir, 'val/val/val_sharp', name) for name in
248 | os.listdir(os.path.join(args.dataset_dir, 'val/val/val_sharp')) if name not in ['000', '001', '006', '017']])
249 | self.LR_dir_list = sorted([os.path.join(LR_root, 'train/train/train_sharp', name) for name in
250 | os.listdir(os.path.join(LR_root, 'train/train/train_sharp')) if name not in ['000', '011', '015', '020']]) + \
251 | sorted([os.path.join(LR_root, 'val/val/val_sharp', name) for name in
252 | os.listdir(os.path.join(LR_root, 'val/val/val_sharp')) if name not in ['000', '001', '006', '017']])
253 | N_frames = self.args.N_frames
254 | self.GT_imgfiles = []
255 | self.LR_imgfiles = []
256 | for idx in range(len(self.GT_dir_list)):
257 | GT_imgfiles_cur = sorted(os.listdir(self.GT_dir_list[idx]))
258 | for img_idx in range(0, len(GT_imgfiles_cur) - N_frames + 1):
259 | self.GT_imgfiles.append([os.path.join(self.GT_dir_list[idx], img_f) for img_f in GT_imgfiles_cur[img_idx:img_idx + N_frames]])
260 | for idx in range(len(self.LR_dir_list)):
261 | LR_imgfiles_cur = sorted(os.listdir(self.LR_dir_list[idx]))
262 | for img_idx in range(0, len(LR_imgfiles_cur) - N_frames + 1):
263 | self.LR_imgfiles.append([os.path.join(self.LR_dir_list[idx], img_f) for img_f in LR_imgfiles_cur[img_idx:img_idx + N_frames]])
264 |
265 | def __getitem__(self, index):
266 | #### Configs
267 | scale = self.args.scale
268 | GT_size = self.args.GT_size
269 | LR_size = GT_size // scale
270 | FV_size = self.args.FV_size
271 |
272 | ### GT
273 | GT_imgfiles = self.GT_imgfiles[index]
274 | GT_imgs = [np.array(PIL.Image.open(img)) for img in GT_imgfiles]
275 |
276 | #### Bicubic downsampling
277 | ### LR and LR_sr
278 | H_, W_, _ = GT_imgs[0].shape
279 | LR_imgfiles = self.LR_imgfiles[index]
280 | LR_imgs = [np.array(PIL.Image.open(img)) for img in LR_imgfiles]
281 | LR_sr_imgs = [np.array(PIL.Image.fromarray(img).resize((W_, H_), PIL.Image.BICUBIC)) for img in LR_imgs]
282 |
283 | ### Random cropping
284 | H, W, C = LR_imgs[0].shape
285 | rnd_h = random.randint(0, max(0, H - LR_size))
286 | rnd_w = random.randint(0, max(0, W - LR_size))
287 | LR_imgs = [v[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] for v in LR_imgs]
288 |
289 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
290 | GT_imgs = [v[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] for v in GT_imgs]
291 | LR_sr_imgs = [v[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] for v in LR_sr_imgs]
292 |
293 | ### Ref(FV) and Ref_sr
294 | Ref, Ref_sp, _ = fovea_generator(GT_imgs, method='Nanascan', FV_HW=(FV_size, FV_size))
295 |
296 | #### Stacking
297 | GT_imgs = np.stack(GT_imgs, axis=0)
298 | LR_imgs = np.stack(LR_imgs, axis=0)
299 | LR_sr_imgs = np.stack(LR_sr_imgs, axis=0)
300 | Ref = np.stack(Ref, axis=0)
301 | Ref_sp = np.stack(Ref_sp, axis=0)
302 |
303 | #### Scaling
304 | GT_imgs = GT_imgs.astype(np.float32) / 255.
305 | LR_imgs = LR_imgs.astype(np.float32) / 255.
306 | LR_sr_imgs = LR_sr_imgs.astype(np.float32) / 255.
307 | Ref = Ref.astype(np.float32) / 255.
308 | Ref_sp = Ref_sp.astype(np.bool_)
309 |
310 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float()
311 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float()
312 | LR_sr_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_sr_imgs, (0, 3, 1, 2)))).float()
313 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float()
314 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2))))
315 |
316 | if torch.rand(1) < 0.5:
317 | GT_imgs = F.hflip(GT_imgs)
318 | LR_imgs = F.hflip(LR_imgs)
319 | LR_sr_imgs = F.hflip(LR_sr_imgs)
320 | Ref = F.hflip(Ref)
321 | Ref_sp = F.hflip(Ref_sp)
322 |
323 | if torch.rand(1) < 0.5:
324 | GT_imgs = F.vflip(GT_imgs)
325 | LR_imgs = F.vflip(LR_imgs)
326 | LR_sr_imgs = F.vflip(LR_sr_imgs)
327 | Ref = F.vflip(Ref)
328 | Ref_sp = F.vflip(Ref_sp)
329 |
330 | return {'LR': LR_imgs,
331 | 'LR_sr': LR_sr_imgs,
332 | 'HR': GT_imgs,
333 | 'Ref': Ref,
334 | 'Ref_sp': Ref_sp}
335 |
336 | def __len__(self):
337 | return len(self.GT_imgfiles)
338 |
339 | class EvalSet(data.Dataset):
340 | '''
341 | Reading the training Vimeo dataset
342 | key example: train/00001/0001/im1.png
343 | GT: Ground-Truth;
344 | LQ: Low-Quality, e.g., low-resolution frames
345 | support reading N HR frames, N = 3, 5, 7
346 | '''
347 | def __init__(self, args):
348 | super(EvalSet, self).__init__()
349 | self.args = args
350 |
351 | scale = self.args.scale
352 | if scale == 8:
353 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI_x8')
354 | elif scale == 4:
355 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI')
356 | self.GT_dir_list = sorted([os.path.join(args.dataset_dir, 'val/val/val_sharp', name) for name in
357 | ['000', '001', '006', '017']])
358 | self.LR_dir_list = sorted([os.path.join(LR_root, 'val/val/val_sharp', name) for name in
359 | ['000', '001', '006', '017']])
360 | # self.GT_dir_list = sorted([os.path.join(args.dataset_dir, 'train/train/train_sharp', name) for name in
361 | # ['000', '011', '015', '020']])
362 | # self.LR_dir_list = sorted([os.path.join(LR_root, 'train/train/train_sharp', name) for name in
363 | # ['000', '011', '015', '020']])
364 | N_frames = self.args.N_frames
365 |
366 | self.GT_imgfiles = []
367 | self.LR_imgfiles = []
368 | for idx in range(len(self.GT_dir_list)):
369 | GT_imgfiles_cur = sorted(os.listdir(self.GT_dir_list[idx]))
370 | for img_idx in range(0, len(GT_imgfiles_cur) - N_frames + 1):
371 | self.GT_imgfiles.append([os.path.join(self.GT_dir_list[idx], img_f) for img_f in GT_imgfiles_cur[img_idx:img_idx + N_frames]])
372 | for idx in range(len(self.LR_dir_list)):
373 | LR_imgfiles_cur = sorted(os.listdir(self.LR_dir_list[idx]))
374 | for img_idx in range(0, len(LR_imgfiles_cur) - N_frames + 1):
375 | self.LR_imgfiles.append([os.path.join(self.LR_dir_list[idx], img_f) for img_f in LR_imgfiles_cur[img_idx:img_idx + N_frames]])
376 |
377 | def __getitem__(self, index):
378 | #### Configs
379 | scale = self.args.scale
380 | GT_size = self.args.GT_size
381 | LR_size = GT_size // scale
382 | FV_size = self.args.FV_size
383 |
384 | ### GT
385 | GT_imgfiles = self.GT_imgfiles[index]
386 | GT_imgs = [np.array(PIL.Image.open(img)) for img in GT_imgfiles]
387 |
388 | #### Bicubic downsampling
389 | ### LR and LR_sr
390 | H_, W_, _ = GT_imgs[0].shape
391 | LR_imgfiles = self.LR_imgfiles[index]
392 | LR_imgs = [np.array(PIL.Image.open(img)) for img in LR_imgfiles]
393 | # LR_imgs = [np.array(PIL.Image.fromarray(img).resize((W_ // 8, H_ // 8), PIL.Image.BILINEAR)) for img in GT_imgs]
394 | LR_sr_imgs = [np.array(PIL.Image.fromarray(img).resize((W_, H_), PIL.Image.BICUBIC)) for img in LR_imgs]
395 |
396 | ### Ref(FV) and Ref_sr
397 | Ref, Ref_sp, fv_sp = fovea_generator(GT_imgs, method='Evenscan', FV_HW=(FV_size, FV_size))
398 |
399 | #### Stacking
400 | GT_imgs = np.stack(GT_imgs, axis=0)
401 | LR_imgs = np.stack(LR_imgs, axis=0)
402 | LR_sr_imgs = np.stack(LR_sr_imgs, axis=0)
403 | Ref = np.stack(Ref, axis=0)
404 | Ref_sp = np.stack(Ref_sp, axis=0)
405 |
406 | #### Scaling
407 | GT_imgs = GT_imgs.astype(np.float32) / 255.
408 | LR_imgs = LR_imgs.astype(np.float32) / 255.
409 | LR_sr_imgs = LR_sr_imgs.astype(np.float32) / 255.
410 | Ref = Ref.astype(np.float32) / 255.
411 | Ref_sp = Ref_sp.astype(np.bool_)
412 |
413 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float()
414 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float()
415 | LR_sr_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_sr_imgs, (0, 3, 1, 2)))).float()
416 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float()
417 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2))))
418 |
419 | return {'LR': LR_imgs,
420 | 'LR_sr': LR_sr_imgs,
421 | 'HR': GT_imgs,
422 | 'Ref': Ref,
423 | 'Ref_sp': Ref_sp,
424 | 'FV_sp': fv_sp}
425 |
426 | def __len__(self):
427 | return len(self.GT_imgfiles)
428 |
429 | class TestSet(data.Dataset):
430 | '''
431 | Reading the training Vimeo dataset
432 | key example: train/00001/0001/im1.png
433 | GT: Ground-Truth;
434 | LQ: Low-Quality, e.g., low-resolution frames
435 | support reading N HR frames, N = 3, 5, 7
436 | '''
437 | def __init__(self, args):
438 | super(TestSet, self).__init__()
439 | self.args = args
440 | scale = self.args.scale
441 | if scale == 8:
442 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI_x8')
443 | elif scale == 4:
444 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI')
445 |
446 | self.GT_dir_list = sorted([os.path.join(args.dataset_dir, 'train/train/train_sharp', name) for name in
447 | ['000', '011', '015', '020']])
448 | self.LR_dir_list = sorted([os.path.join(LR_root, 'train/train/train_sharp', name) for name in
449 | ['000', '011', '015', '020']])
450 |
451 | N_frames = self.args.N_frames
452 |
453 | self.GT_imgfiles = []
454 | self.LR_imgfiles = []
455 | for idx in range(len(self.GT_dir_list)):
456 | GT_imgfiles_cur = sorted(os.listdir(self.GT_dir_list[idx]))
457 | for img_idx in range(0, len(GT_imgfiles_cur) - N_frames + 1):
458 | self.GT_imgfiles.append([os.path.join(self.GT_dir_list[idx], img_f) for img_f in GT_imgfiles_cur[img_idx:img_idx + N_frames]])
459 | for idx in range(len(self.LR_dir_list)):
460 | LR_imgfiles_cur = sorted(os.listdir(self.LR_dir_list[idx]))
461 | for img_idx in range(0, len(LR_imgfiles_cur) - N_frames + 1):
462 | self.LR_imgfiles.append([os.path.join(self.LR_dir_list[idx], img_f) for img_f in LR_imgfiles_cur[img_idx:img_idx + N_frames]])
463 |
464 | def __getitem__(self, index):
465 | #### Configs
466 | scale = self.args.scale
467 | GT_size = self.args.GT_size
468 | LR_size = GT_size // scale
469 | FV_size = self.args.FV_size
470 |
471 | ### GT
472 | GT_imgfiles = self.GT_imgfiles[index]
473 | GT_imgs = [np.array(PIL.Image.open(img)) for img in GT_imgfiles]
474 |
475 | #### Bicubic downsampling
476 | ### LR and LR_sr
477 | H_, W_, _ = GT_imgs[0].shape
478 | LR_imgfiles = self.LR_imgfiles[index]
479 | LR_imgs = [np.array(PIL.Image.open(img)) for img in LR_imgfiles]
480 | LR_sr_imgs = [np.array(PIL.Image.fromarray(img).resize((W_, H_), PIL.Image.BICUBIC)) for img in LR_imgs]
481 |
482 | ### Ref(FV) and Ref_sr
483 | Ref, Ref_sp, fv_sp = fovea_generator(GT_imgs, method='Evenscan', FV_HW=(FV_size, FV_size))
484 |
485 | #### Stacking
486 | GT_imgs = np.stack(GT_imgs, axis=0)
487 | LR_imgs = np.stack(LR_imgs, axis=0)
488 | LR_sr_imgs = np.stack(LR_sr_imgs, axis=0)
489 | Ref = np.stack(Ref, axis=0)
490 | Ref_sp = np.stack(Ref_sp, axis=0)
491 |
492 | #### Scaling
493 | GT_imgs = GT_imgs.astype(np.float32) / 255.
494 | LR_imgs = LR_imgs.astype(np.float32) / 255.
495 | LR_sr_imgs = LR_sr_imgs.astype(np.float32) / 255.
496 | Ref = Ref.astype(np.float32) / 255.
497 | Ref_sp = Ref_sp.astype(np.bool_)
498 |
499 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float()
500 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float()
501 | LR_sr_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_sr_imgs, (0, 3, 1, 2)))).float()
502 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float()
503 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2))))
504 |
505 | return {'LR': LR_imgs,
506 | 'LR_sr': LR_sr_imgs,
507 | 'HR': GT_imgs,
508 | 'Ref': Ref,
509 | 'Ref_sp': Ref_sp,
510 | 'FV_sp': fv_sp}
511 |
512 | def __len__(self):
513 | return len(self.GT_imgfiles)
--------------------------------------------------------------------------------
/dataset/vimeo7.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import pickle
5 | import logging
6 | import numpy as np
7 | import PIL
8 | import pdb
9 | import math
10 |
11 | import torch
12 | import torch.utils.data as data
13 | import torchvision.transforms.functional as F
14 | import torch.nn.functional as nnF
15 | from torchvision.transforms import Compose, ToTensor
16 |
17 | logger = logging.getLogger('base')
18 |
19 | def gaussian_downsample(x, scale=4):
20 | """Downsamping with Gaussian kernel used in the DUF official code
21 | Args:
22 | x (Tensor, [C, T, H, W]): frames to be downsampled.
23 | scale (int): downsampling factor: 2 | 3 | 4.
24 | """
25 |
26 | assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale)
27 |
28 | def gkern(kernlen=13, nsig=1.6):
29 | import scipy.ndimage.filters as fi
30 | inp = np.zeros((kernlen, kernlen))
31 | # set element at the middle to one, a dirac delta
32 | inp[kernlen // 2, kernlen // 2] = 1
33 | # gaussian-smooth the dirac, resulting in a gaussian filter mask
34 | return fi.gaussian_filter(inp, nsig)
35 |
36 | if scale == 2:
37 | h = gkern(13, 0.8) # 13 and 0.8 for x2
38 | elif scale == 3:
39 | h = gkern(13, 1.2) # 13 and 1.2 for x3
40 | elif scale == 4:
41 | h = gkern(13, 1.6) # 13 and 1.6 for x4
42 | else:
43 | print('Invalid upscaling factor: {} (Must be one of 2, 3, 4)'.format(R))
44 | exit(1)
45 |
46 | C, T, H, W = x.size()
47 | x = x.contiguous().view(-1, 1, H, W) # depth convolution (channel-wise convolution)
48 | pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter
49 | r_h, r_w = 0, 0
50 |
51 | if scale == 3:
52 | r_h = 3 - (H % 3)
53 | r_w = 3 - (W % 3)
54 |
55 | x = nnF.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], mode='reflect')
56 | gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0)
57 | x = nnF.conv2d(x, gaussian_filter, stride=scale)
58 | # please keep the operation same as training.
59 | # if downsample to 32 on training time, use the below code.
60 | x = x[:, :, 2:-2, 2:-2]
61 | # if downsample to 28 on training time, use the below code.
62 | #x = x[:,:,scale:-scale,scale:-scale]
63 | x = x.view(C, T, x.size(2), x.size(3)).permute(1, 0 ,2, 3)
64 | return x
65 |
66 | def fovea_generator(GT_imgs, method='Rscan', step=0.1, FV_HW=(32, 32)):
67 |
68 | len_sp = len(GT_imgs)
69 | if torch.is_tensor(GT_imgs):
70 | _, GT_H, GT_W = GT_imgs[0].size()
71 | else:
72 | GT_H, GT_W, _ = GT_imgs[0].shape
73 |
74 | FV_H, FV_W = FV_HW
75 | if method == 'Cscan' or method == 'Zscan':
76 | SP = 0.1
77 | CP = 0.5
78 | EP = 0.9
79 | else:
80 | SP = 0.1
81 | CP = 0.5
82 | EP = 0.9
83 |
84 | #### shift according to center point
85 | CP_H = (GT_H*CP - FV_H//2)/GT_H
86 | CP_W = (GT_W*CP - FV_W//2)/GT_W
87 | EP_H = (GT_H*EP - FV_H)/GT_H
88 | EP_W = (GT_W*EP - FV_W)/GT_W
89 |
90 | #### finetune step size
91 | if method == 'Cscan' or method == 'Zscan':
92 | if SP + math.ceil(math.sqrt(len_sp)) * step > EP_H or SP + math.ceil(math.sqrt(len_sp)) * step > EP_W:
93 | step = min((EP_H - SP) / math.ceil(math.sqrt(len_sp)), (EP_W - SP) / math.ceil(math.sqrt(len_sp)))
94 | SP = int(SP * 100)
95 | step = int(step * 100)
96 | EP = int(SP + math.ceil(math.sqrt(len_sp) - 1) * step)
97 | elif method == 'Hscan':
98 | if SP + len_sp * step > EP_W:
99 | step = (EP_W - SP) / len_sp
100 | SP = int(SP * 100)
101 | step = int(step * 100)
102 | EP = int((SP + len_sp * step))
103 | elif method == 'Vscan':
104 | if SP + len_sp * step > EP_H:
105 | step = (EP_H - SP) / len_sp
106 | SP = int(SP * 100)
107 | step = int(step * 100)
108 | EP = int((SP + len_sp * step))
109 | else:
110 | if SP + len_sp * step > EP_H or SP + len_sp * step > EP_W:
111 | step = min((EP_H - SP) / len_sp, (EP_W - SP) / len_sp)
112 | SP = int(SP * 100)
113 | step = int(step * 100)
114 | EP = int((SP + len_sp * step))
115 |
116 | #### fovea scan simulation
117 | if method == 'Hscan':
118 | fv_sp = [[int(CP_H * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]]
119 | elif method == 'Vscan':
120 | fv_sp = [[int((v / 100) * GT_H), int(CP_W * GT_W)] for v in [*range(SP, EP, step)]]
121 | elif method == 'DRscan': # Not done
122 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]]
123 | elif method == 'DLscan': # Not done
124 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]]
125 | elif method == 'Cscan':
126 | fv_sp = []
127 | v, h = (SP, SP)
128 | v_step, h_step = (step, step)
129 | for t in range(len_sp):
130 | fv_sp.append([int((v / 100) * GT_H), int((h / 100) * GT_W)])
131 | if h == EP and h_step > 0:
132 | h_step = -h_step
133 | v += v_step
134 | elif h == SP and h_step < 0:
135 | h_step = -h_step
136 | v += v_step
137 | else:
138 | h += h_step
139 | elif method == 'Zscan':
140 | fv_sp = []
141 | v, h = (SP, SP)
142 | v_step, h_step = (step, step)
143 | for t in range(len_sp):
144 | fv_sp.append([int((v / 100) * GT_H), int((h / 100) * GT_W)])
145 | if h == EP and v_step < 0:
146 | v_step = -v_step
147 | v += v_step
148 | h_step = -abs(h_step)
149 | elif v == SP and h_step > 0:
150 | h += h_step
151 | h_step = -h_step
152 | v_step = abs(v_step)
153 | elif v == EP and h_step < 0:
154 | h_step = -h_step
155 | h += h_step
156 | v_step = -abs(v_step)
157 | elif h == SP and v_step > 0:
158 | v += v_step
159 | v_step = -v_step
160 | h_step = abs(h_step)
161 | else:
162 | h += h_step
163 | v += v_step
164 | elif method == 'Rscan':
165 | sigma = 0.05
166 | rand_h = np.random.normal(CP_H, sigma, len_sp).clip(0, EP_H)
167 | rand_w = np.random.normal(CP_W, sigma, len_sp).clip(0, EP_W)
168 | fv_sp = [[int(rh * GT_H), int(rw * GT_W)] for rh, rw in zip(rand_h, rand_w)]
169 | elif method == 'Nanascan': # Not done
170 | SP_H = 0
171 | EP_H = (GT_H - FV_H - 1) / GT_H
172 | Q1_H = (0.25 - (FV_H / GT_H)/2) if (0.25 - (FV_H / GT_H)/2) > 0 else SP_H
173 | Q2_H = (0.50 - (FV_H / GT_H)/2)
174 | Q3_H = (0.75 - (FV_H / GT_H)/2) if (0.75 + (FV_H / GT_H)/2) <= 1 else EP_H
175 | T1_H = (0.33 - (FV_H / GT_H)/2) if (0.33 - (FV_H / GT_H)/2) > 0 else SP_H
176 | T2_H = (0.66 - (FV_H / GT_H)/2) if (0.66 + (FV_H / GT_H)/2) <= 1 else EP_H
177 | SP_W = 0
178 | EP_W = (GT_W - FV_W - 1) / GT_W
179 | Q1_W = (0.25 - (FV_W / GT_W)/2) if (0.25 - (FV_W / GT_W)/2) > 0 else SP_W
180 | Q2_W = (0.50 - (FV_W / GT_W)/2)
181 | Q3_W = (0.75 - (FV_W / GT_W)/2) if (0.75 + (FV_W / GT_W)/2) <= 1 else EP_W
182 | T1_W = (0.33 - (FV_W / GT_W)/2) if (0.33 - (FV_W / GT_W)/2) > 0 else SP_W
183 | T2_W = (0.66 - (FV_W / GT_W)/2) if (0.66 + (FV_W / GT_W)/2) <= 1 else EP_W
184 |
185 | fv_sp = [[Q1_H, T1_W], [Q1_H, T2_W], [Q2_H, Q1_W], [Q2_H, Q2_W], [Q2_H, Q3_W], [Q3_H, T1_W], [Q3_H, T2_W]]
186 | fv_sp = [[int(v[0] * GT_H), int(v[1] * GT_W)] for v in fv_sp]
187 | random.shuffle(fv_sp)
188 | else:
189 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]]
190 |
191 | fv_sp = torch.tensor(fv_sp)
192 | if torch.is_tensor(GT_imgs):
193 | FV_imgs = []
194 | Ref_sps = []
195 | for t in range(len(GT_imgs)):
196 | #### With padding ####
197 | Ref_sp = torch.zeros_like(GT_imgs[t])
198 | Ref_sp[:, fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W] = 1
199 | Ref = GT_imgs[t] * Ref_sp
200 | FV_imgs.append(Ref)
201 | Ref_sps.append(Ref_sp)
202 | #### Without padding ####
203 | # FV_imgs.append(GT_imgs[t][:, fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W].clone())
204 | # FV_imgs = torch.stack(FV_imgs, dim=0)
205 | else:
206 | FV_imgs = []
207 | Ref_sps = []
208 | for t in range(len(GT_imgs)):
209 | #### With padding ####
210 | # Ref_sp = np.zeros_like(GT_imgs[t])
211 | H, W, C = GT_imgs[t].shape
212 | Ref_sp = np.zeros((H, W, 1))
213 | Ref_sp[fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W, :] = 1
214 | Ref = GT_imgs[t] * Ref_sp
215 | FV_imgs.append(Ref)
216 | Ref_sps.append(Ref_sp)
217 | #### Without padding ####
218 | # FV_imgs.append(GT_imgs[t][fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W, :].copy())
219 | # FV_imgs = np.stack(FV_imgs, dim=0)
220 |
221 | return FV_imgs, Ref_sps, fv_sp
222 | # return FV_imgs, fv_sp
223 |
224 | class TrainSet(data.Dataset):
225 | '''
226 | Reading the training Vimeo dataset
227 | key example: train/00001/0001/im1.png
228 | GT: Ground-Truth;
229 | LQ: Low-Quality, e.g., low-resolution frames
230 | support reading N HR frames, N = 3, 5, 7
231 | '''
232 | def __init__(self, args):
233 | super(TrainSet, self).__init__()
234 | self.args = args
235 |
236 | self.LR_dir_list = []
237 | self.GT_dir_list = []
238 | GT_list = open(os.path.join(args.dataset_dir, 'sep_trainlist.txt'), 'r')
239 | LR_root = args.dataset_dir.replace('90k', '90k_BD')
240 | for line in GT_list.readlines():
241 | self.GT_dir_list.append(os.path.join(args.dataset_dir, 'sequences', line.strip()))
242 | self.LR_dir_list.append(os.path.join(LR_root, 'sequences', line.strip()))
243 |
244 | self.transform = Compose([ToTensor()])
245 |
246 | def __getitem__(self, index):
247 | #### Configs
248 | scale = self.args.scale
249 | GT_size = self.args.GT_size
250 | LR_size = GT_size // scale
251 | FV_size = self.args.FV_size
252 |
253 | #### Bicubic downsampling
254 | ### GT
255 | GT_imgfiles = sorted(os.listdir(self.GT_dir_list[index]))
256 | GT_imgs = [np.array(PIL.Image.open(os.path.join(self.GT_dir_list[index], img))) for img in GT_imgfiles]
257 |
258 | ### LR and LR_sr
259 | H, W, C = GT_imgs[0].shape
260 | LR_imgs = [np.array(PIL.Image.fromarray(img).resize((W // scale, H // scale), PIL.Image.BICUBIC)) for img in GT_imgs]
261 |
262 | #### Random cropping
263 | H, W, C = LR_imgs[0].shape
264 | rnd_h = random.randint(0, max(0, H - LR_size))
265 | rnd_w = random.randint(0, max(0, W - LR_size))
266 | LR_imgs = [v[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] for v in LR_imgs]
267 |
268 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
269 | GT_imgs = [v[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] for v in GT_imgs]
270 |
271 | ### Ref(FV) and Ref_sr
272 | Ref, Ref_sp, _ = fovea_generator(GT_imgs, method='Nanascan', FV_HW=(FV_size, FV_size))
273 |
274 | #### Stacking
275 | GT_imgs = np.stack(GT_imgs, axis=0)
276 | LR_imgs = np.stack(LR_imgs, axis=0)
277 | Ref = np.stack(Ref, axis=0)
278 | Ref_sp = np.stack(Ref_sp, axis=0)
279 |
280 | GT_imgs = GT_imgs.astype(np.float32) / 255.
281 | LR_imgs = LR_imgs.astype(np.float32) / 255.
282 | Ref = Ref.astype(np.float32) / 255.
283 | Ref_sp = Ref_sp.astype(np.bool_)
284 |
285 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float()
286 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float()
287 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float()
288 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2)))).float()
289 |
290 | if torch.rand(1) < 0.5:
291 | GT_imgs = F.hflip(GT_imgs)
292 | LR_imgs = F.hflip(LR_imgs)
293 | Ref = F.hflip(Ref)
294 | Ref_sp = F.hflip(Ref_sp)
295 |
296 | if torch.rand(1) < 0.5:
297 | GT_imgs = F.vflip(GT_imgs)
298 | LR_imgs = F.vflip(LR_imgs)
299 | Ref = F.vflip(Ref)
300 | Ref_sp = F.vflip(Ref_sp)
301 |
302 | return {'LR': LR_imgs,
303 | 'HR': GT_imgs,
304 | 'Ref': Ref,
305 | 'Ref_sp': Ref_sp}
306 |
307 | def __len__(self):
308 | return len(self.GT_dir_list)
309 |
310 | class EvalSet(data.Dataset):
311 | '''
312 | Reading the training Vimeo dataset
313 | key example: train/00001/0001/im1.png
314 | GT: Ground-Truth;
315 | LQ: Low-Quality, e.g., low-resolution frames
316 | support reading N HR frames, N = 3, 5, 7
317 | '''
318 | def __init__(self, args):
319 | super(EvalSet, self).__init__()
320 | self.args = args
321 |
322 | self.LR_dir_list = []
323 | self.GT_dir_list = []
324 | GT_list = open(os.path.join(args.dataset_dir, 'sep_testlist.txt'), 'r')
325 | LR_root = args.dataset_dir.replace('90k', '90k_BD')
326 | for line in GT_list.readlines():
327 | self.GT_dir_list.append(os.path.join(args.dataset_dir, 'sequences', line.strip()))
328 | self.LR_dir_list.append(os.path.join(LR_root, 'sequences', line.strip()))
329 |
330 | self.transform = Compose([ToTensor()])
331 |
332 | def __getitem__(self, index):
333 | #### Configs
334 | scale = self.args.scale
335 | GT_size = self.args.GT_size
336 | LR_size = GT_size // scale
337 | FV_size = self.args.FV_size
338 |
339 | #### Bicubic downsampling
340 | ### GT
341 | GT_imgfiles = sorted(os.listdir(self.GT_dir_list[index]))
342 | GT_imgs = [np.array(PIL.Image.open(os.path.join(self.GT_dir_list[index], img))) for img in GT_imgfiles]
343 |
344 | ### LR and LR_sr
345 | H, W, C = GT_imgs[0].shape
346 | LR_imgs = [np.array(PIL.Image.fromarray(img).resize((W // scale, H // scale), PIL.Image.BICUBIC)) for img in GT_imgs]
347 |
348 | ### Ref(FV) and Ref_sr
349 | Ref, Ref_sp, _ = fovea_generator(GT_imgs, method='Nanascan', FV_HW=(FV_size, FV_size))
350 |
351 | #### Stacking
352 | GT_imgs = np.stack(GT_imgs, axis=0)
353 | LR_imgs = np.stack(LR_imgs, axis=0)
354 | Ref = np.stack(Ref, axis=0)
355 | Ref_sp = np.stack(Ref_sp, axis=0)
356 |
357 | #### Scaling
358 | GT_imgs = GT_imgs.astype(np.float32) / 255.
359 | LR_imgs = LR_imgs.astype(np.float32) / 255.
360 | Ref = Ref.astype(np.float32) / 255.
361 | Ref_sp = Ref_sp.astype(np.bool_)
362 |
363 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float()
364 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float()
365 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float()
366 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2))))
367 |
368 | return {'LR': LR_imgs,
369 | 'HR': GT_imgs,
370 | 'Ref': Ref,
371 | 'Ref_sp': Ref_sp}
372 |
373 | def __len__(self):
374 | return len(self.GT_dir_list)
375 |
376 | class TestSet(data.Dataset):
377 | '''
378 | Reading the training Vimeo dataset
379 | key example: train/00001/0001/im1.png
380 | GT: Ground-Truth;
381 | LQ: Low-Quality, e.g., low-resolution frames
382 | support reading N HR frames, N = 3, 5, 7
383 | '''
384 | def __init__(self, args):
385 | super(TestSet, self).__init__()
386 | self.args = args
387 |
388 | self.LR_dir_list = []
389 | self.GT_dir_list = []
390 | GT_list = open(os.path.join(args.dataset_dir, 'slow_testset.txt'), 'r')
391 | LR_root = args.dataset_dir.replace('90k', '90k_LR')
392 | for line in GT_list.readlines():
393 | self.GT_dir_list.append(os.path.join(args.dataset_dir, 'sequences', line.strip()))
394 | self.LR_dir_list.append(os.path.join(LR_root, 'sequences', line.strip()))
395 |
396 | def __getitem__(self, index):
397 | #### Configs
398 | scale = self.args.scale
399 |
400 | #### Bicubic downsampling
401 | ### GT
402 | GT_imgfiles = sorted(os.listdir(self.GT_dir_list[index]))
403 | GT_imgs = [np.array(PIL.Image.open(os.path.join(self.GT_dir_list[index], img))) for img in GT_imgfiles]
404 | GT_H, GT_W = GT_imgs[0].shape[:2]
405 | LR_H, LR_W = GT_H // scale, GT_W // scale
406 | FV_size = self.args.FV_size
407 | ### LR and LR_sr
408 | LR_imgs = [np.array(PIL.Image.fromarray(img).resize((LR_W, LR_H), PIL.Image.BICUBIC)) for img in GT_imgs]
409 |
410 | ### Ref(FV) and Ref_sr
411 | Ref, Ref_sp, fv_sp = fovea_generator(GT_imgs, method='Hscan', step=0.2, FV_HW=(FV_size, FV_size))
412 |
413 | #### Stacking
414 | GT_imgs = np.stack(GT_imgs, axis=0)
415 | LR_imgs = np.stack(LR_imgs, axis=0)
416 | Ref = np.stack(Ref, axis=0)
417 | Ref_sp = np.stack(Ref_sp, axis=0)
418 |
419 | #### Scaling
420 | GT_imgs = GT_imgs.astype(np.float32) / 255.
421 | LR_imgs = LR_imgs.astype(np.float32) / 255.
422 | Ref = Ref.astype(np.float32) / 255.
423 | Ref_sp = Ref_sp.astype(np.bool_)
424 |
425 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float()
426 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float()
427 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float()
428 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2))))
429 |
430 | return {'LR': LR_imgs,
431 | 'HR': GT_imgs,
432 | 'Ref': Ref,
433 | 'Ref_sp': Ref_sp,
434 | 'FV_sp': fv_sp}
435 |
436 | def __len__(self):
437 | return len(self.GT_dir_list)
--------------------------------------------------------------------------------
/demo/sigma10.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/demo/sigma10.gif
--------------------------------------------------------------------------------
/demo/sigma100.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/demo/sigma100.gif
--------------------------------------------------------------------------------
/demo/sigma50.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/demo/sigma50.gif
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | ### evaluation
2 | python3 main.py --save_dir ./eval/REDS/FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_cra \
3 | --reset True \
4 | --num_gpu 1 \
5 | --gpu_id 0 \
6 | --log_file_name eval.log \
7 | --eval True \
8 | --eval_save_results True \
9 | --num_workers 1 \
10 | --scale 8 \
11 | --cra true \
12 | --mrcf true \
13 | --hr_dcn true \
14 | --offset_prop true \
15 | --N_frames 15 \
16 | --FV_size 96 \
17 | --dataset Reds \
18 | --dataset_dir /DATA/REDS_sharp/ \
19 | --model_path ./train/REDS/FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_cra/model/ \
20 | --visdom_port 8803 \
21 | --visdom_view 0301_FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_cra
--------------------------------------------------------------------------------
/gen_img.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 |
4 | model_code = 13
5 | hr_dcn = False
6 | offset_prop = False
7 | split_ratio = 3
8 | model_name = 'FVSR_x8_simple_v{}_hrdcn_{}_offsetprop_{}_fnet{}'.format(model_code, 'y' if hr_dcn else 'n',
9 | 'y' if offset_prop else 'n',
10 | '_{}outof4'.format(4-split_ratio) if model_code == 18 else '')
11 | # print('Current model name: {}'.format(model_name))
12 |
13 | dir_root = 'test_png'
14 | gt_root = '{}/GroundTruth/'.format(dir_root)
15 | model_names = os.listdir(dir_root)
16 | for model_name in model_names:
17 | if model_name == 'GroundTruth' or model_name == 'eval_video' or model_name == 'results':
18 | continue
19 | print('Current model name: {}'.format(model_name))
20 | img_root = '{}/{}/'.format(dir_root, model_name)
21 | save_root = '{}/{}/'.format(dir_root, 'results')
22 | if not os.path.exists(save_root):
23 | os.makedirs(save_root)
24 |
25 | #### 000 Past Foveated Region
26 | video_num = 0
27 | img_0_num = 66
28 | img_1_num = 75
29 | img_dir_path = os.path.join(gt_root, str(video_num))
30 | save_path = os.path.join(save_root, model_name, 'pastfv')
31 | if not os.path.exists(save_path):
32 | os.makedirs(save_path)
33 | begin_x_0 = 20
34 | begin_y_0 = 350
35 | begin_x_1 = 550
36 | begin_y_1 = 350
37 | begin_x_2 = 5
38 | begin_y_2 = 210
39 | img_w = 350
40 | img_h = 350
41 | img_w_0 = 80
42 | img_h_0 = 80
43 | hr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num))
44 | hr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num))
45 | hr_img_0 = cv2.imread(hr_img_0)
46 | hr_img_1 = cv2.imread(hr_img_1)
47 | GT_H, GT_W, _ = hr_img_0.shape
48 | cv2.rectangle(hr_img_0, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), (255, 51, 153), 3)
49 | cv2.rectangle(hr_img_1, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), ( 51, 255, 153), 3)
50 | cv2.rectangle(hr_img_1, (begin_x_1, begin_y_1), (begin_x_1+img_w, begin_y_1+img_h), ( 51, 153, 255), 3)
51 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_0_num)), hr_img_0)
52 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_1_num)), hr_img_1)
53 |
54 | dirs = os.listdir(os.path.join(img_root, str(video_num)))
55 | for dir in dirs:
56 | if dir == 'traj.png':
57 | continue
58 | img_dir_path = os.path.join(img_root, str(video_num), dir)
59 | sr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num))
60 | sr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num))
61 | sr_img_0 = cv2.imread(sr_img_0)
62 | sr_img_1 = cv2.imread(sr_img_1)
63 | sr_img_2 = sr_img_1.copy()
64 | if dir == 'results':
65 | sr_img_0 = sr_img_0[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :]
66 | sr_img_1 = sr_img_1[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :]
67 | sr_img_2 = sr_img_2[begin_y_1:begin_y_1+img_h, begin_x_1:begin_x_1+img_w, :]
68 | sr_img_3 = sr_img_1[begin_y_2:begin_y_2+img_h_0, begin_x_2:begin_x_2+img_w_0, :].copy()
69 | sr_img_4 = sr_img_1.copy()
70 | cv2.rectangle(sr_img_4, (begin_x_2, begin_y_2), (begin_x_2+img_w_0, begin_y_2+img_h_0), (255, 51, 153), 3)
71 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_3.png'.format(img_1_num, dir)), sr_img_3)
72 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_1_line.png'.format(img_1_num, dir)), sr_img_4)
73 | else:
74 | H, W, _ = sr_img_0.shape
75 | a = H / GT_H
76 | b = W / GT_W
77 | sr_img_0 = sr_img_0[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :]
78 | sr_img_1 = sr_img_1[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :]
79 | sr_img_2 = sr_img_2[int(begin_y_1*a):int((begin_y_1+img_h)*a), int(begin_x_1*b):int((begin_x_1+img_w)*b), :]
80 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_0_num, dir)), sr_img_0)
81 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_1.png'.format(img_1_num, dir)), sr_img_1)
82 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_2.png'.format(img_1_num, dir)), sr_img_2)
83 |
84 | #### 011 Whole Region
85 | video_num = 11
86 | img_0_num = 30
87 | img_1_num = 36
88 | img_dir_path = os.path.join(gt_root, str(video_num))
89 | save_path = os.path.join(save_root, model_name, 'whole')
90 | if not os.path.exists(save_path):
91 | os.makedirs(save_path)
92 | begin_x_0 = 200
93 | begin_y_0 = 100
94 | begin_x_1 = 500
95 | begin_y_1 = 100
96 | begin_x_2 = 360
97 | begin_y_2 = 100
98 | begin_x_3 = 60
99 | begin_y_3 = 90
100 | begin_x_4 = 60
101 | begin_y_4 = 40
102 | begin_x_5 = 200
103 | begin_y_5 = 60
104 | img_w = 750
105 | img_h = 450
106 | img_w_ = 360
107 | img_h_ = 216
108 | img_w__ = 100
109 | img_h__ = 120
110 | hr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num))
111 | hr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num))
112 | hr_img_0 = cv2.imread(hr_img_0)
113 | hr_img_1 = cv2.imread(hr_img_1)
114 | hr_img_2 = hr_img_0.copy()
115 | hr_img_3 = hr_img_1.copy()
116 | GT_H, GT_W, _ = hr_img_0.shape
117 | cv2.rectangle(hr_img_0, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), (255, 51, 153), 3)
118 | cv2.rectangle(hr_img_1, (begin_x_1, begin_y_1), (begin_x_1+img_w, begin_y_1+img_h), ( 51, 153, 255), 3)
119 | cv2.rectangle(hr_img_2, (begin_x_2, begin_y_2), (begin_x_2+img_w, begin_y_2+img_h), (255, 51, 153), 3)
120 | cv2.rectangle(hr_img_3, (begin_x_2, begin_y_2), (begin_x_2+img_w, begin_y_2+img_h), ( 51, 153, 255), 3)
121 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_0_num)), hr_img_0)
122 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_1_num)), hr_img_1)
123 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line_2.png'.format(img_0_num)), hr_img_2)
124 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line_2.png'.format(img_1_num)), hr_img_3)
125 |
126 | dirs = os.listdir(os.path.join(img_root, str(video_num)))
127 | for dir in dirs:
128 | if dir == 'traj.png':
129 | continue
130 | img_dir_path = os.path.join(img_root, str(video_num), dir)
131 | sr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num))
132 | sr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num))
133 | sr_img_0 = cv2.imread(sr_img_0)
134 | sr_img_1 = cv2.imread(sr_img_1)
135 | sr_img_2 = sr_img_0.copy()
136 | sr_img_3 = sr_img_1.copy()
137 | if dir == 'results':
138 | sr_img_0 = sr_img_0[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :]
139 | sr_img_1 = sr_img_1[begin_y_1:begin_y_1+img_h, begin_x_1:begin_x_1+img_w, :]
140 | sr_img_2 = sr_img_2[begin_y_2:begin_y_2+img_h, begin_x_2:begin_x_2+img_w, :]
141 | sr_img_3 = sr_img_3[begin_y_2:begin_y_2+img_h, begin_x_2:begin_x_2+img_w, :]
142 | cv2.rectangle(sr_img_2, (begin_x_3, begin_y_3), (begin_x_3+img_w_, begin_y_3+img_h_), ( 51, 153, 255), 3)
143 | cv2.rectangle(sr_img_3, (begin_x_3, begin_y_3), (begin_x_3+img_w_, begin_y_3+img_h_), ( 51, 153, 255), 3)
144 | sr_img_4 = sr_img_2[begin_y_3:begin_y_3+img_h_, begin_x_3:begin_x_3+img_w_, :]
145 | sr_img_5 = sr_img_3[begin_y_3:begin_y_3+img_h_, begin_x_3:begin_x_3+img_w_, :]
146 | sr_img_6 = sr_img_5[begin_y_4:begin_y_4+img_h__, begin_x_4:begin_x_4+img_w__, :].copy()
147 | sr_img_7 = sr_img_5[begin_y_5:begin_y_5+img_h__, begin_x_5:begin_x_5+img_w__, :].copy()
148 | cv2.rectangle(sr_img_5, (begin_x_4, begin_y_4), (begin_x_4+img_w__, begin_y_4+img_h__), ( 51, 153, 255), 3)
149 | cv2.rectangle(sr_img_5, (begin_x_5, begin_y_5), (begin_x_5+img_w__, begin_y_5+img_h__), ( 51, 153, 255), 3)
150 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_3.png'.format(img_0_num, dir)), sr_img_4)
151 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_3.png'.format(img_1_num, dir)), sr_img_5)
152 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_4.png'.format(img_1_num, dir)), sr_img_6)
153 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_5.png'.format(img_1_num, dir)), sr_img_7)
154 | else:
155 | H, W, _ = sr_img_0.shape
156 | a = H / GT_H
157 | b = W / GT_W
158 | sr_img_0 = sr_img_0[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :]
159 | sr_img_1 = sr_img_1[int(begin_y_1*a):int((begin_y_1+img_h)*a), int(begin_x_1*b):int((begin_x_1+img_w)*b), :]
160 | sr_img_2 = sr_img_2[int(begin_y_2*a):int((begin_y_2+img_h)*a), int(begin_x_2*b):int((begin_x_2+img_w)*b), :]
161 | sr_img_3 = sr_img_3[int(begin_y_2*a):int((begin_y_2+img_h)*a), int(begin_x_2*b):int((begin_x_2+img_w)*b), :]
162 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_0_num, dir)), sr_img_0)
163 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_1_num, dir)), sr_img_1)
164 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_2.png'.format(img_0_num, dir)), sr_img_2)
165 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_2.png'.format(img_1_num, dir)), sr_img_3)
166 |
167 | #### 015 Title
168 | video_num = 15
169 | img_0_num = 31
170 | img_1_num = 36
171 | img_2_num = 43
172 | img_dir_path = os.path.join(gt_root, str(video_num))
173 | save_path = os.path.join(save_root, model_name, 'title')
174 | if not os.path.exists(save_path):
175 | os.makedirs(save_path)
176 | begin_x_0 = 350
177 | begin_y_0 = 140
178 | begin_x_1 = 700
179 | begin_y_1 = 140
180 | img_w = 500
181 | img_h = 300
182 | hr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num))
183 | hr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num))
184 | hr_img_2 = os.path.join(img_dir_path, '{:03d}.png'.format(img_2_num))
185 | hr_img_0 = cv2.imread(hr_img_0)
186 | hr_img_1 = cv2.imread(hr_img_1)
187 | hr_img_2 = cv2.imread(hr_img_2)
188 | GT_H, GT_W, _ = hr_img_0.shape
189 | cv2.rectangle(hr_img_0, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), ( 255, 51, 153), 3)
190 | cv2.rectangle(hr_img_1, (begin_x_1, begin_y_1), (begin_x_1+img_w, begin_y_1+img_h), ( 51, 153, 255), 3)
191 | cv2.rectangle(hr_img_2, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), ( 51, 153, 255), 3)
192 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_0_num)), hr_img_0)
193 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_1_num)), hr_img_1)
194 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_2_num)), hr_img_2)
195 |
196 | dirs = os.listdir(os.path.join(img_root, str(video_num)))
197 | for dir in dirs:
198 | if dir == 'traj.png':
199 | continue
200 | img_dir_path = os.path.join(img_root, str(video_num), dir)
201 | sr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num))
202 | sr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num))
203 | sr_img_2 = os.path.join(img_dir_path, '{:03d}.png'.format(img_2_num))
204 | sr_img_0 = cv2.imread(sr_img_0)
205 | sr_img_1 = cv2.imread(sr_img_1)
206 | sr_img_2 = cv2.imread(sr_img_2)
207 | if dir == 'results':
208 | sr_img_0 = sr_img_0[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :]
209 | sr_img_1 = sr_img_1[begin_y_1:begin_y_1+img_h, begin_x_1:begin_x_1+img_w, :]
210 | sr_img_2 = sr_img_2[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :]
211 | else:
212 | H, W, _ = sr_img_0.shape
213 | a = H / GT_H
214 | b = W / GT_W
215 | sr_img_0 = sr_img_0[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :]
216 | sr_img_1 = sr_img_1[int(begin_y_1*a):int((begin_y_1+img_h)*a), int(begin_x_1*b):int((begin_x_1+img_w)*b), :]
217 | sr_img_2 = sr_img_2[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :]
218 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_0_num, dir)), sr_img_0)
219 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_1_num, dir)), sr_img_1)
220 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_2_num, dir)), sr_img_2)
221 |
222 | #### 020 Gaussian
223 | video_num = 20
224 | img_0_num = 99
225 | img_dir_path = os.path.join(img_root, str(video_num), 'results')
226 | save_path = os.path.join(save_root, model_name, 'title')
227 | if not os.path.exists(save_path):
228 | os.makedirs(save_path)
229 | begin_x_0 = 320
230 | begin_y_0 = 180
231 | img_w = 640
232 | img_h = 360
233 | hr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num))
234 | hr_img_0 = cv2.imread(hr_img_0)
235 | GT_H, GT_W, _ = hr_img_0.shape
236 | cv2.rectangle(hr_img_0, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), ( 255, 51, 153), 3)
237 | cv2.imwrite(os.path.join(save_path, '{:03d}_sr_line.png'.format(img_0_num)), hr_img_0)
238 |
239 | dirs = os.listdir(os.path.join(img_root, str(video_num)))
240 | for dir in dirs:
241 | if dir == 'traj.png':
242 | continue
243 | img_dir_path = os.path.join(img_root, str(video_num), dir)
244 | sr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num))
245 | sr_img_0 = cv2.imread(sr_img_0)
246 | if dir == 'results':
247 | sr_img_0 = sr_img_0[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :]
248 | else:
249 | H, W, _ = sr_img_0.shape
250 | a = H / GT_H
251 | b = W / GT_W
252 | sr_img_0 = sr_img_0[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :]
253 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_0_num, dir)), sr_img_0)
--------------------------------------------------------------------------------
/gen_video.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import cv2
6 |
7 | if __name__ == '__main__':
8 | # fourcc = cv2.VideoWriter_fourcc(*'MP4V')
9 | # out = cv2.VideoWriter('test_video_arcane.mp4', fourcc, 20.0, (1920, 1080))
10 | save_dir = 'old_tree_x1'
11 | if not os.path.exists(save_dir):
12 | os.mkdir(save_dir)
13 | cap = cv2.VideoCapture('old_tree.mp4')
14 | gen_frames = []
15 | n = 0
16 | scale = 8
17 | while cap.isOpened():
18 | ret, frame = cap.read()
19 | if not isinstance(frame, np.ndarray):
20 | break
21 | H, W, C = frame.shape
22 | n += 1
23 | # gen_frames.append(frame)
24 | # if H != 1080 or W != 1920:
25 | # frame = cv2.resize(frame, (1920, 1080), interpolation=cv2.INTER_CUBIC)
26 | frame = frame[(H - 1080)//2:(H + 1080)//2, (W - 1920)//2:(W + 1920)//2, :]
27 | # frame = cv2.resize(frame, (W//scale, H//scale), interpolation=cv2.INTER_CUBIC)
28 | # cv2.putText(frame, str(n), (10, 120), cv2.FONT_HERSHEY_DUPLEX, 1, (0, 255, 255), 1, cv2.LINE_AA)
29 | # cv2.imshow('frame', frame)
30 | # if n <= 800:
31 | # continue
32 | # if n > 1200:
33 | # break
34 | # if cv2.waitKey(1) == ord('q'):
35 | # break
36 | print('{}\r'.format(n), end='')
37 | cv2.imwrite(os.path.join(save_dir, '{:08d}.png'.format(n)), frame)
38 |
39 | cap.release()
40 | cv2.destroyAllWindows()
41 |
42 | # print(n)
43 | # with get_writer('fungtsun_confuse.gif', mode="I", fps=20) as writer:
44 | # for i in range(n):
45 | # if i > 4:
46 | # out.write(gen_frames[i])
47 | # writer.append_data(gen_frames[i][:,:,::-1])
48 |
--------------------------------------------------------------------------------
/gif_combine.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image,ImageSequence
3 | import numpy as np
4 | import cv2
5 | from imageio import imread, imsave, get_writer
6 |
7 | im_1 = Image.open('DCN_3and1.gif')
8 | im_2 = Image.open('DCN_4.gif')
9 |
10 | def iter_frames(im):
11 | try:
12 | i= 0
13 | while 1:
14 | im.seek(i)
15 | imframe = im.copy()
16 | if i == 0:
17 | palette = imframe.getpalette()
18 | else:
19 | imframe.putpalette(palette)
20 | yield imframe
21 | i += 1
22 | except EOFError:
23 | pass
24 |
25 | gif_1 = []
26 | gif_2 = []
27 | n_frames = 0
28 | for i,frame in enumerate(ImageSequence.Iterator(im_1),1):
29 | frame = frame.convert('RGB')
30 | # frame.save(os.path.join('test.png'))
31 | # gif_1.append(np.array(Image.open('test.png')))
32 | gif_1.append(np.array(frame))
33 | n_frames += 1
34 | # cv2.imshow('image', gif_1[-1][:,:,::-1])
35 | # cv2.waitKey(10)
36 |
37 | for i,frame in enumerate(ImageSequence.Iterator(im_2),1):
38 | frame = frame.convert('RGB')
39 | frame.save(os.path.join('test.png'))
40 | # gif_2.append(np.array(Image.open('test.png')))
41 | gif_2.append(np.array(frame))
42 | # cv2.imshow('image', gif_2[-1][:,:,::-1])
43 | # cv2.waitKey(10)
44 |
45 | with get_writer('test_1.gif', mode="I", fps=7) as writer:
46 | for n in range(n_frames - 10):
47 | writer.append_data(np.concatenate((gif_2[n][260:620,270:590,:],gif_1[n][260:620,270:590,:]), axis=1))
48 | cv2.imshow('image', np.concatenate((gif_2[n][260:620,270:590,:],gif_1[n][260:620,270:590,:]), axis=1))
49 | cv2.waitKey(100)
50 |
51 | # with get_writer('test_1.gif', mode="I", fps=7) as writer:
52 | # for n in range(n_frames):
53 | # writer.append_data(gif_1[n][260:620,220:640,:])
54 |
--------------------------------------------------------------------------------
/loss/loss.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 |
7 | def reduce_loss(loss, reduction):
8 | """Reduce loss as specified.
9 |
10 | Args:
11 | loss (Tensor): Elementwise loss tensor.
12 | reduction (str): Options are "none", "mean" and "sum".
13 |
14 | Returns:
15 | Tensor: Reduced loss tensor.
16 | """
17 | reduction_enum = F._Reduction.get_enum(reduction)
18 | # none: 0, elementwise_mean:1, sum: 2
19 | if reduction_enum == 0:
20 | return loss
21 | if reduction_enum == 1:
22 | return loss.mean()
23 |
24 | return loss.sum()
25 |
26 | def mask_reduce_loss(loss, weight=None, reduction='mean', sample_wise=False):
27 | """Apply element-wise weight and reduce loss.
28 |
29 | Args:
30 | loss (Tensor): Element-wise loss.
31 | weight (Tensor): Element-wise weights. Default: None.
32 | reduction (str): Same as built-in losses of PyTorch. Options are
33 | "none", "mean" and "sum". Default: 'mean'.
34 | sample_wise (bool): Whether calculate the loss sample-wise. This
35 | argument only takes effect when `reduction` is 'mean' and `weight`
36 | (argument of `forward()`) is not None. It will first reduces loss
37 | with 'mean' per-sample, and then it means over all the samples.
38 | Default: False.
39 |
40 | Returns:
41 | Tensor: Processed loss values.
42 | """
43 | # if weight is specified, apply element-wise weight
44 | if weight is not None:
45 | assert weight.dim() == loss.dim()
46 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
47 | loss = loss * weight
48 |
49 | # if weight is not specified or reduction is sum, just reduce the loss
50 | if weight is None or reduction == 'sum':
51 | loss = reduce_loss(loss, reduction)
52 | # if reduction is mean, then compute mean over masked region
53 | elif reduction == 'mean':
54 | # expand weight from N1HW to NCHW
55 | if weight.size(1) == 1:
56 | weight = weight.expand_as(loss)
57 | # small value to prevent division by zero
58 | eps = 1e-12
59 |
60 | # perform sample-wise mean
61 | if sample_wise:
62 | weight = weight.sum(dim=[1, 2, 3], keepdim=True) # NCHW to N111
63 | loss = (loss / (weight + eps)).sum() / weight.size(0)
64 | # perform pixel-wise mean
65 | else:
66 | loss = loss.sum() / (weight.sum() + eps)
67 |
68 | return loss
69 |
70 | def masked_loss(loss_func):
71 | """Create a masked version of a given loss function.
72 |
73 | To use this decorator, the loss function must have the signature like
74 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
75 | element-wise loss without any reduction. This decorator will add weight
76 | and reduction arguments to the function. The decorated function will have
77 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
78 | avg_factor=None, **kwargs)`.
79 |
80 | :Example:
81 |
82 | >>> import torch
83 | >>> @masked_loss
84 | >>> def l1_loss(pred, target):
85 | >>> return (pred - target).abs()
86 |
87 | >>> pred = torch.Tensor([0, 2, 3])
88 | >>> target = torch.Tensor([1, 1, 1])
89 | >>> weight = torch.Tensor([1, 0, 1])
90 |
91 | >>> l1_loss(pred, target)
92 | tensor(1.3333)
93 | >>> l1_loss(pred, target, weight)
94 | tensor(1.5000)
95 | >>> l1_loss(pred, target, reduction='none')
96 | tensor([1., 1., 2.])
97 | >>> l1_loss(pred, target, weight, reduction='sum')
98 | tensor(3.)
99 | """
100 |
101 | @functools.wraps(loss_func)
102 | def wrapper(pred,
103 | target,
104 | weight=None,
105 | reduction='mean',
106 | sample_wise=False,
107 | **kwargs):
108 | # get element-wise loss
109 | loss = loss_func(pred, target, **kwargs)
110 | loss = mask_reduce_loss(loss, weight, reduction, sample_wise)
111 | return loss
112 |
113 | return wrapper
114 |
115 | @masked_loss
116 | def charbonnier_loss(pred, target, eps=1e-12):
117 | """Charbonnier loss.
118 | Args:
119 | pred (Tensor): Prediction Tensor with shape (n, c, h, w).
120 | target ([type]): Target Tensor with shape (n, c, h, w).
121 | Returns:
122 | Tensor: Calculated Charbonnier loss.
123 | """
124 | return torch.sqrt((pred - target)**2 + eps)
125 |
126 | class CharbonnierLoss(nn.Module):
127 | """Charbonnier loss (one variant of Robust L1Loss, a differentiable
128 | variant of L1Loss).
129 | Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
130 | Super-Resolution".
131 | Args:
132 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
133 | reduction (str): Specifies the reduction to apply to the output.
134 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
135 | sample_wise (bool): Whether calculate the loss sample-wise. This
136 | argument only takes effect when `reduction` is 'mean' and `weight`
137 | (argument of `forward()`) is not None. It will first reduces loss
138 | with 'mean' per-sample, and then it means over all the samples.
139 | Default: False.
140 | eps (float): A value used to control the curvature near zero.
141 | Default: 1e-12.
142 | """
143 |
144 | def __init__(self,
145 | loss_weight=1.0,
146 | reduction='mean',
147 | sample_wise=False,
148 | eps=1e-12):
149 | super().__init__()
150 | if reduction not in ['none', 'mean', 'sum']:
151 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
152 | f'Supported ones are: {_reduction_modes}')
153 |
154 | self.loss_weight = loss_weight
155 | self.reduction = reduction
156 | self.sample_wise = sample_wise
157 | self.eps = eps
158 |
159 | def forward(self, pred, target, weight=None, **kwargs):
160 | """Forward Function.
161 | Args:
162 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
163 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
164 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
165 | weights. Default: None.
166 | """
167 | return self.loss_weight * charbonnier_loss(
168 | pred,
169 | target,
170 | weight,
171 | eps=self.eps,
172 | reduction=self.reduction,
173 | sample_wise=self.sample_wise)
174 |
175 |
176 |
177 | def get_loss_dict(args, logger):
178 | loss = {}
179 | if (abs(args.rec_w - 0) <= 1e-8):
180 | raise SystemExit('NotImplementError: ReconstructionLoss must exist!')
181 | else:
182 | loss['cb_loss'] = CharbonnierLoss()
183 |
184 | return loss
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #### take reference from
2 |
3 | from option import args
4 | from utils import mkExpDir
5 | from dataset import dataloader
6 | from model import CRFP
7 | from loss.loss import get_loss_dict
8 | from trainer import Trainer
9 |
10 | import math
11 | import os
12 | import time
13 | import torch
14 | import torch.nn as nn
15 | import warnings
16 | from tqdm import tqdm
17 | warnings.filterwarnings('ignore')
18 |
19 | if __name__ == '__main__':
20 |
21 | ### make save_dir
22 | _logger = mkExpDir(args)
23 |
24 | ### device and model
25 | if args.num_gpu == 1:
26 | device = torch.device('cpu') if args.cpu else torch.device('cuda:{}'.format(args.gpu_id))
27 | else:
28 | device = torch.device('cpu') if args.cpu else torch.device('cuda')
29 |
30 | # _model = CRFP.BasicFVSR(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
31 | # _model = CRFP.CRFP_simple_noDCN(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
32 | # _model = CRFP.CRFP_simple(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
33 | # _model = CRFP.CRFP(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
34 | _model = CRFP.CRFP_DSV(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
35 | # _model = CRFP.CRFP_DSV_CRA(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
36 |
37 | if ((not args.cpu) and (args.num_gpu > 1)):
38 | _model = nn.DataParallel(_model, list(range(args.num_gpu)))
39 |
40 | ### dataloader of training set and testing set
41 | _dataloader = dataloader.get_dataloader(args)
42 |
43 | ### loss
44 | _loss_all = get_loss_dict(args, _logger)
45 |
46 | ### trainer
47 | t = Trainer(args, _logger, _dataloader, _model, _loss_all)
48 | t.before_run()
49 | ### test / eval / train
50 | if (args.test):
51 | t.load(model_path=args.model_path)
52 | t.test_basicvsr()
53 | elif (args.eval):
54 | model_list = sorted(os.listdir(args.model_path))
55 | # model_list = model_list[::-1]
56 | for idx, m in enumerate(model_list):
57 | t.load(model_path=os.path.join(args.model_path, m))
58 | t.eval_basicvsr(idx)
59 | t.vis_plot_metric('eval')
60 | else:
61 | # t.load(model_path=args.model_path)
62 | for epoch in range(1, args.num_epochs+1):
63 | t.train_basicvsr(current_epoch=epoch)
64 | t.vis_plot_metric('train')
65 | # if (epoch % args.val_every == 0):
66 | # t.eval_basicvsr(current_epoch=epoch)
67 | # t.vis_plot_metric('eval')
68 | torch.cuda.empty_cache()
--------------------------------------------------------------------------------
/model/LTE.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | def pixel_unshuffle(input, downscale_factor):
6 | '''
7 | input: batchSize * c * k*w * k*h
8 | kdownscale_factor: k
9 | batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
10 | '''
11 | c = input.shape[1]
12 |
13 | kernel = torch.zeros(size=[downscale_factor * downscale_factor * c,
14 | 1, downscale_factor, downscale_factor],
15 | device=input.device)
16 | for y in range(downscale_factor):
17 | for x in range(downscale_factor):
18 | kernel[x + y * downscale_factor::downscale_factor*downscale_factor, 0, y, x] = 1
19 | return F.conv2d(input, kernel, stride=downscale_factor, groups=c)
20 |
21 | class PixelUnshuffle(nn.Module):
22 | def __init__(self, downscale_factor):
23 | super(PixelUnshuffle, self).__init__()
24 | self.downscale_factor = downscale_factor
25 | def forward(self, input):
26 | '''
27 | input: batchSize * c * k*w * k*h
28 | kdownscale_factor: k
29 | batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
30 | '''
31 |
32 | return pixel_unshuffle(input, self.downscale_factor)
33 |
34 | class LTE_simple_lr(torch.nn.Module):
35 | def __init__(self, mid_channels):
36 | super(LTE_simple_lr, self).__init__()
37 |
38 | ### use vgg19 weights to initialize
39 | self.slice1 = torch.nn.Sequential(
40 | nn.Conv2d(3, mid_channels, 3, 1, 1),
41 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
42 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1),
43 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
44 | )
45 |
46 | def forward(self, x, islr=False):
47 | x = self.slice1(x)
48 | # x_lv3 = x
49 | # x_lv2 = x
50 | # x_lv1 = x
51 | return None, None, x
52 |
53 | class LTE_simple_hr(torch.nn.Module):
54 | def __init__(self, mid_channels):
55 | super(LTE_simple_hr, self).__init__()
56 |
57 | ### use vgg19 weights to initialize
58 | self.slice1 = torch.nn.Sequential(
59 | nn.Conv2d(6, mid_channels, 3, 1, 1),
60 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
61 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1),
62 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
63 | )
64 | self.slice2 = torch.nn.Sequential(
65 | nn.MaxPool2d(2, 2),
66 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1),
67 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
68 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1),
69 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
70 | )
71 | self.slice3 = torch.nn.Sequential(
72 | nn.MaxPool2d(2, 2),
73 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1),
74 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
75 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1),
76 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
77 | )
78 |
79 | self.conv_lv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
80 | self.conv_lv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
81 | self.conv_lv3 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
82 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
83 |
84 | def forward(self, x, islr=False):
85 | if islr:
86 | x = self.slice1(x)
87 | x_lv3 = self.lrelu(self.conv_lv3(x))
88 | x = self.slice2(x)
89 | x_lv2 = self.lrelu(self.conv_lv2(x))
90 | x = self.slice3(x)
91 | x_lv1 = self.lrelu(self.conv_lv1(x))
92 | else:
93 | x_lv3 = x
94 | x = self.slice2(x)
95 | x_lv2 = x
96 | x = self.slice3(x)
97 | x_lv1 = x
98 | return x_lv1, x_lv2, x_lv3
99 |
100 | class LTE_simple_hr_single(torch.nn.Module):
101 | def __init__(self, mid_channels):
102 | super(LTE_simple_hr_single, self).__init__()
103 |
104 | ### use vgg19 weights to initialize
105 | self.slice1 = torch.nn.Sequential(
106 | nn.Conv2d(6, mid_channels, 3, 1, 1),
107 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
108 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1),
109 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
110 | )
111 |
112 | def forward(self, x, islr=False):
113 | x = self.slice1(x)
114 | # x_lv3 = x
115 | # x_lv2 = x
116 | # x_lv1 = x
117 | return None, None, x
118 |
119 | class LTE_simple_hr_ps(torch.nn.Module):
120 | def __init__(self, mid_channels):
121 | super(LTE_simple_hr_ps, self).__init__()
122 |
123 | ### use vgg19 weights to initialize
124 | self.slice1 = torch.nn.Sequential(
125 | nn.Conv2d(6, mid_channels, 3, 1, 1),
126 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
127 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1),
128 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
129 | )
130 | self.slice2 = torch.nn.Sequential(
131 | PixelUnshuffle(4),
132 | nn.Conv2d(mid_channels*16, mid_channels*4, 3, 1, 1),
133 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
134 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1),
135 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
136 | )
137 | self.slice3 = torch.nn.Sequential(
138 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1),
139 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
140 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1),
141 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
142 | )
143 | self.slice4 = torch.nn.Sequential(
144 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1),
145 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
146 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1),
147 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
148 | )
149 |
150 | self.conv_lv0 = nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1)
151 | self.conv_lv1 = nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1)
152 | self.conv_lv2 = nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1)
153 | self.conv_lv3 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
154 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
155 |
156 | def forward(self, x):
157 | x = self.slice1(x)
158 | x_lv3 = self.lrelu(self.conv_lv3(x))
159 | x = self.slice2(x)
160 | x_lv2 = self.lrelu(self.conv_lv2(x))
161 | x = self.slice3(x)
162 | x_lv1 = self.lrelu(self.conv_lv1(x))
163 | x = self.slice4(x)
164 | x_lv0 = self.lrelu(self.conv_lv0(x))
165 |
166 | return x_lv0, x_lv1, x_lv2, x_lv3
167 |
168 | class LTE_simple_hr_v1(torch.nn.Module):
169 | def __init__(self, mid_channels):
170 | super(LTE_simple_hr_v1, self).__init__()
171 |
172 | ### use vgg19 weights to initialize
173 | self.slice1 = torch.nn.Sequential(
174 | nn.Conv2d(6, mid_channels//4, 3, 1, 1),
175 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
176 | nn.Conv2d(mid_channels//4, mid_channels//4, 3, 1, 1),
177 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
178 | )
179 | self.slice2 = torch.nn.Sequential(
180 | nn.MaxPool2d(2, 2),
181 | nn.Conv2d(mid_channels//4, mid_channels//2, 3, 1, 1),
182 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
183 | nn.Conv2d(mid_channels//2, mid_channels//2, 3, 1, 1),
184 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
185 | )
186 | self.slice3 = torch.nn.Sequential(
187 | nn.MaxPool2d(2, 2),
188 | nn.Conv2d(mid_channels//2, mid_channels//1, 3, 1, 1),
189 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
190 | nn.Conv2d(mid_channels//1, mid_channels//1, 3, 1, 1),
191 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
192 | )
193 |
194 | self.conv_lv3 = nn.Conv2d(mid_channels//4, mid_channels//4, 3, 1, 1)
195 | self.conv_lv2 = nn.Conv2d(mid_channels//2, mid_channels//2, 3, 1, 1)
196 | self.conv_lv1 = nn.Conv2d(mid_channels//1, mid_channels//1, 3, 1, 1)
197 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
198 |
199 | def forward(self, x, islr=False):
200 | if islr:
201 | x = self.slice1(x)
202 | x_lv3 = self.lrelu(self.conv_lv3(x))
203 | x = self.slice2(x)
204 | x_lv2 = self.lrelu(self.conv_lv2(x))
205 | x = self.slice3(x)
206 | x_lv1 = self.lrelu(self.conv_lv1(x))
207 | else:
208 | x_lv3 = x
209 | x = self.slice2(x)
210 | x_lv2 = x
211 | x = self.slice3(x)
212 | x_lv1 = x
213 | return x_lv1, x_lv2, x_lv3
214 |
215 | class LTE_simple_hr_x8(torch.nn.Module):
216 | def __init__(self, mid_channels):
217 | super(LTE_simple_hr_x8, self).__init__()
218 |
219 | ### use vgg19 weights to initialize
220 | self.slice1 = torch.nn.Sequential(
221 | nn.Conv2d(6, 64, 3, 1, 1),
222 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
223 | nn.Conv2d(64, 64, 3, 1, 1),
224 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
225 | )
226 | self.slice2 = torch.nn.Sequential(
227 | nn.MaxPool2d(2, 2),
228 | nn.Conv2d(64, 64, 3, 1, 1),
229 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
230 | nn.Conv2d(64, 64, 3, 1, 1),
231 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
232 | )
233 | self.slice3 = torch.nn.Sequential(
234 | nn.MaxPool2d(2, 2),
235 | nn.Conv2d(64, 64, 3, 1, 1),
236 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
237 | nn.Conv2d(64, 64, 3, 1, 1),
238 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
239 | )
240 | self.slice4 = torch.nn.Sequential(
241 | nn.MaxPool2d(2, 2),
242 | nn.Conv2d(64, 64, 3, 1, 1),
243 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
244 | nn.Conv2d(64, 64, 3, 1, 1),
245 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
246 | )
247 |
248 | self.conv_lv0 = nn.Conv2d(64, 64, 3, 1, 1)
249 | self.conv_lv1 = nn.Conv2d(64, 64, 3, 1, 1)
250 | self.conv_lv2 = nn.Conv2d(64, 64, 3, 1, 1)
251 | self.conv_lv3 = nn.Conv2d(64, 64, 3, 1, 1)
252 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
253 | # self.maxpool = nn.MaxPool2d(2, 2)
254 |
255 | def forward(self, x, islr=False):
256 | if islr:
257 | x = self.slice1(x)
258 | x_lv3 = self.lrelu(self.conv_lv3(x))
259 | x = self.slice2(x)
260 | x_lv2 = self.lrelu(self.conv_lv2(x))
261 | x = self.slice3(x)
262 | x_lv1 = self.lrelu(self.conv_lv1(x))
263 | x = self.slice4(x)
264 | x_lv0 = self.lrelu(self.conv_lv0(x))
265 | else:
266 | x_lv3 = x
267 | x = self.slice2(x)
268 | x_lv2 = x
269 | x = self.slice3(x)
270 | x_lv1 = x
271 | x = self.slice4(x)
272 | x_lv0 = x
273 | return x_lv0, x_lv1, x_lv2, x_lv3
274 |
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def str2bool(v):
4 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
5 | return True
6 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
7 | return False
8 | else:
9 | raise argparse.ArgumentTypeError('Boolean value expected.')
10 |
11 | parser = argparse.ArgumentParser(description='MRCF')
12 |
13 | ### visdom setting
14 | parser.add_argument('--visdom_port', type=int, default=8801,
15 | help='Visdom execution port')
16 | parser.add_argument('--visdom_view', type=str, default='MRCF',
17 | help='Visdom execution view')
18 |
19 | ### log setting
20 | parser.add_argument('--save_dir', type=str, default='save_dir',
21 | help='Directory to save log, arguments, models and images')
22 | parser.add_argument('--reset', type=str2bool, default=False,
23 | help='Delete save_dir to create a new one')
24 | parser.add_argument('--log_file_name', type=str, default='MRCF.log',
25 | help='Log file name')
26 | parser.add_argument('--logger_name', type=str, default='MRCF',
27 | help='Logger name')
28 |
29 | ### device setting
30 | parser.add_argument('--cpu', type=str2bool, default=False,
31 | help='Use CPU to run code')
32 | parser.add_argument('--num_gpu', type=int, default=1,
33 | help='The number of GPU used in training')
34 | parser.add_argument('--gpu_id', type=int, default=0,
35 | help='The id of GPU used in training')
36 |
37 | ### dataset setting
38 | parser.add_argument('--dataset', type=str, default='REDS',
39 | help='Which dataset to train and test')
40 | parser.add_argument('--dataset_dir', type=str, default='/Data/REDS_sharp/',
41 | help='Directory of dataset')
42 |
43 | ### dataloader setting
44 | parser.add_argument('--num_workers', type=int, default=4,
45 | help='The number of workers when loading data')
46 |
47 | ### model setting
48 | parser.add_argument('--num_res_blocks', type=str, default='4+4+4+4',
49 | help='The number of residual blocks in each stage')
50 | parser.add_argument('--n_feats', type=int, default=64,
51 | help='The number of channels in network')
52 | parser.add_argument('--res_scale', type=float, default=1.,
53 | help='Residual scale')
54 | parser.add_argument('--cra', type=str2bool, default=True,
55 | help='Use cra module or not')
56 | parser.add_argument('--mrcf', type=str2bool, default=True,
57 | help='MRCF or SRCF')
58 | parser.add_argument('--y_only', type=str2bool, default=False,
59 | help='Output Y-channel only or RGB-channels')
60 | parser.add_argument('--hr_dcn', type=str2bool, default=True,
61 | help='Move on DCN to highest spatial dimension or not')
62 | parser.add_argument('--offset_prop', type=str2bool, default=True,
63 | help='Propagate offset feature among each DCN or not')
64 |
65 | ### loss setting
66 | parser.add_argument('--rec_w', type=float, default=1.,
67 | help='The weight of reconstruction loss')
68 |
69 | ### optimizer setting
70 | parser.add_argument('--beta1', type=float, default=0.9,
71 | help='The beta1 in Adam optimizer')
72 | parser.add_argument('--beta2', type=float, default=0.999,
73 | help='The beta2 in Adam optimizer')
74 | parser.add_argument('--eps', type=float, default=1e-12,
75 | help='The eps in Adam optimizer')
76 | parser.add_argument('--lr_rate', type=float, default=1e-4,
77 | help='Learning rate')
78 | parser.add_argument('--lr_rate_flow', type=float, default=2.5e-5,
79 | help='Learning rate of optical flow estimater')
80 | parser.add_argument('--decay', type=float, default=999999,
81 | help='Learning rate decay type')
82 | parser.add_argument('--gamma', type=float, default=0.5,
83 | help='Learning rate decay factor for step decay')
84 |
85 | ### training setting
86 | parser.add_argument('--batch_size', type=int, default=8,
87 | help='Training batch size')
88 | parser.add_argument('--GT_size', type=int, default=256,
89 | help='Training GT size')
90 | parser.add_argument('--FV_size', type=int, default=80,
91 | help='Training FV size')
92 | parser.add_argument('--scale', type=int, default=4,
93 | help='Training upsampling scale')
94 | parser.add_argument('--N_frames', type=int, default=15,
95 | help='Training number of frames')
96 | parser.add_argument('--train_crop_size', type=int, default=40,
97 | help='Training data crop size')
98 | parser.add_argument('--num_init_epochs', type=int, default=2,
99 | help='The number of init epochs which are trained with only reconstruction loss')
100 | parser.add_argument('--num_epochs', type=int, default=1,
101 | help='The number of training epochs')
102 | parser.add_argument('--print_every', type=int, default=1,
103 | help='Print period')
104 | parser.add_argument('--save_every', type=int, default=999999,
105 | help='Save period')
106 | parser.add_argument('--val_every', type=int, default=999999,
107 | help='Validation period')
108 |
109 | ### evaluate / test / finetune setting
110 | parser.add_argument('--eval', type=str2bool, default=False,
111 | help='Evaluation mode')
112 | parser.add_argument('--eval_save_results', type=str2bool, default=False,
113 | help='Save each image during evaluation')
114 | parser.add_argument('--model_path', type=str, default=None,
115 | help='The path of model to evaluation')
116 | parser.add_argument('--test', type=str2bool, default=False,
117 | help='Test mode')
118 |
119 | args = parser.parse_args()
120 |
--------------------------------------------------------------------------------
/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/overview.png
--------------------------------------------------------------------------------
/png2mp4.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import numpy as np
4 | from PIL import Image
5 | import cv2
6 | from torchaudio import save_encinfo
7 |
8 | if __name__ == '__main__':
9 | fourcc = cv2.VideoWriter_fourcc(*'MP4V')
10 | model_code = 15
11 | hr_dcn = True
12 | offset_prop = True
13 | split_ratio = 3
14 | model_name = 'FVSR_x8_simple_v{}_hrdcn_{}_offsetprop_{}_fnet{}_gaussian'.format(model_code, 'y' if hr_dcn else 'n',
15 | 'y' if offset_prop else 'n',
16 | '_{}outof4'.format(4-split_ratio) if model_code == 18 else '')
17 | # model_name = 'Bicubic'
18 |
19 | video_nums = [0, 11, 15, 20]
20 | for video_num in video_nums:
21 | sr_png_dir = 'test_png/eval_video/{}/{}/results'.format(model_name, video_num)
22 | gt_png_dir = 'test_png/eval_video/GroundTruth/{}/'.format(video_num)
23 |
24 | save_dir = 'test_video/{}/{}'.format(model_name, video_num)
25 | if not os.path.exists(save_dir):
26 | os.makedirs(save_dir)
27 |
28 | gt_imgs = []
29 | sr_imgs = []
30 | files = os.listdir(sr_png_dir)
31 | files = sorted(files)
32 | for f in files:
33 | if '.gif' in f:
34 | continue
35 | img = cv2.imread(os.path.join(sr_png_dir, f))
36 | sr_imgs.append(img)
37 | files = os.listdir(gt_png_dir)
38 | files = sorted(files)
39 | for f in files:
40 | img = cv2.imread(os.path.join(gt_png_dir, f))
41 | gt_imgs.append(img)
42 |
43 | H, W, C = sr_imgs[0].shape
44 | out = cv2.VideoWriter(os.path.join(save_dir, 'sr.mp4'), fourcc, 20.0, (W, H))
45 | for i in range(len(sr_imgs)):
46 | out.write(sr_imgs[i])
47 |
48 | H, W, C = gt_imgs[0].shape
49 | out = cv2.VideoWriter(os.path.join(save_dir, 'gt.mp4'), fourcc, 20.0, (W, H))
50 | for i in range(len(gt_imgs)):
51 | out.write(gt_imgs[i])
--------------------------------------------------------------------------------
/pretrained_models/EGVSR_iter420000.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/pretrained_models/EGVSR_iter420000.pth
--------------------------------------------------------------------------------
/pretrained_models/fnet.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/pretrained_models/fnet.pth
--------------------------------------------------------------------------------
/pretrained_models/spynet_20210409-c6c1bd09.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/pretrained_models/spynet_20210409-c6c1bd09.pth
--------------------------------------------------------------------------------
/pytorch_ssim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True, batch_avg = False):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if batch_avg and ssim_map.dim() == 4:
35 | B, C, H, W = ssim_map.size()
36 | return ssim_map.view(B, -1).mean(1)
37 | elif size_average:
38 | return ssim_map.mean()
39 | else:
40 | return ssim_map.mean(1).mean(1).mean(1)
41 |
42 | class SSIM(torch.nn.Module):
43 | def __init__(self, window_size = 11, size_average = True):
44 | super(SSIM, self).__init__()
45 | self.window_size = window_size
46 | self.size_average = size_average
47 | self.channel = 1
48 | self.window = create_window(window_size, self.channel)
49 |
50 | def forward(self, img1, img2):
51 | (_, channel, _, _) = img1.size()
52 |
53 | if channel == self.channel and self.window.data.type() == img1.data.type():
54 | window = self.window
55 | else:
56 | window = create_window(self.window_size, channel)
57 |
58 | if img1.is_cuda:
59 | window = window.cuda(img1.get_device())
60 | window = window.type_as(img1)
61 |
62 | self.window = window
63 | self.channel = channel
64 |
65 |
66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
67 |
68 | def ssim(img1, img2, window_size = 11, size_average = True, batch_avg = False):
69 | (_, channel, _, _) = img1.size()
70 | window = create_window(window_size, channel)
71 |
72 | if img1.is_cuda:
73 | window = window.cuda(img1.get_device())
74 | window = window.type_as(img1)
75 |
76 | return _ssim(img1, img2, window, window_size, channel, size_average, batch_avg)
77 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | ### test
2 |
3 | ### test_video_with_reds
4 | python3 main.py --save_dir ./test/demo/FVSR_x4_mrcf/output/ \
5 | --model_path ./train/model_00001_002900.pt \
6 | --log_file_name test.log \
7 | --reset True \
8 | --test True \
9 | --num_workers 1 \
10 | --scale 4 \
11 | --cra true \
12 | --mrcf true \
13 | --N_frames 15 \
14 | --FV_size 96 \
15 | --dataset Reds \
16 | --dataset_dir /DATA/REDS_sharp/ \
17 | --visdom_port 8803 \
18 | --visdom_view FVSR_x4_mrcf
19 |
--------------------------------------------------------------------------------
/test_img_coor.py:
--------------------------------------------------------------------------------
1 | # importing the module
2 | import cv2
3 |
4 | # function to display the coordinates of
5 | # of the points clicked on the image
6 | def click_event(event, x, y, flags, params):
7 |
8 | # checking for left mouse clicks
9 | if event == cv2.EVENT_LBUTTONDOWN:
10 |
11 | # displaying the coordinates
12 | # on the Shell
13 | print(x, ' ', y)
14 |
15 | # displaying the coordinates
16 | # on the image window
17 | font = cv2.FONT_HERSHEY_SIMPLEX
18 | cv2.putText(img, str(x) + ',' +
19 | str(y), (x,y), font,
20 | 1, (255, 0, 0), 2)
21 | cv2.imshow('image', img)
22 |
23 | # checking for right mouse clicks
24 | if event==cv2.EVENT_RBUTTONDOWN:
25 |
26 | # displaying the coordinates
27 | # on the Shell
28 | print(x, ' ', y)
29 |
30 | # displaying the coordinates
31 | # on the image window
32 | font = cv2.FONT_HERSHEY_SIMPLEX
33 | b = img[y, x, 0]
34 | g = img[y, x, 1]
35 | r = img[y, x, 2]
36 | cv2.putText(img, str(b) + ',' +
37 | str(g) + ',' + str(r),
38 | (x,y), font, 1,
39 | (255, 255, 0), 2)
40 | cv2.imshow('image', img)
41 |
42 | # driver function
43 | if __name__=="__main__":
44 |
45 | # reading the image
46 | # img_path = '/DATA/REDS_sharp/train/train/train_sharp/011/00000036.png'
47 | # img_path = '/DATA/REDS_sharp/train/train/train_sharp/020/00000099.png'
48 | img_path = '/home/si2/TTSR/env/test/test_png/results/FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_cra/pastfv/075_results_1.png'
49 | img = cv2.imread(img_path, 1)
50 |
51 | # displaying the image
52 | cv2.imshow('image', img)
53 |
54 | # setting mouse handler for the image
55 | # and calling the click_event() function
56 | cv2.setMouseCallback('image', click_event)
57 |
58 | # wait for a key to be pressed to exit
59 | cv2.waitKey(0)
60 |
61 | # close the window
62 | cv2.destroyAllWindows()
--------------------------------------------------------------------------------
/test_runtime.py:
--------------------------------------------------------------------------------
1 | from model import MRCF_runtime as MRCF
2 |
3 | import os
4 | import time
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from torch.profiler import profile, record_function, ProfilerActivity
9 | from pytorch_memlab import LineProfiler
10 | from pytorch_memlab import MemReporter
11 | from dcn_v2 import DCNv2
12 |
13 | def conv_identify(weight, bias):
14 | weight.data.zero_()
15 | bias.data.zero_()
16 | o, i, h, w = weight.shape
17 | y = h//2
18 | x = w//2
19 | for p in range(i):
20 | for q in range(o):
21 | if p == q:
22 | weight.data[q, p, y, x] = 1.0
23 |
24 | if __name__ == '__main__':
25 | device = torch.device('cuda')
26 |
27 | mid_channels = 32
28 | start = torch.cuda.Event(enable_timing=True)
29 | end = torch.cuda.Event(enable_timing=True)
30 | # print(torch.cuda.memory_allocated(0), '-3')
31 | # print(torch.cuda.max_memory_allocated(0), '-3')
32 | # torch.cuda.reset_max_memory_allocated(0)
33 | y_only = False
34 | hr_dcn = True
35 | offset_prop = True
36 | split_ratio = 3
37 | # model = MRCF.MRCF_simple_v0(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
38 | # model = MRCF.MRCF_simple_v13(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
39 | # model = MRCF.MRCF_simple_v13_nodcn(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
40 | # model = MRCF.MRCF_simple_v15(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
41 | model = MRCF.MRCF_simple_v18(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, split_ratio=split_ratio, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
42 | # model = MRCF.MRCF_simple_v18_nofv(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, split_ratio=split_ratio, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
43 |
44 | # model = MRCF.MRCF_simple(mid_channels=mid_channels, num_blocks=1, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device)
45 | # model = MRCF.MRCF_simple_v1_dcn2_v4_kai(mid_channels=32, y_only=y_only, spynet_pretrained='pretrained_models/250.pth', device=device).to(device)
46 | # model = MRCF.MRCF_simple_v1_dcn2_v4_kai(mid_channels=32, y_only=y_only, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device)
47 | # model = MRCF.MRCF_simple_v4(mid_channels=mid_channels, y_only=False, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device)
48 | # model = MRCF.MRCF_CRA_x8(mid_channels=mid_channels, num_blocks=1, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device)
49 | # model = MRCF.MRCF_CRA_x8_v1(mid_channels=mid_channels, num_blocks=1, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device)
50 | # model_spy = model.spynet
51 | # model_en_hr = model.encoder_hr
52 | # model_en_lr = model.encoder_lr
53 | # model = nn.Upsample(scale_factor=8, mode='bicubic', align_corners=False)
54 | # for k, v in model.named_parameters():
55 | # v.requires_grad_(False)
56 |
57 | # model = MRCF.ResidualBlocksWithInputConv(mid_channels * 2, mid_channels, 3).to(device)
58 | # model = nn.Sequential(
59 | # nn.Conv2d(mid_channels, mid_channels, 3, 1, 1),
60 | # nn.LeakyReLU(0.1, inplace=True),
61 | # nn.Conv2d(mid_channels, 3, 3, 1 ,1)).to(device)
62 |
63 | # group = 1
64 | # model = nn.Sequential(
65 | # nn.Conv2d(mid_channels*2+2, mid_channels, 3, 1, 1, bias=True),
66 | # nn.LeakyReLU(0.1, inplace=True),
67 | # nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True),
68 | # nn.LeakyReLU(0.1, inplace=True),
69 | # nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True),
70 | # nn.LeakyReLU(0.1, inplace=True)
71 | # ).to(device)
72 | # dcn_offset = nn.Conv2d(mid_channels, group*2*3*3, 3, 1, 1).to(device)
73 | # dcn_mask = nn.Conv2d(mid_channels, group*1*3*3, 3, 1, 1).to(device)
74 | # dcn = DCNv2(mid_channels, mid_channels, 3, stride=1, padding=1, dilation=1, deformable_groups=group).to(device)
75 | # dcn_offset.weight.data.zero_()
76 | # dcn_offset.bias.data.zero_()
77 | # dcn_mask.weight.data.zero_()
78 | # dcn_mask.bias.data.zero_()
79 | # conv_identify(dcn.weight, dcn.bias)
80 |
81 | scale = 1
82 | # HR_h = 720
83 | # HR_w = 1280
84 | HR_h = 1080
85 | HR_w = 1920
86 | # HR_h = 512
87 | # HR_w = 512
88 | LR_h = HR_h // 8
89 | LR_w = HR_w // 8
90 | FV_h = 96
91 | FV_w = 96
92 | WP_h = 720
93 | WP_w = 720
94 | # WP_h = 1080
95 | # WP_w = 1920
96 |
97 | t = 5
98 | repeat_time = 30
99 | warm_up = 10
100 | infer_time = 0
101 |
102 | model.eval()
103 | # dcn_offset.eval()
104 | # dcn_mask.eval()
105 | # dcn.eval()
106 |
107 | with torch.no_grad():
108 | # print(torch.cuda.memory_allocated(0), '-2')
109 | # print(torch.cuda.max_memory_allocated(0), '-2')
110 | # torch.cuda.reset_max_memory_allocated(0)
111 |
112 | # x = torch.rand(1, mid_channels, HR_h, HR_w).cuda()
113 | # i = torch.rand(1, mid_channels * 2, HR_h//scale, HR_w//scale).cuda()
114 |
115 | # f = torch.rand(1, 2, HR_h//scale, HR_w//scale).cuda()
116 | # i = torch.rand(1, mid_channels*2+2, HR_h//scale, HR_w//scale).cuda()
117 |
118 | # f = torch.rand(1, 2, HR_h//scale, HR_w//scale).cuda()
119 | # x = torch.rand(1, mid_channels, HR_h//scale, HR_w//scale).cuda()
120 |
121 | # i = torch.rand(1, 3, HR_h//scale, HR_w//scale).cuda()
122 | # f = model(i)
123 | # o = dcn_offset(f)
124 | # o = 10. * torch.tanh(o)
125 | # m = dcn_mask(f)
126 | # m = torch.sigmoid(m)
127 |
128 | lr = torch.rand(1, t, 3, LR_h, LR_w).cuda()
129 | fv = torch.rand(1, t, 3, FV_h, FV_w).cuda()
130 | # mk = torch.ones(1, t, 1, HR_h, HR_w).cuda()
131 |
132 | # ref = torch.rand(1, 3, LR_h, LR_w).cuda()
133 | # sup = torch.rand(1, 3, LR_h, LR_w).cuda()
134 |
135 | # x_lr = torch.rand(1, 3, LR_h, LR_w).cuda()
136 | # x_fv = torch.rand(1, 6, FV_h, FV_w).cuda()
137 |
138 | # print(torch.cuda.memory_allocated(0), '-1')
139 | # print(torch.cuda.max_memory_allocated(0), '-1')
140 | # torch.cuda.reset_max_memory_allocated(0)
141 |
142 | for idx in range(repeat_time):
143 | if idx < warm_up:
144 | infer_time = 0
145 | torch.cuda.synchronize()
146 | # start_time = time.time()
147 | start.record()
148 |
149 | y = model(lr, fv, warp_size=(WP_h, WP_w))
150 | # y = model(lr, fv)
151 | # y = model(lr, fv, mk)
152 | # y = model_spy(ref, sup)
153 |
154 | # y0, y1, y2 = model_en_lr(x_lr, islr=True)
155 | # y0, y1, y2 = model_en_hr(x_fv, islr=True)
156 |
157 | # y = MRCF.flow_warp(x, f.permute(0, 2, 3, 1))
158 |
159 | # y = dcn(f, o, m)
160 |
161 | # y = model(i)
162 | # o = dcn_offset(y)
163 | # o = 10. * torch.tanh(o)
164 | # f = torch.cat((f[:, 1:2, :, :], f[:, 0:1, :, :]), dim=1)
165 | # f = f.repeat(1, o.size(1) // 2, 1, 1)
166 | # o = o + f
167 | # m = dcn_mask(y)
168 | # m = torch.sigmoid(m)
169 |
170 | # y = model(f)
171 | # y = model(x)
172 | # y = model(x_lr)
173 |
174 | # print(torch.cuda.memory_allocated(0), '0')
175 | # print(torch.cuda.max_memory_allocated(0), '0')
176 | # torch.cuda.reset_max_memory_allocated(0)
177 | end.record()
178 | torch.cuda.synchronize()
179 | # infer_time += (time.time() - start_time)
180 | infer_time += (start.elapsed_time(end)/1000)
181 |
182 | # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
183 | # with record_function("model_inference"):
184 | # y = model(lr, fv, mk)
185 |
186 | print(y.shape, infer_time / (repeat_time - warm_up + 1) / t)
187 | # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=100))
188 | # prof.export_chrome_trace('./mrcf_profile.json')
189 | # reporter.report(verbose=True)
--------------------------------------------------------------------------------
/test_video.py:
--------------------------------------------------------------------------------
1 | from torch import save
2 | from model import MRCF_test as MRCF
3 | from model import LTE
4 | from utils import flow_to_color
5 | from dataset import dataloader
6 | from utils import calc_psnr_and_ssim_cuda, bgr2ycbcr
7 |
8 | import os
9 | import numpy as np
10 | import PIL
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from torch.utils.data import DataLoader
15 | import visdom
16 | from imageio import imread, imsave, get_writer
17 | from PIL import Image
18 | import cv2
19 | # from ptflops import get_model_complexity_info
20 | import warnings
21 | warnings.filterwarnings("ignore", category=UserWarning)
22 |
23 | def foveated_metric(LR, LR_fv, HR, mn, hw, crop, kernel_size, stride_size, eval_mode=False):
24 | m, n = mn
25 | h, w = hw
26 | crop_h, crop_w = crop
27 | HR_fold = F.unfold(HR.unsqueeze(0), kernel_size=(kernel_size, kernel_size), stride=stride_size) # [N, 3*11*11, Hr*Wr]
28 | LR_fv_fold = F.unfold(LR_fv.unsqueeze(0), kernel_size=(kernel_size, kernel_size), stride=stride_size) # [N, 3*11*11, Hr*Wr]
29 | B, C, N = HR_fold.size()
30 | HR_fold = HR_fold.permute(0, 2, 1).view(B*N , 3, kernel_size, kernel_size) # [N, 3*11*11, Hr*Wr]
31 | LR_fv_fold = LR_fv_fold.permute(0, 2, 1).view(B*N , 3, kernel_size, kernel_size)
32 | Hr = (h - kernel_size) // stride_size + 1
33 | Wr = (w - kernel_size) // stride_size + 1
34 |
35 | B, C, H, W = HR_fold.size()
36 | mask = torch.ones((B, 1, H, W)).float()
37 | psnr_score, ssim_score = calc_psnr_and_ssim_cuda(HR_fold, LR_fv_fold, mask, is_tensor=False, batch_avg=True)
38 | psnr_score = psnr_score.view(Hr, Wr)
39 | ssim_score = ssim_score.view(Hr, Wr)
40 |
41 | # psnr_y_idx = (torch.argmax(psnr_score) // Wr) * stride_size
42 | # psnr_x_idx = (torch.argmax(psnr_score) % Wr) * stride_size
43 | # ssim_y_idx = (torch.argmax(ssim_score) // Wr) * stride_size
44 | # ssim_x_idx = (torch.argmax(ssim_score) % Wr) * stride_size
45 | if not eval_mode:
46 | HR[:, m:m+crop_h, n] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_h))
47 | HR[:, m:m+crop_h, n+crop_w-1] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_h))
48 | HR[:, m, n:n+crop_w] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_w))
49 | HR[:, m+crop_h-1, n:n+crop_w] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_w))
50 |
51 | LR_fv[:, m:m+crop_h, n] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_h))
52 | LR_fv[:, m:m+crop_h, n+crop_w-1] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_h))
53 | LR_fv[:, m, n:n+crop_w] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_w))
54 | LR_fv[:, m+crop_h-1, n:n+crop_w] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_w))
55 |
56 | psnr_min = psnr_score.min()
57 | psnr_max = psnr_score.max()
58 | ssim_min = ssim_score.min()
59 | ssim_max = ssim_score.max()
60 | # psnr_score = (psnr_score - psnr_min) / (psnr_max - psnr_min)
61 | # ssim_score = (ssim_score - ssim_min) / (ssim_max - ssim_min)
62 | psnr_score = psnr_score / 100
63 | ssim_score = (ssim_score.clip(0, 1) - 0.7) / 0.3
64 |
65 | # psnr_score_discrete = torch.zeros_like(psnr_score)
66 | # ssim_score_discrete = torch.zeros_like(ssim_score)
67 |
68 | # psnr_score_discrete[psnr_score <= 1.0] = 1.0
69 | # psnr_score_discrete[psnr_score <= 0.9] = 0.9
70 | # psnr_score_discrete[psnr_score <= 0.8] = 0.8
71 | # psnr_score_discrete[psnr_score <= 0.7] = 0.7
72 | # psnr_score_discrete[psnr_score <= 0.6] = 0.6
73 | # psnr_score_discrete[psnr_score <= 0.5] = 0.5
74 | # psnr_score_discrete[psnr_score <= 0.4] = 0.4
75 | # psnr_score_discrete[psnr_score <= 0.3] = 0.3
76 | # psnr_score_discrete[psnr_score <= 0.2] = 0.2
77 | # psnr_score_discrete[psnr_score <= 0.1] = 0.1
78 |
79 | # ssim_score_discrete[ssim_score <= 1.0] = 1.0
80 | # ssim_score_discrete[ssim_score <= 0.9] = 0.9
81 | # ssim_score_discrete[ssim_score <= 0.8] = 0.8
82 | # ssim_score_discrete[ssim_score <= 0.7] = 0.7
83 | # ssim_score_discrete[ssim_score <= 0.6] = 0.6
84 | # ssim_score_discrete[ssim_score <= 0.5] = 0.5
85 | # ssim_score_discrete[ssim_score <= 0.4] = 0.4
86 | # ssim_score_discrete[ssim_score <= 0.3] = 0.3
87 | # ssim_score_discrete[ssim_score <= 0.2] = 0.2
88 | # ssim_score_discrete[ssim_score <= 0.1] = 0.1
89 |
90 | # self.viz.viz.image(HR.cpu().numpy(), win='{}'.format('HR'), opts=dict(title='{}, Image size : {}'.format('HR', HR.size())))
91 | # self.viz.viz.image(LR.cpu().numpy(), win='{}'.format('LR'), opts=dict(title='{}, Image size : {}'.format('LR', LR.size())))
92 | # self.viz.viz.image(LR_fv.cpu().numpy(), win='{}'.format('FV'), opts=dict(title='{}, Image size : {}'.format('FV', LR_fv.size())))
93 | # self.viz.viz.image(psnr_score.cpu().numpy(), win='{}'.format('PSNR_score'), opts=dict(title='{}, Image size : {}'.format('PSNR_score', psnr_score.size())))
94 | # self.viz.viz.image(ssim_score.cpu().numpy(), win='{}'.format('SSIM_score'), opts=dict(title='{}, Image size : {}'.format('SSIM_score', ssim_score.size())))
95 | # self.viz.viz.image(psnr_score_discrete.cpu().numpy(), win='{}'.format('PSNR_score_discrete'), opts=dict(title='{}, Image size : {}'.format('PSNR_score_discrete', psnr_score_discrete.size())))
96 | # self.viz.viz.image(ssim_score_discrete.cpu().numpy(), win='{}'.format('SSIM_score_discrete'), opts=dict(title='{}, Image size : {}'.format('SSIM_score_discrete', ssim_score_discrete.size())))
97 |
98 | return psnr_score, ssim_score, (psnr_min, psnr_max), (ssim_min, ssim_max)
99 |
100 | def rgb2yuv(rgb, y_only=True):
101 | # rgb_ = rgb.permute(0,2,3,1)
102 | # A = torch.tensor([[0.299, -0.14714119,0.61497538],
103 | # [0.587, -0.28886916, -0.51496512],
104 | # [0.114, 0.43601035, -0.10001026]])
105 | # yuv = torch.tensordot(rgb_,A,1).transpose(0,2)
106 | r = rgb[:, 0, :, :]
107 | g = rgb[:, 1, :, :]
108 | b = rgb[:, 2, :, :]
109 |
110 | y = 0.299 * r + 0.587 * g + 0.114 * b
111 | u = -0.147 * r - 0.289 * g + 0.436 * b
112 | v = 0.615 * r - 0.515 * g - 0.100 * b
113 | yuv = torch.stack([y,u,v], dim=1)
114 | if y_only:
115 | return y.unsqueeze(1)
116 | else:
117 | return yuv
118 |
119 | def yuv2rgb(yuv):
120 | y = yuv[:, 0, :, :]
121 | u = yuv[:, 1, :, :]
122 | v = yuv[:, 2, :, :]
123 |
124 | r = y + 1.14 * v # coefficient for g is 0
125 | g = y + -0.396 * u - 0.581 * v
126 | b = y + 2.029 * u # coefficient for b is 0
127 | rgb = torch.stack([r,g,b], 1)
128 |
129 | return rgb
130 |
131 | if __name__ == '__main__':
132 |
133 | device = torch.device('cuda')
134 | viz = visdom.Visdom(server='140.113.212.214', port=8803, env='Gen_video')
135 | # fourcc = cv2.VideoWriter_fourcc('m','p','4','v')
136 | # out = cv2.VideoWriter('test_arcane_simple.mp4', fourcc, 30.0, (1080, 1920))
137 |
138 | dataset_name = 'REDS'
139 | # dataset_name = 'old_tree'
140 | # dataset_name = 'arcane'
141 | regional_dcn = False
142 | eval_mode = True
143 | model_code = 15
144 | model_epoch = 99
145 | y_only = False
146 | hr_dcn = True
147 | offset_prop = True
148 | split_ratio = 3
149 | sigma = 50
150 | dcn_size = 720
151 | model_name = 'FVSR_x8_simple_v{}_hrdcn_{}_offsetprop_{}_fnet{}'.format(model_code, 'y' if hr_dcn else 'n',
152 | 'y' if offset_prop else 'n',
153 | '_{}outof4'.format(4-split_ratio) if model_code == 18 else '')
154 | print('Current model name: {}, Epoch: {}'.format(model_name, model_epoch))
155 | video_num = [ 0, 11, 15, 20]
156 | # video_num = [ 0, 1, 6, 17]
157 | if eval_mode:
158 | fv_st_idx = [0, 0, 0, 0]
159 | else:
160 | fv_st_idx = [66, 30, 31, 0]
161 | # fv_st_idx = [100, 100, 100, 100]
162 | video_set = 'train'
163 | # video_set = 'val'
164 | model_path = 'train/REDS/{}/model/'.format(model_name)
165 | model_saves = os.listdir(model_path)
166 | model_save = [v for v in model_saves if '{:05d}'.format(model_epoch) in v]
167 | assert len(model_save) == 1
168 | model_save = model_save[0]
169 | model_name += '_gaussian'
170 | if eval_mode:
171 | save_dir = 'test_png/eval_video/'
172 | else:
173 | save_dir = 'test_png/'
174 | if not os.path.exists(save_dir):
175 | os.makedirs(save_dir)
176 | if not os.path.exists(os.path.join(save_dir, model_name)):
177 | os.makedirs(os.path.join(save_dir, model_name))
178 |
179 | # model = MRCF.MRCF_CRA_x8(mid_channels=64, num_blocks=1, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device)
180 | if model_code == 13:
181 | model = MRCF.MRCF_simple_v13(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
182 | # model = MRCF.MRCF_simple_v13_nodcn(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
183 | elif model_code == 15:
184 | model = MRCF.MRCF_simple_v15(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
185 | elif model_code == 18:
186 | model = MRCF.MRCF_simple_v18(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, split_ratio=split_ratio, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
187 | # model = MRCF.MRCF_simple_v18_cra(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, split_ratio=split_ratio, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
188 | elif model_code == 0:
189 | model = MRCF.MRCF_simple_v0(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device)
190 |
191 | model_state_dict = model.state_dict()
192 | model_state_dict_save = {k.replace('basic_', 'basic_module.'):v for k,v in torch.load(os.path.join(model_path, model_save)).items() if k.replace('basic_', 'basic_module.') in model_state_dict}
193 | # model_state_dict_save = {k.replace('module.',''):v for k,v in torch.load(model_path).items() if k.replace('module.','') in model_state_dict}
194 | # for k in model_state_dict.keys():
195 | # print(k)
196 | # print('-----')
197 | # for k in model_state_dict_save.keys():
198 | # print(k)
199 | # print(model_save)
200 | # for k,v in torch.load(model_path).items():
201 | # print(k)
202 | model_state_dict.update(model_state_dict_save)
203 | model.load_state_dict(model_state_dict, strict=True)
204 |
205 | psnr_whole_list = []
206 | ssim_whole_list = []
207 | psnr_outskirt_list = []
208 | ssim_outskirt_list = []
209 | psnr_past_list = []
210 | ssim_past_list = []
211 | psnr_fovea_list = []
212 | ssim_fovea_list = []
213 |
214 | for v_idx, v in enumerate(video_num):
215 | if dataset_name == 'REDS':
216 | GT_img_dir = '/DATA/REDS_sharp/{}/{}/{}_sharp/{:03d}/'.format(video_set, video_set, video_set, v)
217 | LR_img_dir = '/DATA/REDS_sharp_BI_x8/{}/{}/{}_sharp/{:03d}/'.format(video_set, video_set, video_set, v)
218 | else:
219 | GT_img_dir = '{}_x1'.format(dataset_name)
220 | LR_img_dir = '{}_x8'.format(dataset_name)
221 | print('Data location: {}'.format(GT_img_dir))
222 |
223 | lr_frames = []
224 | hr_frames = []
225 | GT_imgs = []
226 | LR_imgs = []
227 | LRSR_imgs = []
228 | GT_files = os.listdir(GT_img_dir)
229 | LR_files = os.listdir(LR_img_dir)
230 | GT_files = sorted(GT_files)
231 | LR_files = sorted(LR_files)
232 | for file in GT_files:
233 | img = cv2.imread(os.path.join(GT_img_dir, file))
234 | GT_imgs.append(img[:1072, :1920, :])
235 | # img = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2BGR)
236 | # hr_frames.append(img)
237 | H_, W_, _ = GT_imgs[0].shape
238 | for file in LR_files:
239 | img = cv2.imread(os.path.join(LR_img_dir, file))
240 | LR_imgs.append(img[:134, :240, :])
241 | LRSR_imgs.append(np.array(PIL.Image.fromarray(img).resize((W_, H_), PIL.Image.BICUBIC)))
242 | # img = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2BGR)
243 | # lr_frames.append(img)
244 |
245 | gen_frames = []
246 | gt_frames = []
247 | lr_frames = []
248 | psnr_score_list = []
249 | ssim_score_list = []
250 | psnr_score_bicubic_list = []
251 | ssim_score_bicubic_list = []
252 | # fv_size = 144
253 | fv_size = 96
254 | dx = 0
255 | dy = 0
256 | psnr_min = 1000
257 | psnr_max = 0
258 | ssim_min = 1000
259 | ssim_max = 0
260 |
261 | if dataset_name == 'arcane':
262 | st_x = 760
263 | st_y = 300
264 | ed_x = 1160
265 | ed_y = 500
266 | else:
267 | st_x = 360 + dx
268 | st_y = 300 + dy
269 | ed_x = 720 + dx
270 | ed_y = 500 + dy
271 |
272 | cur_x = ed_x
273 | cur_y = ed_y
274 | step_x = 20
275 | step_y = 0
276 | n_frames = 100
277 |
278 | bd_length = 10
279 | rg_w = dcn_size
280 | rg_h = dcn_size
281 |
282 | GT_imgs = GT_imgs[:n_frames]
283 | LR_imgs = LR_imgs[:n_frames]
284 | LRSR_imgs = LRSR_imgs[:n_frames]
285 | # with get_writer(os.path.join(save_dir,'test_{}_{}_{:03}_{}_bicubic.gif'.format(model_name, dataset_name, model_epoch, int(y_only))), mode="I", fps=7) as writer:
286 | # for n in range(n_frames):
287 | # writer.append_data(LRSR_imgs[n][:,:,::-1])
288 | # with get_writer(os.path.join(save_dir,'test_{}_{}_{:03}_{}_gt.gif'.format(model_name, dataset_name, model_epoch, int(y_only))), mode="I", fps=7) as writer:
289 | # for n in range(n_frames):
290 | # writer.append_data(GT_imgs[n][:,:,::-1])
291 |
292 | GT_imgs = np.stack(GT_imgs, axis=0)
293 | LR_imgs = np.stack(LR_imgs, axis=0)
294 | LRSR_imgs = np.stack(LRSR_imgs, axis=0)
295 | GT_imgs = GT_imgs.astype(np.float32) / 255.
296 | LR_imgs = LR_imgs.astype(np.float32) / 255.
297 | LRSR_imgs = LRSR_imgs.astype(np.float32) / 255.
298 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float()
299 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float()
300 | LRSR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LRSR_imgs, (0, 3, 1, 2)))).float()
301 | N, C, H, W = GT_imgs.size()
302 |
303 | kernel = np.array([ [1, 1, 1],
304 | [1, 1, 1],
305 | [1, 1, 1] ], dtype=np.float32)
306 | kernel_tensor = torch.Tensor(np.expand_dims(np.expand_dims(kernel, 0), 0)) # size: (1, 1, 3, 3)
307 | mk_list = []
308 | mk_one = torch.ones((1, 1, H, W)).to(device)
309 | x_array = sigma * np.random.randn(N) + (W / 2)
310 | y_array = sigma * np.random.randn(N) + (H / 2)
311 | white_paper = np.ones((H,W,3), np.uint8) * 255
312 | traj_list = []
313 |
314 | model.eval()
315 | with torch.no_grad():
316 | for n in range(N):
317 | print(n, '\r', end='')
318 | lr = LR_imgs[n:n+1].unsqueeze(0).to(device)
319 | lrsr = LRSR_imgs[n:n+1].unsqueeze(0).to(device).clone()
320 | gt = GT_imgs[n:n+1].unsqueeze(0).to(device).clone()
321 | fv = torch.zeros_like(gt).to(device)
322 | mk = torch.zeros((1, 1, 1, H, W)).to(device)
323 | fg = torch.zeros((1, 1, 1, H, W)).to(device)
324 |
325 | #### Raster scan
326 | # N_H = H // fv_size
327 | # N_W = W // fv_size
328 | # SP_H = H / N_H
329 | # SP_W = W / N_W
330 | # fv_sp = []
331 | # x_i = n % N_W
332 | # y_i = (n // N_W) % N_H
333 | # cur_y = int((1+y_i)*SP_H - (SP_H + fv_size)//2)
334 | # cur_x = int((1+x_i)*SP_W - (SP_W + fv_size)//2)
335 |
336 | #### Gaussian span
337 | cur_y = int(y_array[n]) - fv_size//2
338 | cur_x = int(x_array[n]) - fv_size//2
339 | traj_list.append((cur_y, cur_x))
340 |
341 | if n >= fv_st_idx[v_idx]:
342 | fv[:, :, :, cur_y:cur_y+fv_size, cur_x:cur_x+fv_size] = gt[:, :, :, cur_y:cur_y+fv_size, cur_x:cur_x+fv_size]
343 | mk[:, :, :, cur_y:cur_y+fv_size, cur_x:cur_x+fv_size] = 1
344 |
345 | mk_fv = mk.clone()
346 | mk_fv[:, :, :, cur_y:cur_y+fv_size, cur_x:cur_x+fv_size] = 1
347 | mk_out = mk_fv.clone().squeeze(0)
348 | for _ in range(10):
349 | mk_out = torch.clamp(F.conv2d(mk_out, kernel_tensor.to(mk_out.device), padding=(1, 1)), 0, 1)
350 | mk_out = torch.logical_and(torch.logical_not(mk), mk_out)
351 |
352 | st_rg_x = max(cur_x+(fv_size//2)-(rg_w//2), 0)
353 | ed_rg_x = min(cur_x+(fv_size//2)+(rg_w//2), 1920)
354 | st_rg_y = max(cur_y+(fv_size//2)-(rg_h//2), 0)
355 | ed_rg_y = min(cur_y+(fv_size//2)+(rg_h//2), 1080)
356 | if regional_dcn:
357 | fg[:, :, :, st_rg_y:ed_rg_y, st_rg_x:ed_rg_x] = 1
358 | else:
359 | fg = torch.ones((1, 1, 1, H, W)).to(device)
360 |
361 | sr = model(lrs=lr, fvs=fv, mks=mk, fgs=fg)
362 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.squeeze(0), gt.squeeze(0), mk_one)
363 | psnr_whole_list.append(psnr)
364 | ssim_whole_list.append(ssim)
365 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.squeeze(0), gt.squeeze(0), mk_fv)
366 | psnr_fovea_list.append(psnr)
367 | ssim_fovea_list.append(ssim)
368 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.squeeze(0), gt.squeeze(0), mk_out)
369 | psnr_outskirt_list.append(psnr)
370 | ssim_outskirt_list.append(ssim)
371 | if n > 0:
372 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.squeeze(0), gt.squeeze(0), mk_past)
373 | psnr_past_list.append(psnr)
374 | ssim_past_list.append(ssim)
375 |
376 | mk_list.append(mk_out.squeeze(0))
377 | if len(mk_list) > 3:
378 | mk_list.pop(0)
379 | mk_past = torch.sum(torch.cat(mk_list, dim=1), dim=1, keepdim=True).clip(0, 1)
380 |
381 | psnr_score, ssim_score, psnr, ssim = foveated_metric(lr[0,0], sr[0,0], gt[0,0].clone(), (cur_y, cur_x), (H, W), (fv_size, fv_size), kernel_size=10, stride_size=5, eval_mode=eval_mode)
382 | psnr_score_list.append((psnr_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy().astype(np.uint8))
383 | ssim_score_list.append((ssim_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy().astype(np.uint8))
384 | psnr_score, ssim_score, psnr, ssim = foveated_metric(lr[0,0], lrsr[0,0], gt[0,0], (cur_y, cur_x), (H, W), (fv_size, fv_size), kernel_size=10, stride_size=5, eval_mode=eval_mode)
385 | if psnr[0] < psnr_min:
386 | psnr_min = psnr[0]
387 | if psnr[1] > psnr_max:
388 | psnr_max = psnr[1]
389 | if ssim[0] < ssim_min:
390 | ssim_min = ssim[0]
391 | if ssim[1] > ssim_max:
392 | ssim_max = ssim[1]
393 | psnr_score_bicubic_list.append((psnr_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy().astype(np.uint8))
394 | ssim_score_bicubic_list.append((ssim_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy().astype(np.uint8))
395 |
396 | if y_only:
397 | B, N, C, H, W = lrsr.size()
398 | lrsr = lrsr.view(B*N, C, H, W)
399 | B, N, C, H, W = sr.size()
400 | sr = sr.view(B*N, C, H, W)
401 | lrsr = rgb2yuv(lrsr, y_only=False)
402 | sr = yuv2rgb(torch.cat((sr[:,0:1,:,:], lrsr[:,1:3,:,:]), dim=1))
403 |
404 | sr = (sr * 255.).clip(0., 255.)
405 | lr = (lrsr * 255.).clip(0., 255.)
406 | gt = (gt * 255.).clip(0., 255.)
407 | sr = np.transpose(sr.squeeze().clone().detach().round().cpu().numpy().astype(np.uint8), (1, 2, 0))
408 | lr = np.transpose(lr.squeeze().clone().detach().round().cpu().numpy().astype(np.uint8), (1, 2, 0))
409 | gt = np.transpose(gt.squeeze().clone().detach().round().cpu().numpy().astype(np.uint8), (1, 2, 0))
410 | sr = cv2.cvtColor(sr, cv2.COLOR_RGB2BGR)
411 | H, W, C = sr.shape
412 | # lr = cv2.resize(cv2.cvtColor(lr, cv2.COLOR_RGB2BGR), (W, H), cv2.INTER_CUBIC)
413 | lr = cv2.cvtColor(lr, cv2.COLOR_RGB2BGR)
414 | gt = cv2.cvtColor(gt, cv2.COLOR_RGB2BGR)
415 |
416 | sr_copy = sr.copy()
417 | # cv2.rectangle(sr, (cur_x, cur_y), (cur_x+fv_size, cur_y+fv_size), (51, 51, 255), 3)
418 | # cv2.rectangle(sr, (st_rg_x, st_rg_y), (ed_rg_x, ed_rg_y), (255, 51, 51), 3)
419 | # cv2.line(sr, (cur_x+fv_size//2-5, cur_y+fv_size//2), (cur_x+fv_size//2+5, cur_y+fv_size//2), (51, 51, 255), 3)
420 | # cv2.line(sr, (cur_x+fv_size//2, cur_y+fv_size//2-5), (cur_x+fv_size//2, cur_y+fv_size//2+5), (51, 51, 255), 3)
421 |
422 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x, :]
423 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x-1, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x-1, :]
424 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+1, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+1, :]
425 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x-2, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x-2, :]
426 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+2, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+2, :]
427 |
428 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size, :]
429 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size-1, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size-1, :]
430 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size+1, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size+1, :]
431 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size-2, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size-2, :]
432 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size+2, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size+2, :]
433 |
434 | # sr[cur_y, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y, cur_x+bd_length:cur_x+fv_size-bd_length, :]
435 | # sr[cur_y-1, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y-1, cur_x+bd_length:cur_x+fv_size-bd_length, :]
436 | # sr[cur_y+1, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+1, cur_x+bd_length:cur_x+fv_size-bd_length, :]
437 | # sr[cur_y-2, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y-2, cur_x+bd_length:cur_x+fv_size-bd_length, :]
438 | # sr[cur_y+2, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+2, cur_x+bd_length:cur_x+fv_size-bd_length, :]
439 |
440 | # sr[cur_y+fv_size, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size, cur_x+bd_length:cur_x+fv_size-bd_length, :]
441 | # sr[cur_y+fv_size-1, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size-1, cur_x+bd_length:cur_x+fv_size-bd_length, :]
442 | # sr[cur_y+fv_size+1, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size+1, cur_x+bd_length:cur_x+fv_size-bd_length, :]
443 | # sr[cur_y+fv_size-2, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size-2, cur_x+bd_length:cur_x+fv_size-bd_length, :]
444 | # sr[cur_y+fv_size+2, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size+2, cur_x+bd_length:cur_x+fv_size-bd_length, :]
445 |
446 | # sr[st_rg_y+bd_length:ed_rg_y-bd_length, st_rg_x-2:st_rg_x+3, :] = sr_copy[st_rg_y+bd_length:ed_rg_y-bd_length, st_rg_x-2:st_rg_x+3, :]
447 | # sr[st_rg_y+bd_length:ed_rg_y-bd_length, ed_rg_x-2:ed_rg_x+3, :] = sr_copy[st_rg_y+bd_length:ed_rg_y-bd_length, ed_rg_x-2:ed_rg_x+3, :]
448 | # sr[st_rg_y-2:st_rg_y+3, st_rg_x+bd_length:ed_rg_x-bd_length, :] = sr_copy[st_rg_y-2:st_rg_y+3, st_rg_x+bd_length:ed_rg_x-bd_length, :]
449 | # sr[ed_rg_y-2:ed_rg_y+3, st_rg_x+bd_length:ed_rg_x-bd_length, :] = sr_copy[ed_rg_y-2:ed_rg_y+3, st_rg_x+bd_length:ed_rg_x-bd_length, :]
450 | # cv2.rectangle(sr, (0, 100), (0+fv_size, 100+fv_size), (51, 51, 255), 3)
451 | # sr = cv2.cvtColor(sr, cv2.COLOR_BGR2RGB)
452 | gen_frames.append(sr.copy())
453 | lr_frames.append(lr.copy())
454 | gt_frames.append(gt.copy())
455 | cur_x += step_x
456 | cur_y += step_y
457 | # viz.image(sr.transpose(2, 0, 1), win='{}'.format('sr'), opts=dict(title='{}, Image size : {}'.format('sr', sr.shape)))
458 | # viz.image(lr.transpose(2, 0, 1), win='{}'.format('lr'), opts=dict(title='{}, Image size : {}'.format('lr', lr.shape)))
459 | # viz.image(gt.transpose(2, 0, 1), win='{}'.format('gt'), opts=dict(title='{}, Image size : {}'.format('gt', gt.shape)))
460 |
461 | if cur_x >= ed_x and cur_y <= st_y:
462 | step_x = 0
463 | step_y = 20
464 | elif cur_x >= ed_x and cur_y >= ed_y:
465 | step_x = -20
466 | step_y = 0
467 | elif cur_x <= st_x and cur_y >= ed_y:
468 | step_x = 0
469 | step_y = -20
470 | elif cur_x <= st_x and cur_y <= st_y:
471 | step_x = 20
472 | step_y = 0
473 |
474 | # for (idy, idx) in traj_list:
475 | # white_paper = cv2.circle(white_paper, (idx,idy), radius=10, color=(0, 0, 255), thickness=-1)
476 |
477 | model.clear_states()
478 | if dataset_name == 'REDS':
479 | #### Reconstructed results
480 | if not os.path.exists(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'results')):
481 | os.makedirs(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'results'))
482 | for i in range(len(gen_frames)):
483 | cv2.imwrite(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'results', '{:03d}.png'.format(i)), gen_frames[i][:,:,::-1])
484 | with get_writer(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'results', 'results.gif'), mode="I", fps=7) as writer:
485 | for n in range(len(gen_frames)):
486 | writer.append_data(gen_frames[n])
487 | if not os.path.exists(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'psnr')):
488 | os.makedirs(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'psnr'))
489 | for i in range(len(gen_frames)):
490 | cv2.imwrite(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'psnr', '{:03d}.png'.format(i)), psnr_score_list[i])
491 | if not os.path.exists(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'ssim')):
492 | os.makedirs(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'ssim'))
493 | for i in range(len(gen_frames)):
494 | cv2.imwrite(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'ssim', '{:03d}.png'.format(i)), ssim_score_list[i])
495 | # cv2.imwrite(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'traj.png'), white_paper)
496 | #### Bicubic upsample results
497 | if not os.path.exists(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]))):
498 | os.makedirs(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'results'))
499 | os.makedirs(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'psnr'))
500 | os.makedirs(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'ssim'))
501 | for i in range(len(lr_frames)):
502 | cv2.imwrite(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'results', '{:03d}.png'.format(i)), lr_frames[i][:,:,::-1])
503 | for i in range(len(lr_frames)):
504 | cv2.imwrite(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'psnr', '{:03d}.png'.format(i)), psnr_score_bicubic_list[i])
505 | for i in range(len(lr_frames)):
506 | cv2.imwrite(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'ssim', '{:03d}.png'.format(i)), ssim_score_bicubic_list[i])
507 | #### GroundTruth
508 | if not os.path.exists(os.path.join(save_dir, 'GroundTruth', str(video_num[v_idx]))):
509 | os.makedirs(os.path.join(save_dir, 'GroundTruth', str(video_num[v_idx])))
510 | for i in range(len(lr_frames)):
511 | cv2.imwrite(os.path.join(save_dir, 'GroundTruth', str(video_num[v_idx]), '{:03d}.png'.format(i)), gt_frames[i][:,:,::-1])
512 | else:
513 | if not os.path.exists(os.path.join(save_dir, model_name, dataset_name, 'results')):
514 | os.makedirs(os.path.join(save_dir, model_name, dataset_name, 'results'))
515 | for i in range(len(gen_frames)):
516 | cv2.imwrite(os.path.join(save_dir, model_name, dataset_name, 'results', '{:03d}.png'.format(i)), gen_frames[i][:,:,::-1])
517 | if not os.path.exists(os.path.join(save_dir, model_name, dataset_name, 'psnr')):
518 | os.makedirs(os.path.join(save_dir, model_name, dataset_name, 'psnr'))
519 | for i in range(len(gen_frames)):
520 | cv2.imwrite(os.path.join(save_dir, model_name, dataset_name, 'psnr', '{:03d}.png'.format(i)), psnr_score_list[i])
521 | if not os.path.exists(os.path.join(save_dir, model_name, dataset_name, 'ssim')):
522 | os.makedirs(os.path.join(save_dir, model_name, dataset_name, 'ssim'))
523 | for i in range(len(gen_frames)):
524 | cv2.imwrite(os.path.join(save_dir, model_name, dataset_name, 'ssim', '{:03d}.png'.format(i)), ssim_score_list[i])
525 | # cv2.imwrite(os.path.join(save_dir, model_name, dataset_name, 'traj.png'), white_paper)
526 | break
527 |
528 | print('PSNR_MIN: {}, PSNR_MAX: {}'.format(psnr_min, psnr_max))
529 | print('SSIM_MIN: {}, SSIM_MAX: {}'.format(ssim_min, ssim_max))
530 | # with get_writer(os.path.join(save_dir,'test_{}_{}_{:03}_{}.gif'.format(model_name, dataset_name, model_epoch, int(y_only))), mode="I", fps=7) as writer:
531 | # for n in range(n_frames):
532 | # # out.write(gen_frames[n][:,:,::-1])
533 | # writer.append_data(gen_frames[n])
534 |
535 | # with get_writer('test_output_hr.gif', mode="I", fps=10) as writer:
536 | # for n in range(n_frames):
537 | # writer.append_data(hr_frames[n])
538 |
539 | # with get_writer('test_output_lr.gif', mode="I", fps=10) as writer:
540 | # for n in range(n_frames):
541 | # writer.append_data(lr_frames[n])
542 |
543 | print('PSNR_W: {}, SSIM_W: {}'.format(sum(psnr_whole_list)/len(psnr_whole_list), sum(ssim_whole_list)/len(ssim_whole_list)))
544 | print('PSNR_F: {}, SSIM_F: {}'.format(sum(psnr_fovea_list)/len(psnr_fovea_list), sum(ssim_fovea_list)/len(ssim_fovea_list)))
545 | print('PSNR_P: {}, SSIM_P: {}'.format(sum(psnr_past_list)/len(psnr_past_list), sum(ssim_past_list)/len(ssim_past_list)))
546 | print('PSNR_O: {}, SSIM_O: {}'.format(sum(psnr_outskirt_list)/len(psnr_outskirt_list), sum(ssim_outskirt_list)/len(ssim_outskirt_list)))
547 |
--------------------------------------------------------------------------------
/test_video_quality.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # file_dir="FVSR_x8_simple_v0_hrdcn_n_offsetprop_n_fnet"
4 | # file_dir="FVSR_x8_simple_v13_hrdcn_n_offsetprop_n_fnet"
5 | # file_dir="FVSR_x8_simple_v13_hrdcn_n_offsetprop_n_fnet_nodcn"
6 | # file_dir="FVSR_x8_simple_v15_hrdcn_n_offsetprop_n_fnet"
7 | # file_dir="FVSR_x8_simple_v15_hrdcn_n_offsetprop_y_fnet"
8 | # file_dir="FVSR_x8_simple_v15_hrdcn_y_offsetprop_n_fnet"
9 | # file_dir="FVSR_x8_simple_v15_hrdcn_y_offsetprop_y_fnet"
10 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4"
11 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_nofv"
12 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_fast"
13 | file_dir="FVSR_x8_simple_v15_hrdcn_y_offsetprop_y_fnet_gaussian"
14 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_gaussian"
15 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_gaussian_regional"
16 | # file_dir="Bicubic"
17 | video_num=$1
18 | ffmpeg \
19 | -r 24 -i test_video/$file_dir/$video_num/gt.mp4 \
20 | -r 24 -i test_video/$file_dir/$video_num/sr.mp4 \
21 | -lavfi "[0:v]setpts=PTS-STARTPTS[reference]; \
22 | [1:v]scale=1280:720:flags=bicubic,setpts=PTS-STARTPTS[distorted]; \
23 | [distorted][reference]libvmaf=log_fmt=xml:log_path=/dev/stdout:model_path=/home/si2/vmaf/model/vmaf_v0.6.1.json" \
24 | -f null - > test_video/$file_dir/$video_num/eval.log
25 | echo $file_dir
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #### Train BasicVSR with Reds
2 | python3 main.py --save_dir ./train/REDS/FVSR_x8_simple_v1_dcn2_v11 \
3 | --reset True \
4 | --log_file_name train.log \
5 | --num_gpu 4 \
6 | --gpu_id 0 \
7 | --num_workers 9 \
8 | --dataset Reds \
9 | --dataset_dir /DATA/REDS_sharp/ \
10 | --model_path ./train/REDS/FVSR_x8_simple_v1_dcn2_v10/model/model_00002_005600.pt \
11 | --n_feats 64 \
12 | --lr_rate 2e-4 \
13 | --lr_rate_flow 2.5e-5 \
14 | --rec_w 1 \
15 | --scale 8 \
16 | --cra true \
17 | --mrcf true \
18 | --batch_size 8 \
19 | --FV_size 128 \
20 | --GT_size 256 \
21 | --N_frames 15 \
22 | --y_only false \
23 | --num_init_epochs 2 \
24 | --num_epochs 80 \
25 | --print_every 200 \
26 | --save_every 100 \
27 | --val_every 1 \
28 | --visdom_port 8803 \
29 | --visdom_view 1227_FVSR_x8_simple_v1_dcn2_v11
30 |
31 | ### simple_v1 dk=1, fd=32-> 8, lv1, range=10, dcn_gp=16
32 | ### simple_v2 dk=3, fd=32-> 8, lv1, range=10, dcn_gp=16
33 | ### simple_v3 dk=3, fd=64->16, lv1, range=10, dcn_gp=16
34 |
35 | ### simple_v4 dk=1, fd=32-> 8, lv3, range=10, dcn_gp= 1
36 | ### simple_v5 dk=3, fd=32-> 8, lv3, range=10, dcn_gp= 1
37 | ### simple_v6 dk=3, fd=32-> 8, lv3, range=80, dcn_gp= 1
38 | ### simple_v7 dk=3, fd=32-> 8, lv3, range=80, dcn_gp= 4
39 |
40 | ### simple_v8 dk=3, fd=64->16, lv3, range=80, dcn_gp= 4
41 | ### simple_v9 dk=3, fd=64-> 8, lv3, range=80, dcn_gp= 4
42 | ### simple_v10 dk=3, fd=32->16, lv3, range=80, dcn_gp= 4
43 | ### simple_v11 dk=1, fd=32->16, lv3, range=80, dcn_gp=16
44 |
45 | ### simple_duf dk=1, fd=32-> 8, lv1, range=10, dcn_gp=16
46 |
47 | ### simple_v12 dk=1, fd=32->32, lv3, range=80, dcn_gp= 4
48 | ### simple_v13 dk=1, fd=32->32, lv3, range=80, dcn_gp= 4, dcn * 2, res * 2
49 |
50 | ### dcn3_v1 dcn * 3, dk=1
51 | ### dcn3_v2 dcn * 3, dk=3(offset using repeat)
52 | ### dcn3_v3 dcn * 3, dk=3(offset & mask using repeat)
53 |
54 | ### dcn2_v1 dcn * 2, dk=1
55 | ### dcn2_v2 dcn * 2, dk=1, offset finetune(mean of generated offset)
56 | ### dcn2_v3 dcn * 2, dk=1, upsample * 2
57 | ### dcn2_v4 dcn * 2, dk=1, res_block * 2
58 | ### dcn2_v5 dcn * 2, dk=1, branch out
59 | ### dcn2_v6 dcn * 2, dk=1, deeper downsampel layer(conv2d * 2)
60 | ### dcn2_v7 dcn * 2, dk=1, deeper downsampel layer(conv2d * 4)
61 | ### dcn2_v8 dcn * 2, dk=1, v4 deeper downsampel layer(PS * 2)
62 | ### dcn2_v9 dcn * 2, dk=1, channel dimension = 32
63 | ### dcn2_v10 dcn * 4, dk=1, channel dimension = 32
64 | ### dcn2_v11 pca , dk=1, channel dimension = 32
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | from utils import calc_psnr_and_ssim_cuda, bgr2ycbcr
2 |
3 | import os
4 | import numpy as np
5 | from imageio import imread, imsave, get_writer
6 | from PIL import Image
7 | import random
8 | import time
9 | from math import cos, pi
10 | from tqdm import tqdm
11 |
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 | import torch.optim as optim
16 | import torchvision.utils as utils
17 | import visdom
18 |
19 | def rgb2yuv(rgb, y_only=True):
20 | # rgb_ = rgb.permute(0,2,3,1)
21 | # A = torch.tensor([[0.299, -0.14714119,0.61497538],
22 | # [0.587, -0.28886916, -0.51496512],
23 | # [0.114, 0.43601035, -0.10001026]])
24 | # yuv = torch.tensordot(rgb_,A,1).transpose(0,2)
25 | r = rgb[:, 0, :, :]
26 | g = rgb[:, 1, :, :]
27 | b = rgb[:, 2, :, :]
28 |
29 | y = 0.299 * r + 0.587 * g + 0.114 * b
30 | u = -0.147 * r - 0.289 * g + 0.436 * b
31 | v = 0.615 * r - 0.515 * g - 0.100 * b
32 | yuv = torch.stack([y,u,v], dim=1)
33 | if y_only:
34 | return y.unsqueeze(1)
35 | else:
36 | return yuv
37 |
38 | def yuv2rgb(yuv):
39 | y = yuv[:, 0, :, :]
40 | u = yuv[:, 1, :, :]
41 | v = yuv[:, 2, :, :]
42 |
43 | r = y + 1.14 * v # coefficient for g is 0
44 | g = y + -0.396 * u - 0.581 * v
45 | b = y + 2.029 * u # coefficient for b is 0
46 | rgb = torch.stack([r,g,b], 1)
47 |
48 | return rgb
49 |
50 | def get_position_from_periods(iteration, cumulative_periods):
51 | """Get the position from a period list.
52 | It will return the index of the right-closest number in the period list.
53 | For example, the cumulative_periods = [100, 200, 300, 400],
54 | if iteration == 50, return 0;
55 | if iteration == 210, return 2;
56 | if iteration == 300, return 3.
57 | Args:
58 | iteration (int): Current iteration.
59 | cumulative_periods (list[int]): Cumulative period list.
60 | Returns:
61 | int: The position of the right-closest number in the period list.
62 | """
63 | for i, period in enumerate(cumulative_periods):
64 | if iteration < period:
65 | return i
66 | raise ValueError(f'Current iteration {iteration} exceeds '
67 | f'cumulative_periods {cumulative_periods}')
68 |
69 |
70 | def annealing_cos(start, end, factor, weight=1):
71 | """Calculate annealing cos learning rate.
72 | Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
73 | percentage goes from 0.0 to 1.0.
74 | Args:
75 | start (float): The starting learning rate of the cosine annealing.
76 | end (float): The ending learing rate of the cosine annealing.
77 | factor (float): The coefficient of `pi` when calculating the current
78 | percentage. Range from 0.0 to 1.0.
79 | weight (float, optional): The combination factor of `start` and `end`
80 | when calculating the actual starting learning rate. Default to 1.
81 | """
82 | cos_out = cos(pi * factor) + 1
83 | return end + 0.5 * weight * (start - end) * cos_out
84 |
85 | class Visdom_exe(object):
86 | def __init__(self, port='8907', env='main'):
87 | self.port = port
88 | self.env = env
89 | self.viz = visdom.Visdom(server='140.113.212.214', port=port, env=env)
90 |
91 | def plot_metric(self, loss=[], psnr=[], ssim=[], psnr_cuda=[], ssim_cuda=[], psnr_y_cuda=[], ssim_y_cuda=[], phase='train'):
92 | if len(loss) != 0:
93 | self.viz.line(X=[*range(len(loss))], Y=loss, win='{}_LOSS'.format(phase), opts={'title':'{}_LOSS'.format(phase)})
94 | if len(psnr) != 0:
95 | self.viz.line(X=[*range(len(psnr))], Y=psnr, win='{}_PSNR'.format(phase), opts={'title':'{}_PSNR'.format(phase)})
96 | if len(ssim) != 0:
97 | self.viz.line(X=[*range(len(ssim))], Y=ssim, win='{}_SSIM'.format(phase), opts={'title':'{}_SSIM'.format(phase)})
98 | if len(psnr_cuda) != 0:
99 | self.viz.line(X=[*range(len(psnr_cuda))], Y=psnr_cuda, win='{}_PSNR_cuda'.format(phase), opts={'title':'{}_PSNR_cuda'.format(phase)})
100 | if len(ssim_cuda) != 0:
101 | self.viz.line(X=[*range(len(ssim_cuda))], Y=ssim_cuda, win='{}_SSIM_cuda'.format(phase), opts={'title':'{}_SSIM_cuda'.format(phase)})
102 | if len(psnr_y_cuda) != 0:
103 | self.viz.line(X=[*range(len(psnr_y_cuda))], Y=psnr_y_cuda, win='{}_PSNR_y_cuda'.format(phase), opts={'title':'{}_PSNR_y_cuda'.format(phase)})
104 | if len(ssim_y_cuda) != 0:
105 | self.viz.line(X=[*range(len(ssim_y_cuda))], Y=ssim_y_cuda, win='{}_SSIM_y_cuda'.format(phase), opts={'title':'{}_SSIM_y_cuda'.format(phase)})
106 |
107 | class Trainer():
108 | def __init__(self, args, logger, dataloader, model, loss_all):
109 | self.viz = Visdom_exe(args.visdom_port, args.visdom_view)
110 | self.args = args
111 | self.logger = logger
112 | self.dataloader = dataloader
113 | self.model = model
114 | self.loss_all = loss_all
115 | # self.device = torch.device('cpu') if args.cpu else torch.device('cuda:{}'.format(args.gpu_id))
116 | self.device = torch.device('cpu') if args.cpu else torch.device('cuda')
117 |
118 | self.cur_clip = 0
119 |
120 | #### CosineAnnealing settings ####
121 | self.cur_iter = 0
122 | self.by_epoch = False
123 | self.periods = [600000]
124 | self.min_lr = 1e-7
125 | self.restart_weights = [1]
126 | self.cumulative_periods = [
127 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
128 | ]
129 | #### CosineAnnealing settings ####
130 |
131 | self.params = [
132 | {"params": [p for n, p in (self.model.named_parameters() if
133 | args.num_gpu==1 else self.model.module.named_parameters())
134 | if ('spynet' not in n)],
135 | "lr": args.lr_rate
136 | },
137 | {"params": [p for n, p in (self.model.named_parameters() if
138 | args.num_gpu==1 else self.model.module.named_parameters())
139 | if ('spynet' in n)],
140 | "lr": args.lr_rate_flow
141 | }
142 | ]
143 |
144 | #### TTSR settings ####
145 | # self.optimizer = optim.Adam(self.params, lr=args.lr_rate, betas=(args.beta1, args.beta2), eps=args.eps)
146 | # self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.decay, gamma=self.args.gamma)
147 |
148 | #### BasicVSR settings ####
149 | self.optimizer = optim.Adam(self.params, lr=args.lr_rate, betas=(args.beta1, args.beta2), eps=args.eps)
150 |
151 | self.max_psnr = 0.
152 | self.max_psnr_epoch = 0
153 | self.max_ssim = 0.
154 | self.max_ssim_epoch = 0
155 |
156 | self.max_y_psnr = 0.
157 | self.max_y_psnr_epoch = 0
158 | self.max_y_ssim = 0.
159 | self.max_y_ssim_epoch = 0
160 |
161 | self.train_loss_list = []
162 | self.train_psnr_list = []
163 | self.train_ssim_list = []
164 | self.train_psnr_cuda_list = []
165 | self.train_ssim_cuda_list = []
166 | self.train_psnr_y_cuda_list = []
167 | self.train_ssim_y_cuda_list = []
168 |
169 | self.eval_loss_list = []
170 | self.eval_psnr_list = []
171 | self.eval_ssim_list = []
172 | self.eval_psnr_cuda_list = []
173 | self.eval_ssim_cuda_list = []
174 | self.eval_psnr_y_cuda_list = []
175 | self.eval_ssim_y_cuda_list = []
176 |
177 | self.test_psnr = []
178 | self.test_ssim = []
179 | self.test_lr = []
180 | self.test_hr = []
181 | self.test_sr = []
182 |
183 | self.print_network(self.model)
184 |
185 | def load(self, model_path=None):
186 | if (model_path):
187 | self.logger.info('load_model_path: ' + model_path)
188 |
189 | model_state_dict = self.model.state_dict()
190 | for k in model_state_dict.keys():
191 | print(k)
192 | print('-----')
193 | model_state_dict_save = {k.replace('basic_', 'basic_module.'):v for k,v in torch.load(model_path).items() if k.replace('basic_', 'basic_module.') in model_state_dict}
194 | # model_state_dict_save = {k.replace('basic_','basic_module.'):v for k,v in torch.load(model_path).items()}
195 | # model_state_dict_save = {k:v for k,v in torch.load(model_path, map_location=self.device)['state_dict'].items()}
196 | for k in model_state_dict_save.keys():
197 | print(k)
198 | model_state_dict.update(model_state_dict_save)
199 | self.model.load_state_dict(model_state_dict, strict=True)
200 |
201 | def prepare(self, sample_batched):
202 | for key in sample_batched.keys():
203 | sample_batched[key] = sample_batched[key].to(self.device)
204 | return sample_batched
205 |
206 | def train_basicvsr(self, current_epoch=0):
207 | self.model.train()
208 | loss_list = []
209 | psnr_cuda_list = []
210 | ssim_cuda_list = []
211 | psnr_y_cuda_list = []
212 | ssim_y_cuda_list = []
213 | loop = tqdm(enumerate(self.dataloader['train']), total=len(self.dataloader['train']), leave=False)
214 | for i_batch, sample_batched in loop:
215 | sample_batched = self.prepare(sample_batched)
216 | lr = sample_batched['LR']
217 | lr_sr = sample_batched['LR_sr']
218 | hr = sample_batched['HR']
219 | ref = sample_batched['Ref']
220 | # ref_sr = sample_batched['Ref_sr']
221 | ref_sp = sample_batched['Ref_sp']
222 |
223 | if self.cur_iter < 5000:
224 | for k, v in self.model.named_parameters():
225 | if 'spynet' in k:
226 | v.requires_grad_(False)
227 | elif self.cur_iter == 5000:
228 | #### train all the parameters
229 | self.model.requires_grad_(True)
230 |
231 | self.before_train_iter()
232 |
233 | sr = self.model(lrs=lr, fvs=ref, mks=ref_sp)
234 | B, N, C, H, W = sr.size()
235 | sr = sr.view(B*N, C, H, W)
236 | B, N, C, H, W = hr.size()
237 | hr = hr.view(B*N, C, H, W)
238 |
239 | if self.args.y_only:
240 | B, N, C, H, W = lr_sr.size()
241 | lr_sr = lr_sr.view(B*N, C, H, W)
242 | lr_sr = rgb2yuv(lr_sr, y_only=False)
243 | sr = yuv2rgb(torch.cat((sr[:,0:1,:,:], lr_sr[:,1:3,:,:]), dim=1))
244 |
245 | rec_loss = self.args.rec_w * self.loss_all['cb_loss'](sr, hr)
246 | loss = rec_loss
247 |
248 | self.optimizer.zero_grad()
249 | loss.backward()
250 | self.optimizer.step()
251 | loss_list.append(loss.cpu().item())
252 |
253 | ### calculate psnr and ssim
254 | #### RGB domain ####
255 | # B, N, C, H, W = sr.size()
256 | mk_ = torch.ones((B*N, 1, H, W)).to(sr.device)
257 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.detach(), hr.detach(), mk_)
258 | psnr_cuda_list.append(psnr.cpu().item())
259 | ssim_cuda_list.append(ssim.cpu().item())
260 |
261 | # #### YCbCr domain ####
262 | # B, N, C, H, W = sr.size()
263 | psnr, ssim = calc_psnr_and_ssim_cuda(bgr2ycbcr(sr.permute(0, 2, 3, 1).detach(), y_only=True), \
264 | bgr2ycbcr(hr.permute(0, 2, 3, 1).detach(), y_only=True), mk_)
265 | # psnr, ssim = calc_psnr_and_ssim_cuda(sr.view(B*N, C, H, W).permute(0, 2, 3, 1).detach(), \
266 | # hr.view(B*N, C, H, W).permute(0, 2, 3, 1).detach(), mk_)
267 | psnr_y_cuda_list.append(psnr.cpu().item())
268 | ssim_y_cuda_list.append(ssim.cpu().item())
269 |
270 | loop.set_description(f"Epoch[{current_epoch}/{self.args.num_epochs}](Train)")
271 | # loop.set_postfix(loss=loss.item(), psnr=psnr.item(), ssim=ssim.item())
272 | loop.set_postfix(loss=loss.item(), lr=self.optimizer.param_groups[0]['lr'], lr_flow=self.optimizer.param_groups[1]['lr'])
273 |
274 | self.cur_iter += 1
275 |
276 | if self.cur_iter % self.args.save_every == 0:
277 | tmp = self.model.state_dict()
278 | model_state_dict = {key.replace('module.',''): tmp[key] for key in tmp}
279 | model_name = self.args.save_dir.strip('/')+'/model/model_{}_{}.pt'.format(str(current_epoch).zfill(5), str(self.cur_iter).zfill(6))
280 | torch.save(model_state_dict, model_name)
281 | self.train_loss_list.append(sum(loss_list)/len(loss_list))
282 | self.train_psnr_cuda_list.append(sum(psnr_cuda_list)/len(psnr_cuda_list))
283 | self.train_ssim_cuda_list.append(sum(ssim_cuda_list)/len(ssim_cuda_list))
284 | self.train_psnr_y_cuda_list.append(sum(psnr_y_cuda_list)/len(psnr_y_cuda_list))
285 | self.train_ssim_y_cuda_list.append(sum(ssim_y_cuda_list)/len(ssim_y_cuda_list))
286 | loss_list.clear()
287 | psnr_cuda_list.clear()
288 | ssim_cuda_list.clear()
289 | psnr_y_cuda_list.clear()
290 | ssim_y_cuda_list.clear()
291 |
292 | # if self.cur_iter % self.args.print_every == 0:
293 | # self.test_basicvsr(save_img=False)
294 |
295 | def eval_basicvsr(self, current_epoch=0):
296 | self.model.eval()
297 | psnr_cuda_list = []
298 | ssim_cuda_list = []
299 | psnr_y_cuda_list = []
300 | ssim_y_cuda_list = []
301 | # kernel = np.array([ [1, 1, 1],
302 | # [1, 1, 1],
303 | # [1, 1, 1] ], dtype=np.float32)
304 | # kernel_tensor = torch.Tensor(np.expand_dims(np.expand_dims(kernel, 0), 0)) # size: (1, 1, 3, 3)
305 | loop = tqdm(enumerate(self.dataloader['eval']), total=len(self.dataloader['eval']), leave=False)
306 | with torch.no_grad():
307 | for i_batch, sample_batched in loop:
308 | sample_batched = self.prepare(sample_batched)
309 | lr = sample_batched['LR']
310 | lr_sr = sample_batched['LR_sr']
311 | hr = sample_batched['HR']
312 | ref = sample_batched['Ref']
313 | ref_sp = sample_batched['Ref_sp']
314 |
315 | B, N, C, H, W = ref_sp.size()
316 | mk = torch.split(ref_sp.view(B*N, 1, H, W).clone().detach(), 1, dim=0)
317 | mk_bd = ref_sp.view(B*N, 1, H, W).float().clone().detach()
318 | sr = self.model(lrs=lr, fvs=ref, mks=ref_sp)
319 | # LR_fv = (sr * 255.).clip(0., 255.).squeeze().clone().detach()
320 | # for n in range(N):
321 | # self.viz.viz.image(LR_fv[n , :, :, :].cpu().numpy(), win='{}'.format('FV'), opts=dict(title='{}, Image size : {}'.format('FV', LR_fv.size())))
322 | # sr = lr_sr
323 |
324 | ### calculate psnr and ssim
325 | B, N, C, H, W = sr.size()
326 | sr = sr.view(B*N, C, H, W)
327 | B, N, C, H, W = hr.size()
328 | hr = hr.view(B*N, C, H, W)
329 |
330 | if self.args.y_only:
331 | B, N, C, H, W = lr_sr.size()
332 | lr_sr = lr_sr.view(B*N, C, H, W)
333 | lr_sr = rgb2yuv(lr_sr, y_only=False)
334 | sr = yuv2rgb(torch.cat((sr[:,0:1,:,:], lr_sr[:,1:3,:,:]), dim=1))
335 |
336 | # LR_fv = (sr * 255.).clip(0., 255.).squeeze().clone().detach()
337 | # for i in range(B*N):
338 | # self.viz.viz.image(LR_fv[i].cpu().numpy(), win='{}'.format('FV'), opts=dict(title='{}, Image size : {}'.format('FV', LR_fv.size())))
339 |
340 | # mk = torch.split(ref_sp.view(B*N, 1, H, W), 1, dim=0)
341 | # mk_bd = ref_sp.view(B*N, 1, H, W).float()
342 | # for _ in range(10):
343 | # mk_bd = torch.clamp(F.conv2d(mk_bd, kernel_tensor.to(mk_bd.device), padding=(1, 1)), 0, 1)
344 | # mk_bd = torch.split(mk_bd, 1, dim=0)
345 | # past_len = 3
346 | # if i_batch % 50 == 0:
347 | # mk_pre = [mk_bd[0]]
348 | mk_ = torch.ones((B*N, 1, H, W)).to(sr.device)
349 | for idx in range(0, N):
350 | if idx == 0 and i_batch % 50 == 0:
351 | continue
352 | sr_ = sr[idx, :, :, :].unsqueeze(0)
353 | hr_ = hr[idx, :, :, :].unsqueeze(0)
354 | # mk_ = mk[idx]
355 | # mk_ = torch.logical_and(mk_, torch.logical_not(mk[idx]))
356 | # mk_ = torch.logical_or(mk_, mk[idx - 1])
357 | # mk_ = torch.logical_and(torch.logical_not(mk[idx]), mk_bd[idx])
358 | # mk_ = torch.sum(torch.cat(mk_pre, dim=1), dim=1, keepdim=True).clip(0, 1)
359 | # psnr, ssim = calc_psnr_and_ssim_cuda(sr.view(B*N, C, H, W).detach(), hr.view(B*N, C, H, W).detach())
360 | # psnr, ssim = calc_psnr_and_ssim_cuda((sr_*mk_).detach(), (hr_*mk_).detach())
361 | psnr, ssim = calc_psnr_and_ssim_cuda(sr_.detach(), hr_.detach(), mk_)
362 | psnr_cuda_list.append(psnr.cpu().item())
363 | ssim_cuda_list.append(ssim.cpu().item())
364 | psnr, ssim = calc_psnr_and_ssim_cuda(bgr2ycbcr(sr_.permute(0, 2, 3, 1).detach(), y_only=True), \
365 | bgr2ycbcr(hr_.permute(0, 2, 3, 1).detach(), y_only=True), mk_)
366 | # psnr, ssim = calc_psnr_and_ssim_cuda(sr_.detach(), \
367 | # hr_.detach(), mk_)
368 | psnr_y_cuda_list.append(psnr.cpu().item())
369 | ssim_y_cuda_list.append(ssim.cpu().item())
370 |
371 | # mk_pre.append(mk_bd[idx])
372 | # if len(mk_pre) > past_len:
373 | # mk_pre.pop(0)
374 |
375 | psnr_ave = sum(psnr_cuda_list)/len(psnr_cuda_list)
376 | ssim_ave = sum(ssim_cuda_list)/len(ssim_cuda_list)
377 | psnr_ave_y = sum(psnr_y_cuda_list)/len(psnr_y_cuda_list)
378 | ssim_ave_y = sum(ssim_y_cuda_list)/len(ssim_y_cuda_list)
379 | loop.set_description(f"Epoch[{current_epoch}/{self.args.num_epochs}](Eval)")
380 | loop.set_postfix(psnr=psnr_ave, ssim=ssim_ave, psnr_y=psnr_ave_y, ssim_y=ssim_ave_y)
381 |
382 | psnr_ave = sum(psnr_cuda_list)/len(psnr_cuda_list)
383 | ssim_ave = sum(ssim_cuda_list)/len(ssim_cuda_list)
384 | self.eval_psnr_cuda_list.append(psnr_ave)
385 | self.eval_ssim_cuda_list.append(ssim_ave)
386 | self.logger.info('Ref PSNR (now): %.3f \t SSIM (now): %.4f' %(psnr_ave, ssim_ave))
387 | if (psnr_ave > self.max_psnr):
388 | self.max_psnr = psnr_ave
389 | self.max_psnr_epoch = current_epoch
390 | if (ssim_ave > self.max_ssim):
391 | self.max_ssim = ssim_ave
392 | self.max_ssim_epoch = current_epoch
393 | self.logger.info('Ref PSNR (max): %.3f (%d) \t SSIM (max): %.4f (%d)'
394 | %(self.max_psnr, self.max_psnr_epoch, self.max_ssim, self.max_ssim_epoch))
395 |
396 | psnr_y_ave = sum(psnr_y_cuda_list)/len(psnr_y_cuda_list)
397 | ssim_y_ave = sum(ssim_y_cuda_list)/len(ssim_y_cuda_list)
398 | self.eval_psnr_y_cuda_list.append(psnr_y_ave)
399 | self.eval_ssim_y_cuda_list.append(ssim_y_ave)
400 | self.logger.info('Ref PSNR_Y (now): %.3f \t SSIM_Y (now): %.4f' %(psnr_y_ave, ssim_y_ave))
401 | if (psnr_y_ave > self.max_y_psnr):
402 | self.max_y_psnr = psnr_y_ave
403 | self.max_y_psnr_epoch = current_epoch
404 | if (ssim_y_ave > self.max_y_ssim):
405 | self.max_y_ssim = ssim_y_ave
406 | self.max_y_ssim_epoch = current_epoch
407 | self.logger.info('Ref PSNR_Y (max): %.3f (%d) \t SSIM_Y (max): %.4f (%d)'
408 | %(self.max_y_psnr, self.max_y_psnr_epoch, self.max_y_ssim, self.max_y_ssim_epoch))
409 |
410 | psnr_cuda_list.clear()
411 | ssim_cuda_list.clear()
412 | psnr_y_cuda_list.clear()
413 | ssim_y_cuda_list.clear()
414 |
415 | def test_basicvsr(self, save_img=True):
416 | if save_img:
417 | self.print_network(self.model)
418 | self.logger.info('Test process...')
419 |
420 | crop_h = self.args.FV_size
421 | crop_w = self.args.FV_size
422 | # kernel_size = self.args.FV_size
423 | kernel_size = 10
424 | stride_size = 5
425 | self.model.eval()
426 | with torch.no_grad():
427 | for i_batch, sample_batched in enumerate(self.dataloader['test']):
428 | sample_batched = self.prepare(sample_batched)
429 | lr = sample_batched['LR']
430 | lr_sr = sample_batched['LR_sr']
431 | hr = sample_batched['HR']
432 | ref = sample_batched['Ref']
433 | ref_sp = sample_batched['Ref_sp']
434 | fv_sp = sample_batched['FV_sp']
435 | B, N, C, H, W = hr.size()
436 |
437 | sr = self.model(lrs=lr, fvs=ref, mks=ref_sp)
438 | if self.args.y_only:
439 | B, N, C, H, W = sr.size()
440 | sr = sr.view(B*N, C, H, W)
441 | B, N, C, H, W = lr_sr.size()
442 | lr_sr = lr_sr.view(B*N, C, H, W)
443 | lr_sr = rgb2yuv(lr_sr, y_only=False)
444 | sr = yuv2rgb(torch.cat((sr[:,0:1,:,:], lr_sr[:,1:3,:,:]), dim=1))
445 |
446 | LR = (lr * 255.).squeeze().clone().detach()
447 | HR = (hr * 255.).squeeze().clone().detach()
448 | LR_fv = (sr * 255.).clip(0., 255.).squeeze().clone().detach()
449 | Ref_sp = fv_sp.squeeze().clone().detach()
450 |
451 | # direction = -1
452 | # scan_step = 8
453 | # accm_step = W - scan_step
454 |
455 | for n in range(N):
456 | psnr_score, ssim_score = self.foveated_metric(LR[n , :, :, :], LR_fv[n , :, :, :], HR[n , :, :, :], Ref_sp[n], (H, W), (crop_h, crop_w), kernel_size, stride_size)
457 | self.test_psnr.append((psnr_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy())
458 | self.test_ssim.append((ssim_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy())
459 | # self.result_comp(LR_fv[n , :, :, :], accm_step)
460 | # accm_step += direction * scan_step
461 | # if accm_step < 0:
462 | # direction *= -1
463 | # accm_step += direction * scan_step
464 | # elif accm_step >= W:
465 | # direction *= -1
466 | # accm_step += direction * scan_step
467 | time.sleep(0.1)
468 |
469 | print('Process: {:.2f} % ...\r'.format(i_batch*100/len(self.dataloader['test'])), end='')
470 |
471 | for n in range(N):
472 | self.test_lr.append(LR[n, :, :, :].round().cpu().numpy())
473 | self.test_hr.append(HR[n, :, :, :].round().cpu().numpy())
474 | self.test_sr.append(LR_fv[n, :, :, :].round().cpu().numpy())
475 |
476 | if save_img:
477 | if len(self.test_sr) == 100:
478 | N = 100
479 | save_path = os.path.join(self.args.save_dir, 'save_results', '{:05d}'.format(self.cur_clip))
480 | if not os.path.isdir(save_path):
481 | os.mkdir(save_path)
482 | with get_writer(os.path.join(save_path, '{:05d}.gif'.format(i_batch)), mode="I", fps=5) as writer:
483 | for n in range(N):
484 | imsave(os.path.join(save_path, '{}_sr.png'.format(n)), np.transpose(self.test_sr[n], (1, 2, 0)).astype(np.uint8))
485 | writer.append_data(np.transpose(self.test_sr[n], (1, 2, 0)).astype(np.uint8))
486 | with get_writer(os.path.join(save_path, '{:05d}_gt.gif'.format(i_batch)), mode="I", fps=5) as writer:
487 | for n in range(N):
488 | imsave(os.path.join(save_path, '{}_hr.png'.format(n)), np.transpose(self.test_hr[n], (1, 2, 0)).astype(np.uint8))
489 | writer.append_data(np.transpose(self.test_hr[n], (1, 2, 0)).astype(np.uint8))
490 | with get_writer(os.path.join(save_path, '{:05d}_lr.gif'.format(i_batch)), mode="I", fps=5) as writer:
491 | for n in range(N):
492 | imsave(os.path.join(save_path, '{}_lr.png'.format(n)), np.transpose(self.test_lr[n], (1, 2, 0)).astype(np.uint8))
493 | writer.append_data(np.transpose(self.test_lr[n], (1, 2, 0)).astype(np.uint8))
494 | with get_writer(os.path.join(save_path, '{:05d}_psnr.gif'.format(i_batch)), mode="I", fps=5) as writer:
495 | for n in range(N):
496 | imsave(os.path.join(save_path, '{}_psnr.png'.format(n)), self.test_psnr[n].astype(np.uint8))
497 | writer.append_data(self.test_psnr[n].astype(np.uint8))
498 | with get_writer(os.path.join(save_path, '{:05d}_ssim.gif'.format(i_batch)), mode="I", fps=5) as writer:
499 | for n in range(N):
500 | imsave(os.path.join(save_path, '{}_ssim.png'.format(n)), self.test_ssim[n].astype(np.uint8))
501 | writer.append_data(self.test_ssim[n].astype(np.uint8))
502 | self.test_lr.clear()
503 | self.test_hr.clear()
504 | self.test_sr.clear()
505 | self.test_psnr.clear()
506 | self.test_ssim.clear()
507 | self.cur_clip += 1
508 | else:
509 | break
510 |
511 | if save_img:
512 | self.logger.info('Test over.')
513 |
514 | def test_basicvsr_scan(self, save_img=True):
515 | if save_img:
516 | self.print_network(self.model)
517 | self.logger.info('Test process...')
518 | self.logger.info('lr path: %s' %(self.args.lr_path))
519 | self.logger.info('ref path: %s' %(self.args.ref_path))
520 |
521 | self.model.eval()
522 | with torch.no_grad():
523 | for i_batch, sample_batched in enumerate(self.dataloader['test']):
524 | sample_batched = self.prepare(sample_batched)
525 | lr = sample_batched['LR']
526 | hr = sample_batched['HR']
527 | ref = sample_batched['Ref']
528 | ref_sp = sample_batched['Ref_sp']
529 | fv_sp = sample_batched['FV_sp']
530 | B, N, C, H, W = hr.size()
531 |
532 | sr = self.model(lrs=lr, fvs=ref, mks=ref_sp)
533 | #### Save images ####
534 | if save_img:
535 | sr_save = (sr * 255.).clip(0., 225.).squeeze()
536 | SR_t = sr
537 | sr_save = np.transpose(sr_save.round().cpu().detach().numpy(), (0, 2, 3, 1)).astype(np.uint8)
538 | for t in range(N):
539 | save_path = os.path.join(self.args.save_dir, 'save_results', '{:05d}'.format(i_batch))
540 | if not os.path.isdir(save_path):
541 | os.mkdir(save_path)
542 | imsave(os.path.join(save_path, '{}.png'.format(t)), sr_save[t])
543 | #### Save images ####
544 | LR = (lr * 255.).squeeze().clone().detach()
545 | HR = (hr * 255.).squeeze().clone().detach()
546 | LR_fv = (sr * 255.).clip(0., 255.).squeeze().clone().detach()
547 | Ref_sp = fv_sp.squeeze().clone().detach()
548 | for n in range(N):
549 | self.result_comp(LR_fv[n , :, :, :])
550 | time.sleep(0.1)
551 | if save_img:
552 | with get_writer(os.path.join(save_path, '{:05d}.gif'.format(i_batch)), mode="I", fps=5) as writer:
553 | for n in range(N):
554 | writer.append_data(np.transpose(LR_fv[n].round().cpu().detach().numpy(), (1, 2, 0)).astype(np.uint8))
555 | print('Process: {:.2f} % ...\r'.format(i_batch*100/len(self.dataloader['test'])), end='')
556 | else:
557 | break
558 |
559 | if save_img:
560 | self.logger.info('Test over.')
561 |
562 | def vis_plot_metric(self, phase):
563 | if phase == 'train':
564 | self.viz.plot_metric(loss=self.train_loss_list, psnr=self.train_psnr_list, ssim=self.train_ssim_list, \
565 | psnr_cuda=self.train_psnr_cuda_list, ssim_cuda=self.train_ssim_cuda_list, \
566 | psnr_y_cuda=self.train_psnr_y_cuda_list, ssim_y_cuda=self.train_ssim_y_cuda_list, \
567 | phase='train')
568 | elif phase == 'eval':
569 | self.viz.plot_metric(loss=self.eval_loss_list, psnr=self.eval_psnr_list, ssim=self.eval_ssim_list, \
570 | psnr_cuda=self.eval_psnr_cuda_list, ssim_cuda=self.eval_ssim_cuda_list, \
571 | psnr_y_cuda=self.eval_psnr_y_cuda_list, ssim_y_cuda=self.eval_ssim_y_cuda_list, \
572 | phase='eval')
573 |
574 | def _get_network_description(self, net):
575 | """Get the string and total parameters of the network"""
576 | if isinstance(net, nn.DataParallel):
577 | net = net.module
578 | return str(net), sum(map(lambda x: x.numel(), net.parameters()))
579 |
580 | def print_network(self, net):
581 | """Print the str and parameter number of a network.
582 | Args:
583 | net (nn.Module)
584 | """
585 | net_str, net_params = self._get_network_description(net)
586 | if isinstance(net, nn.DataParallel):
587 | net_cls_str = (f'{net.__class__.__name__} - '
588 | f'{net.module.__class__.__name__}')
589 | else:
590 | net_cls_str = f'{net.__class__.__name__}'
591 |
592 | self.logger.info(
593 | f'Network: {net_cls_str}, with parameters: {net_params:,d}')
594 | self.logger.info(net_str)
595 |
596 | def before_run(self):
597 | # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved,
598 | # it will be set according to the optimizer params
599 | for group in self.optimizer.param_groups:
600 | group.setdefault('initial_lr', group['lr'])
601 | self.base_lr = [
602 | group['initial_lr'] for group in self.optimizer.param_groups
603 | ]
604 |
605 | def before_train_iter(self):
606 | self.regular_lr = self.get_regular_lr()
607 | self._set_lr(self.regular_lr)
608 |
609 | def get_regular_lr(self):
610 | return [self.get_lr(_base_lr) for _base_lr in self.base_lr]
611 |
612 | def get_lr(self, base_lr):
613 | progress = self.cur_iter
614 | target_lr = self.min_lr
615 |
616 | idx = get_position_from_periods(progress, self.cumulative_periods)
617 | current_weight = self.restart_weights[idx]
618 | nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
619 | current_periods = self.periods[idx]
620 |
621 | alpha = min((progress - nearest_restart) / current_periods, 1)
622 | return annealing_cos(base_lr, target_lr, alpha, current_weight)
623 |
624 | def _set_lr(self, lr_groups):
625 | for param_group, lr in zip(self.optimizer.param_groups, lr_groups):
626 | param_group['lr'] = lr
627 |
628 | def foveated_metric(self, LR, LR_fv, HR, mn, hw, crop, kernel_size, stride_size):
629 | m, n = mn
630 | h, w = hw
631 | crop_h, crop_w = crop
632 | HR_fold = F.unfold(HR.unsqueeze(0), kernel_size=(kernel_size, kernel_size), stride=stride_size) # [N, 3*11*11, Hr*Wr]
633 | LR_fv_fold = F.unfold(LR_fv.unsqueeze(0), kernel_size=(kernel_size, kernel_size), stride=stride_size) # [N, 3*11*11, Hr*Wr]
634 | B, C, N = HR_fold.size()
635 | HR_fold = HR_fold.permute(0, 2, 1).view(B*N , 3, kernel_size, kernel_size) # [N, 3*11*11, Hr*Wr]
636 | LR_fv_fold = LR_fv_fold.permute(0, 2, 1).view(B*N , 3, kernel_size, kernel_size)
637 | Hr = (h - kernel_size) // stride_size + 1
638 | Wr = (w - kernel_size) // stride_size + 1
639 |
640 | B, C, H, W = HR_fold.size()
641 | mask = torch.ones((B, 1, H, W)).float()
642 | psnr_score, ssim_score = calc_psnr_and_ssim_cuda(HR_fold, LR_fv_fold, mask, is_tensor=False, batch_avg=True)
643 | psnr_score = psnr_score.view(Hr, Wr)
644 | ssim_score = ssim_score.view(Hr, Wr)
645 |
646 | psnr_y_idx = (torch.argmax(psnr_score) // Wr) * stride_size
647 | psnr_x_idx = (torch.argmax(psnr_score) % Wr) * stride_size
648 | ssim_y_idx = (torch.argmax(ssim_score) // Wr) * stride_size
649 | ssim_x_idx = (torch.argmax(ssim_score) % Wr) * stride_size
650 |
651 | LR_fv[:, m:m+crop_h, n] = torch.tensor([255., 0., 0.]).unsqueeze(1).repeat((1,crop_h))
652 | LR_fv[:, m:m+crop_h, n+crop_w-1] = torch.tensor([255., 0., 0.]).unsqueeze(1).repeat((1,crop_h))
653 | LR_fv[:, m, n:n+crop_w] = torch.tensor([255., 0., 0.]).unsqueeze(1).repeat((1,crop_w))
654 | LR_fv[:, m+crop_h-1, n:n+crop_w] = torch.tensor([255., 0., 0.]).unsqueeze(1).repeat((1,crop_w))
655 |
656 | psnr_score_discrete = torch.zeros_like(psnr_score)
657 | ssim_score_discrete = torch.zeros_like(ssim_score)
658 |
659 | psnr_score = (psnr_score - psnr_score.min()) / (psnr_score.max() - psnr_score.min())
660 | ssim_score = (ssim_score - ssim_score.min()) / (ssim_score.max() - ssim_score.min())
661 |
662 | psnr_score_discrete[psnr_score <= 1.0] = 1.0
663 | psnr_score_discrete[psnr_score <= 0.9] = 0.9
664 | psnr_score_discrete[psnr_score <= 0.8] = 0.8
665 | psnr_score_discrete[psnr_score <= 0.7] = 0.7
666 | psnr_score_discrete[psnr_score <= 0.6] = 0.6
667 | psnr_score_discrete[psnr_score <= 0.5] = 0.5
668 | psnr_score_discrete[psnr_score <= 0.4] = 0.4
669 | psnr_score_discrete[psnr_score <= 0.3] = 0.3
670 | psnr_score_discrete[psnr_score <= 0.2] = 0.2
671 | psnr_score_discrete[psnr_score <= 0.1] = 0.1
672 |
673 | ssim_score_discrete[ssim_score <= 1.0] = 1.0
674 | ssim_score_discrete[ssim_score <= 0.9] = 0.9
675 | ssim_score_discrete[ssim_score <= 0.8] = 0.8
676 | ssim_score_discrete[ssim_score <= 0.7] = 0.7
677 | ssim_score_discrete[ssim_score <= 0.6] = 0.6
678 | ssim_score_discrete[ssim_score <= 0.5] = 0.5
679 | ssim_score_discrete[ssim_score <= 0.4] = 0.4
680 | ssim_score_discrete[ssim_score <= 0.3] = 0.3
681 | ssim_score_discrete[ssim_score <= 0.2] = 0.2
682 | ssim_score_discrete[ssim_score <= 0.1] = 0.1
683 |
684 | # self.viz.viz.image(HR.cpu().numpy(), win='{}'.format('HR'), opts=dict(title='{}, Image size : {}'.format('HR', HR.size())))
685 | # self.viz.viz.image(LR.cpu().numpy(), win='{}'.format('LR'), opts=dict(title='{}, Image size : {}'.format('LR', LR.size())))
686 | self.viz.viz.image(LR_fv.cpu().numpy(), win='{}'.format('FV'), opts=dict(title='{}, Image size : {}'.format('FV', LR_fv.size())))
687 | # self.viz.viz.image(psnr_score.cpu().numpy(), win='{}'.format('PSNR_score'), opts=dict(title='{}, Image size : {}'.format('PSNR_score', psnr_score.size())))
688 | # self.viz.viz.image(ssim_score.cpu().numpy(), win='{}'.format('SSIM_score'), opts=dict(title='{}, Image size : {}'.format('SSIM_score', ssim_score.size())))
689 | # self.viz.viz.image(psnr_score_discrete.cpu().numpy(), win='{}'.format('PSNR_score_discrete'), opts=dict(title='{}, Image size : {}'.format('PSNR_score_discrete', psnr_score_discrete.size())))
690 | # self.viz.viz.image(ssim_score_discrete.cpu().numpy(), win='{}'.format('SSIM_score_discrete'), opts=dict(title='{}, Image size : {}'.format('SSIM_score_discrete', ssim_score_discrete.size())))
691 |
692 | return psnr_score, ssim_score
693 |
694 | def result_comp(self, LR_fv, SP_W):
695 | C, H, W = LR_fv.size()
696 | LR_fv[:, :, SP_W] = torch.tensor([255., 255., 255.]).unsqueeze(1).repeat((1,H))
697 | self.viz.viz.image(LR_fv.cpu().numpy(), win='{}'.format('LR_fv'), opts=dict(title='{}, Image size : {}'.format('LR_fv', LR_fv.size())))
698 |
--------------------------------------------------------------------------------
/untar_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | tar xvf train_model.tar -C train/REDS/
4 | rm train_model.tar
5 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import logging
4 | import cv2
5 | import os
6 | import shutil
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.autograd import Variable
12 | import pytorch_ssim as pytorch_ssim
13 |
14 | class Logger(object):
15 | def __init__(self, log_file_name, logger_name, log_level=logging.DEBUG):
16 | ### create a logger
17 | self.__logger = logging.getLogger(logger_name)
18 |
19 | ### set the log level
20 | self.__logger.setLevel(log_level)
21 |
22 | ### create a handler to write log file
23 | file_handler = logging.FileHandler(log_file_name)
24 |
25 | ### create a handler to print on console
26 | console_handler = logging.StreamHandler()
27 |
28 | ### define the output format of handlers
29 | formatter = logging.Formatter('[%(asctime)s] - [%(filename)s file line:%(lineno)d] - %(levelname)s: %(message)s')
30 | file_handler.setFormatter(formatter)
31 | console_handler.setFormatter(formatter)
32 |
33 | ### add handler to logger
34 | self.__logger.addHandler(file_handler)
35 | self.__logger.addHandler(console_handler)
36 |
37 | def get_log(self):
38 | return self.__logger
39 |
40 |
41 | def mkExpDir(args):
42 | if (os.path.exists(args.save_dir)):
43 | if (not args.reset):
44 | raise SystemExit('Error: save_dir "' + args.save_dir + '" already exists! Please set --reset True to delete the folder.')
45 | else:
46 | shutil.rmtree(args.save_dir)
47 |
48 | os.makedirs(args.save_dir)
49 | # os.makedirs(os.path.join(args.save_dir, 'img'))
50 |
51 | if ((not args.eval) and (not args.test)):
52 | os.makedirs(os.path.join(args.save_dir, 'model'))
53 |
54 | if ((args.eval and args.eval_save_results) or args.test):
55 | os.makedirs(os.path.join(args.save_dir, 'save_results'))
56 |
57 | args_file = open(os.path.join(args.save_dir, 'args.txt'), 'w')
58 | for k, v in vars(args).items():
59 | args_file.write(k.rjust(30,' ') + '\t' + str(v) + '\n')
60 |
61 | _logger = Logger(log_file_name=os.path.join(args.save_dir, args.log_file_name),
62 | logger_name=args.logger_name).get_log()
63 |
64 | return _logger
65 |
66 |
67 | class MeanShift(nn.Conv2d):
68 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
69 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
70 | std = torch.Tensor(rgb_std)
71 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
72 | self.weight.data.div_(std.view(3, 1, 1, 1))
73 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
74 | self.bias.data.div_(std)
75 | # self.requires_grad = False
76 | self.weight.requires_grad = False
77 | self.bias.requires_grad = False
78 |
79 |
80 | def calc_psnr(img1, img2):
81 | ### args:
82 | # img1: [h, w, c], range [0, 255]
83 | # img2: [h, w, c], range [0, 255]
84 | diff = (img1 - img2) / 255.0
85 | diff[:,:,0] = diff[:,:,0] * 65.738 / 256.0
86 | diff[:,:,1] = diff[:,:,1] * 129.057 / 256.0
87 | diff[:,:,2] = diff[:,:,2] * 25.064 / 256.0
88 |
89 | diff = np.sum(diff, axis=2)
90 | mse = np.mean(np.power(diff, 2))
91 | return -10 * math.log10(mse)
92 |
93 |
94 | def calc_ssim(img1, img2):
95 | def ssim(img1, img2):
96 | C1 = (0.01 * 255)**2
97 | C2 = (0.03 * 255)**2
98 |
99 | img1 = img1.astype(np.float64)
100 | img2 = img2.astype(np.float64)
101 | kernel = cv2.getGaussianKernel(11, 1.5)
102 | window = np.outer(kernel, kernel.transpose())
103 |
104 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
105 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
106 | mu1_sq = mu1**2
107 | mu2_sq = mu2**2
108 | mu1_mu2 = mu1 * mu2
109 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
110 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
111 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
112 |
113 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
114 | (sigma1_sq + sigma2_sq + C2))
115 | return ssim_map.mean()
116 |
117 | ### args:
118 | # img1: [h, w, c], range [0, 255]
119 | # img2: [h, w, c], range [0, 255]
120 | # the same outputs as MATLAB's
121 | border = 0
122 | img1_y = np.dot(img1, [65.738,129.057,25.064])/256.0+16.0
123 | img2_y = np.dot(img2, [65.738,129.057,25.064])/256.0+16.0
124 | if not img1.shape == img2.shape:
125 | raise ValueError('Input images must have the same dimensions.')
126 | h, w = img1.shape[:2]
127 | img1_y = img1_y[border:h-border, border:w-border]
128 | img2_y = img2_y[border:h-border, border:w-border]
129 |
130 | if img1_y.ndim == 2:
131 | return ssim(img1_y, img2_y)
132 | elif img1.ndim == 3:
133 | if img1.shape[2] == 3:
134 | ssims = []
135 | for i in range(3):
136 | ssims.append(ssim(img1, img2))
137 | return np.array(ssims).mean()
138 | elif img1.shape[2] == 1:
139 | return ssim(np.squeeze(img1), np.squeeze(img2))
140 | else:
141 | raise ValueError('Wrong input image dimensions.')
142 |
143 |
144 | def calc_psnr_and_ssim(sr, hr):
145 | ### args:
146 | # sr: pytorch tensor, range [-1, 1]
147 | # hr: pytorch tensor, range [-1, 1]
148 |
149 | ### prepare datac
150 | sr = (sr+1.) * 127.5
151 | hr = (hr+1.) * 127.5
152 | if (sr.size() != hr.size()):
153 | h_min = min(sr.size(2), hr.size(2))
154 | w_min = min(sr.size(3), hr.size(3))
155 | sr = sr[:, :, :h_min, :w_min]
156 | hr = hr[:, :, :h_min, :w_min]
157 |
158 | img1 = np.transpose(sr.squeeze().round().cpu().numpy(), (1,2,0))
159 | img2 = np.transpose(hr.squeeze().round().cpu().numpy(), (1,2,0))
160 |
161 | psnr = calc_psnr(img1, img2)
162 | ssim = calc_ssim(img1, img2)
163 |
164 | return psnr, ssim
165 |
166 | def psnr_cuda(img1, img2, mask, batch_avg=False):
167 | #### Image range [0, 1] ####
168 | if batch_avg:
169 | B, C, H, W = img1.size()
170 | mse = ((img1 - img2) ** 2).view(B, -1).mean(1)
171 | psnr = torch.where(mse == 0, -20 * torch.log10(torch.sqrt((1/255.)**2 / torch.tensor(C*H*W).to(mse.device))),\
172 | -20 * torch.log10(torch.sqrt(mse)))
173 | else:
174 | # mse = torch.mean((img1 - img2) ** 2)
175 | B, C, H, W = img1.size()
176 | # C = 1
177 | mse = (((img1 - img2) ** 2)*mask).sum() / (mask.float().sum() * C)
178 | if mse == 0:
179 | psnr = -20 * torch.log10(torch.sqrt((1/255.)**2 / (torch.prod(torch.tensor(img1.size())))))
180 | else:
181 | psnr = -20 * torch.log10(torch.sqrt(mse.cpu()))
182 |
183 | #### Image range [0, 255] ####
184 | # psnr = 20. * torch.log10(255. / torch.sqrt(mse))
185 | return psnr
186 |
187 | def gaussian(window_size, sigma):
188 | gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
189 | return gauss/gauss.sum()
190 |
191 | def create_window(window_size, channel):
192 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
193 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
194 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
195 | return window
196 |
197 | def _ssim(img1, img2, mask, window, window_size, channel, size_average = True, batch_avg = False):
198 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
199 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
200 |
201 | mu1_sq = mu1.pow(2)
202 | mu2_sq = mu2.pow(2)
203 | mu1_mu2 = mu1*mu2
204 |
205 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
206 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
207 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
208 |
209 | C1 = 0.01**2
210 | C2 = 0.03**2
211 | #### Image range [0, 255] ####
212 | # C1 = (0.01 * 255)**2
213 | # C2 = (0.03 * 255)**2
214 |
215 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
216 | # print(ssim_map.size())
217 | B, C, H, W = ssim_map.size()
218 |
219 | if batch_avg and ssim_map.dim() == 4:
220 | B, C, H, W = ssim_map.size()
221 | return ssim_map.view(B, -1).mean(1)
222 | elif size_average:
223 | # return ssim_map.mean()
224 | return (ssim_map*mask).sum() / (mask.float().sum() * C)
225 | else:
226 | return ssim_map.mean(1).mean(1).mean(1)
227 |
228 | def ssim(img1, img2, mask, window_size = 11, size_average = True, batch_avg = False):
229 | (_, channel, _, _) = img1.size()
230 | window = create_window(window_size, channel)
231 |
232 | if img1.is_cuda:
233 | window = window.cuda(img1.get_device())
234 | window = window.type_as(img1)
235 |
236 | return _ssim(img1, img2, mask, window, window_size, channel, size_average, batch_avg)
237 |
238 | def ssim_cuda(img1, img2, mask, batch_avg=False):
239 | #### Image range [0, 1] ####
240 | return ssim(img1, img2, mask, batch_avg=batch_avg)
241 |
242 | def calc_psnr_and_ssim_cuda(sr, hr, mask, is_tensor=True, batch_avg=False):
243 | #### Convert Image range to [0, 1] ####
244 | min_val, max_val = hr.min(), hr.max()
245 | if((max_val - min_val) > 2):
246 | sr = sr / 255.
247 | hr = hr / 255.
248 | elif((max_val - min_val) > 1):
249 | sr = (sr+1.) / 2.
250 | hr = (hr+1.) / 2.
251 | #### Convert Image range to [0, 255] ####
252 | # sr = sr * 255.
253 | # hr = hr * 255.
254 | return psnr_cuda(sr, hr, mask, batch_avg=batch_avg), ssim_cuda(sr, hr, mask, batch_avg=batch_avg)
255 |
256 | def _convert_input_type_range(img):
257 | """Convert the type and range of the input image.
258 | It converts the input image to np.float32 type and range of [0, 1].
259 | It is mainly used for pre-processing the input image in colorspace
260 | convertion functions such as rgb2ycbcr and ycbcr2rgb.
261 | Args:
262 | img (ndarray): The input image. It accepts:
263 | 1. np.uint8 type with range [0, 255];
264 | 2. np.float32 type with range [0, 1].
265 | Returns:
266 | (ndarray): The converted image with type of np.float32 and range of
267 | [0, 1].
268 | """
269 | img_type = img.dtype
270 | img = img.astype(np.float32)
271 | if img_type == np.float32:
272 | pass
273 | elif img_type == np.uint8:
274 | img /= 255.
275 | else:
276 | raise TypeError('The img type should be np.float32 or np.uint8, '
277 | f'but got {img_type}')
278 | return img
279 |
280 |
281 | def _convert_output_type_range(img, dst_type):
282 | """Convert the type and range of the image according to dst_type.
283 | It converts the image to desired type and range. If `dst_type` is np.uint8,
284 | images will be converted to np.uint8 type with range [0, 255]. If
285 | `dst_type` is np.float32, it converts the image to np.float32 type with
286 | range [0, 1].
287 | It is mainly used for post-processing images in colorspace convertion
288 | functions such as rgb2ycbcr and ycbcr2rgb.
289 | Args:
290 | img (ndarray): The image to be converted with np.float32 type and
291 | range [0, 255].
292 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
293 | converts the image to np.uint8 type with range [0, 255]. If
294 | dst_type is np.float32, it converts the image to np.float32 type
295 | with range [0, 1].
296 | Returns:
297 | (ndarray): The converted image with desired type and range.
298 | """
299 | if dst_type not in (np.uint8, np.float32):
300 | raise TypeError('The dst_type should be np.float32 or np.uint8, '
301 | f'but got {dst_type}')
302 | if dst_type == np.uint8:
303 | img = img.round()
304 | else:
305 | img /= 255.
306 | return img.astype(dst_type)
307 |
308 | def bgr2ycbcr(img, y_only=False):
309 | """Convert a BGR image to YCbCr image.
310 | The bgr version of rgb2ycbcr.
311 | It implements the ITU-R BT.601 conversion for standard-definition
312 | television. See more details in
313 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
314 | It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
315 | In OpenCV, it implements a JPEG conversion. See more details in
316 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
317 | Args:
318 | img (ndarray): The input image. It accepts:
319 | 1. np.uint8 type with range [0, 255];
320 | 2. np.float32 type with range [0, 1].
321 | y_only (bool): Whether to only return Y channel. Default: False.
322 | Returns:
323 | ndarray: The converted YCbCr image. The output image has the same type
324 | and range as input image.
325 | """
326 | # img_type = img.dtype
327 | # img = _convert_input_type_range(img)
328 | if y_only:
329 | out_img = torch.matmul(img, torch.tensor([24.966, 128.553, 65.481]).to(img.device)) + 16.0
330 | out_img = out_img.unsqueeze(3).permute(0, 3, 1, 2)
331 | else:
332 | out_img = torch.matmul(
333 | img, torch.tensor([[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
334 | [65.481, -37.797, 112.0]]).to(img.device)) + torch.tensor([16, 128, 128])
335 | out_img = out_img.permute(0, 3, 1, 2)
336 | # out_img = _convert_output_type_range(out_img, img_type)
337 | return out_img
338 |
339 | def make_colorwheel():
340 | """
341 | Generates a color wheel for optical flow visualization as presented in:
342 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
343 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
344 |
345 | Code follows the original C++ source code of Daniel Scharstein.
346 | Code follows the the Matlab source code of Deqing Sun.
347 |
348 | Returns:
349 | np.ndarray: Color wheel
350 | """
351 |
352 | RY = 15
353 | YG = 6
354 | GC = 4
355 | CB = 11
356 | BM = 13
357 | MR = 6
358 |
359 | ncols = RY + YG + GC + CB + BM + MR
360 | colorwheel = np.zeros((ncols, 3))
361 | col = 0
362 |
363 | # RY
364 | colorwheel[0:RY, 0] = 255
365 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
366 | col = col+RY
367 | # YG
368 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
369 | colorwheel[col:col+YG, 1] = 255
370 | col = col+YG
371 | # GC
372 | colorwheel[col:col+GC, 1] = 255
373 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
374 | col = col+GC
375 | # CB
376 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
377 | colorwheel[col:col+CB, 2] = 255
378 | col = col+CB
379 | # BM
380 | colorwheel[col:col+BM, 2] = 255
381 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
382 | col = col+BM
383 | # MR
384 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
385 | colorwheel[col:col+MR, 0] = 255
386 | return colorwheel
387 |
388 |
389 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
390 | """
391 | Applies the flow color wheel to (possibly clipped) flow components u and v.
392 |
393 | According to the C++ source code of Daniel Scharstein
394 | According to the Matlab source code of Deqing Sun
395 |
396 | Args:
397 | u (np.ndarray): Input horizontal flow of shape [H,W]
398 | v (np.ndarray): Input vertical flow of shape [H,W]
399 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
400 |
401 | Returns:
402 | np.ndarray: Flow visualization image of shape [H,W,3]
403 | """
404 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
405 | colorwheel = make_colorwheel() # shape [55x3]
406 | ncols = colorwheel.shape[0]
407 | rad = np.sqrt(np.square(u) + np.square(v))
408 | a = np.arctan2(-v, -u)/np.pi
409 | fk = (a+1) / 2*(ncols-1)
410 | k0 = np.floor(fk).astype(np.int32)
411 | k1 = k0 + 1
412 | k1[k1 == ncols] = 0
413 | f = fk - k0
414 | for i in range(colorwheel.shape[1]):
415 | tmp = colorwheel[:,i]
416 | col0 = tmp[k0] / 255.0
417 | col1 = tmp[k1] / 255.0
418 | col = (1-f)*col0 + f*col1
419 | idx = (rad <= 1)
420 | col[idx] = 1 - rad[idx] * (1-col[idx])
421 | col[~idx] = col[~idx] * 0.75 # out of range
422 | # Note the 2-i => BGR instead of RGB
423 | ch_idx = 2-i if convert_to_bgr else i
424 | flow_image[:,:,ch_idx] = np.floor(255 * col)
425 | return flow_image
426 |
427 |
428 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
429 | """
430 | Expects a two dimensional flow image of shape.
431 |
432 | Args:
433 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
434 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
435 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
436 |
437 | Returns:
438 | np.ndarray: Flow visualization image of shape [H,W,3]
439 | """
440 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
441 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
442 | if clip_flow is not None:
443 | flow_uv = np.clip(flow_uv, 0, clip_flow)
444 | u = flow_uv[:,:,0]
445 | v = flow_uv[:,:,1]
446 | rad = np.sqrt(np.square(u) + np.square(v))
447 | rad_max = np.max(rad)
448 | epsilon = 1e-5
449 | u = u / (rad_max + epsilon)
450 | v = v / (rad_max + epsilon)
451 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------