├── README.md ├── backup ├── dataset ├── dataloader.py ├── reds.py └── vimeo7.py ├── demo ├── sigma10.gif ├── sigma100.gif └── sigma50.gif ├── eval.sh ├── gen_img.py ├── gen_video.py ├── gif_combine.py ├── loss └── loss.py ├── main.py ├── model ├── CRFP.py ├── CRFP_runtime.py ├── CRFP_test.py └── LTE.py ├── option.py ├── overview.png ├── png2mp4.py ├── pretrained_models ├── EGVSR_iter420000.pth ├── fnet.pth └── spynet_20210409-c6c1bd09.pth ├── pytorch_ssim └── __init__.py ├── test.sh ├── test_img_coor.py ├── test_runtime.py ├── test_video.py ├── test_video_quality.sh ├── train.sh ├── trainer.py ├── untar_models.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Cross-Resolution-Flow-Propagation-for-Foveated-Video-Super-Resolution 2 | 3 | Official implementation of Cross-Resolution Flow Propagation for Foveated Video Super-Resolution (CRFP) accepted by WACV 2023. 4 | 5 | 6 | 7 | ## Demo 8 | 9 | Demonstration how CRFP deal with various value of $\sigma^T$ representing the noise induced by the movement of eye tracker during pupil detection. Note that regions beyond the foveated region are resolved under $8\times$ super-resolution. 10 | 11 | $\sigma^T=10$ 12 | 13 | 14 | 15 | $\sigma^T=50$ 16 | 17 | 18 | 19 | $\sigma^T=100$ 20 | 21 | 22 | 23 | 24 | 25 | ## Training and evaluation 26 | To train the model, you need to install DCN first from https://github.com/jinfagang/DCNv2_latest 27 | 28 | Run the following to start training 29 | ``` 30 | bash train.sh 31 | ``` 32 | 33 | To evaluate, run 34 | ``` 35 | bash eval.sh 36 | ``` 37 | 38 | To test, run 39 | ``` 40 | bash test.sh 41 | ``` 42 | 43 | ## References 44 | Most of the code is referenced from 45 | 46 | 1. TTSR: https://github.com/researchmm/TTSR 47 | 2. BasicVSR: https://github.com/open-mmlab/mmediting 48 | 49 | 50 | -------------------------------------------------------------------------------- /backup: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | tar --exclude='train/*' --exclude='env/*' --exclude='eval/*' --exclude='test/*' --exclude='test_png/*' --exclude='test_video/*' --exclude='test_gif/*' --exclude='old_tree_x1/*' --exclude='old_tree_x8/*' -cvf CRFP_x8.tar . 4 | -------------------------------------------------------------------------------- /dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from importlib import import_module 3 | 4 | 5 | def get_dataloader(args): 6 | ### import module 7 | m = import_module('dataset.' + args.dataset.lower()) 8 | 9 | if (args.dataset == 'Vimeo7'): 10 | print('Processing Vimeo7 dataset...') 11 | data_train = getattr(m, 'TrainSet')(args) 12 | dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 13 | data_eval = getattr(m, 'EvalSet')(args) 14 | dataloader_eval = DataLoader(data_eval, batch_size=1, shuffle=True, num_workers=1) 15 | data_test = getattr(m, 'TestSet')(args) 16 | dataloader_test = DataLoader(data_test, batch_size=1, shuffle=False, num_workers=1) 17 | dataloader = {'train': dataloader_train, 'eval': dataloader_eval, 'test': dataloader_test} 18 | elif (args.dataset == 'Reds'): 19 | print('Processing Reds dataset...') 20 | data_train = getattr(m, 'TrainSet')(args) 21 | dataloader_train = DataLoader(data_train, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) 22 | data_eval = getattr(m, 'EvalSet')(args) 23 | dataloader_eval = DataLoader(data_eval, batch_size=1, shuffle=False, num_workers=1) 24 | data_test = getattr(m, 'TestSet')(args) 25 | dataloader_test = DataLoader(data_test, batch_size=1, shuffle=False, num_workers=1) 26 | dataloader = {'train': dataloader_train, 'eval': dataloader_eval, 'test': dataloader_test} 27 | else: 28 | raise SystemExit('Error: no such type of dataset!') 29 | 30 | return dataloader -------------------------------------------------------------------------------- /dataset/reds.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import pickle 5 | import logging 6 | import numpy as np 7 | import PIL 8 | import pdb 9 | import math 10 | 11 | import torch 12 | import torch.utils.data as data 13 | import torchvision.transforms.functional as F 14 | 15 | logger = logging.getLogger('base') 16 | 17 | def fovea_generator(GT_imgs, method='Rscan', step=0.1, FV_HW=(32, 32)): 18 | 19 | len_sp = len(GT_imgs) 20 | if torch.is_tensor(GT_imgs): 21 | _, GT_H, GT_W = GT_imgs[0].size() 22 | else: 23 | GT_H, GT_W, _ = GT_imgs[0].shape 24 | 25 | FV_H, FV_W = FV_HW 26 | if method == 'Cscan' or method == 'Zscan': 27 | SP = 0.1 28 | CP = 0.5 29 | EP = 0.9 30 | else: 31 | SP = 0.1 32 | CP = 0.5 33 | EP = 0.9 34 | 35 | #### shift according to center point 36 | CP_H = (GT_H*CP - FV_H//2)/GT_H 37 | CP_W = (GT_W*CP - FV_W//2)/GT_W 38 | EP_H = (GT_H*EP - FV_H)/GT_H 39 | EP_W = (GT_W*EP - FV_W)/GT_W 40 | 41 | #### finetune step size 42 | if method == 'Cscan' or method == 'Zscan': 43 | if SP + math.ceil(math.sqrt(len_sp)) * step > EP_H or SP + math.ceil(math.sqrt(len_sp)) * step > EP_W: 44 | step = min((EP_H - SP) / math.ceil(math.sqrt(len_sp)), (EP_W - SP) / math.ceil(math.sqrt(len_sp))) 45 | SP = int(SP * 100) 46 | step = int(step * 100) 47 | EP = int(SP + math.ceil(math.sqrt(len_sp) - 1) * step) 48 | elif method == 'Hscan': 49 | if SP + len_sp * step > EP_W: 50 | step = (EP_W - SP) / len_sp 51 | SP = int(SP * 100) 52 | step = int(step * 100) 53 | EP = int((SP + len_sp * step)) 54 | elif method == 'Vscan': 55 | if SP + len_sp * step > EP_H: 56 | step = (EP_H - SP) / len_sp 57 | SP = int(SP * 100) 58 | step = int(step * 100) 59 | EP = int((SP + len_sp * step)) 60 | else: 61 | if SP + len_sp * step > EP_H or SP + len_sp * step > EP_W: 62 | step = min((EP_H - SP) / len_sp, (EP_W - SP) / len_sp) 63 | SP = int(SP * 100) 64 | step = int(step * 100) 65 | EP = int((SP + len_sp * step)) 66 | 67 | #### fovea scan simulation 68 | if method == 'Hscan': 69 | fv_sp = [[int(CP_H * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]] 70 | elif method == 'Vscan': 71 | fv_sp = [[int((v / 100) * GT_H), int(CP_W * GT_W)] for v in [*range(SP, EP, step)]] 72 | elif method == 'DRscan': # Not done 73 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]] 74 | elif method == 'DLscan': # Not done 75 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]] 76 | elif method == 'Cscan': 77 | fv_sp = [] 78 | v, h = (SP, SP) 79 | v_step, h_step = (step, step) 80 | for t in range(len_sp): 81 | fv_sp.append([int((v / 100) * GT_H), int((h / 100) * GT_W)]) 82 | if h == EP and h_step > 0: 83 | h_step = -h_step 84 | v += v_step 85 | elif h == SP and h_step < 0: 86 | h_step = -h_step 87 | v += v_step 88 | else: 89 | h += h_step 90 | elif method == 'Zscan': 91 | fv_sp = [] 92 | v, h = (SP, SP) 93 | v_step, h_step = (step, step) 94 | for t in range(len_sp): 95 | fv_sp.append([int((v / 100) * GT_H), int((h / 100) * GT_W)]) 96 | if h == EP and v_step < 0: 97 | v_step = -v_step 98 | v += v_step 99 | h_step = -abs(h_step) 100 | elif v == SP and h_step > 0: 101 | h += h_step 102 | h_step = -h_step 103 | v_step = abs(v_step) 104 | elif v == EP and h_step < 0: 105 | h_step = -h_step 106 | h += h_step 107 | v_step = -abs(v_step) 108 | elif h == SP and v_step > 0: 109 | v += v_step 110 | v_step = -v_step 111 | h_step = abs(h_step) 112 | else: 113 | h += h_step 114 | v += v_step 115 | elif method == 'Rscan': 116 | sigma = 0.05 117 | rand_h = np.random.normal(CP_H, sigma, len_sp).clip(0, EP_H) 118 | rand_w = np.random.normal(CP_W, sigma, len_sp).clip(0, EP_W) 119 | fv_sp = [[int(rh * GT_H), int(rw * GT_W)] for rh, rw in zip(rand_h, rand_w)] 120 | elif method == 'Nanascan': # Not done 121 | # SP_H = 0 122 | # EP_H = (GT_H - FV_H - 1) / GT_H 123 | # Q1_H = (0.25 - (FV_H / GT_H)/2) if (0.25 - (FV_H / GT_H)/2) > 0 else SP_H 124 | # Q2_H = (0.50 - (FV_H / GT_H)/2) 125 | # Q3_H = (0.75 - (FV_H / GT_H)/2) if (0.75 + (FV_H / GT_H)/2) < 1 else EP_H 126 | # T1_H = (0.33 - (FV_H / GT_H)/2) if (0.33 - (FV_H / GT_H)/2) > 0 else SP_H 127 | # T2_H = (0.66 - (FV_H / GT_H)/2) if (0.66 + (FV_H / GT_H)/2) < 1 else EP_H 128 | 129 | ratio_H = FV_H / GT_H 130 | SP_H = 0 + (ratio_H / 2) 131 | EP_H = 1 - (ratio_H / 2) 132 | Q1_H = SP_H + ((EP_H - SP_H) * 0.25) 133 | Q2_H = SP_H + ((EP_H - SP_H) * 0.50) 134 | Q3_H = SP_H + ((EP_H - SP_H) * 0.75) 135 | T1_H = SP_H + ((EP_H - SP_H) * 0.33) 136 | T2_H = SP_H + ((EP_H - SP_H) * 0.66) 137 | 138 | ratio_W = FV_W / GT_W 139 | SP_W = 0 + (ratio_W / 2) 140 | EP_W = 1 - (ratio_W / 2) 141 | Q1_W = SP_W + ((EP_W - SP_W) * 0.25) 142 | Q2_W = SP_W + ((EP_W - SP_W) * 0.50) 143 | Q3_W = SP_W + ((EP_W - SP_W) * 0.75) 144 | T1_W = SP_W + ((EP_W - SP_W) * 0.33) 145 | T2_W = SP_W + ((EP_W - SP_W) * 0.66) 146 | 147 | locs = [[SP_H, SP_W], [SP_H, T1_W], [SP_H, T2_W], [SP_H, EP_W], 148 | [T1_H, SP_W], [T1_H, T1_W], [T1_H, T2_W], [T1_H, EP_W], 149 | [T2_H, SP_W], [T2_H, T1_W], [T2_H, T2_W], [T2_H, EP_W], 150 | [EP_H, SP_W], [EP_H, T1_W], [EP_H, T2_W], [EP_H, EP_W]] 151 | locs = [(y - (ratio_H / 2), x - (ratio_H / 2)) for y, x in locs] 152 | # locs = [[Q1_H, T1_W], [Q1_H, T2_W], [Q2_H, Q1_W], [Q2_H, Q2_W], [Q2_H, Q3_W], [Q3_H, T1_W], [Q3_H, T2_W]] 153 | # locs = [[Q2_H, Q2_W]] 154 | 155 | fv_sp = random.choices(locs, k=len_sp) 156 | fv_sp = [[min(int(v[0] * GT_H), GT_H-FV_H), min(int(v[1] * GT_W), GT_W-FV_W)] for v in fv_sp] 157 | random.shuffle(fv_sp) 158 | elif method == 'Evenscan': # Not done 159 | idx = 20 160 | N_H = GT_H // FV_H 161 | N_W = GT_W // FV_W 162 | SP_H = GT_H / N_H 163 | SP_W = GT_W / N_W 164 | fv_sp = [] 165 | for i in range(idx, idx + len_sp): 166 | x_i = i % N_W 167 | y_i = (i // N_W) % N_H 168 | fv_sp.append([int((1+y_i)*SP_H - (SP_H + FV_H)/2), int((1+x_i)*SP_W - (SP_W + FV_W)/2)]) 169 | elif method == 'DemoHscan': # Not done 170 | SP_H = 0 171 | EP_H = 1 172 | SP_W = 0 173 | EP_W = 1 174 | fv_sp = [] 175 | direction = -1 176 | scan_step = 8 177 | accm_step = GT_W - scan_step 178 | for _ in range(len_sp): 179 | fv_sp.append([0, accm_step]) 180 | accm_step += direction * scan_step 181 | if accm_step < 0: 182 | direction *= -1 183 | accm_step += direction * scan_step 184 | elif accm_step >= GT_W: 185 | direction *= -1 186 | accm_step += direction * scan_step 187 | else: 188 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]] 189 | 190 | fv_sp = torch.tensor(fv_sp) 191 | if torch.is_tensor(GT_imgs): 192 | FV_imgs = [] 193 | Ref_sps = [] 194 | for t in range(len(GT_imgs)): 195 | #### With padding #### 196 | Ref_sp = torch.zeros_like(GT_imgs[t]) 197 | if method == 'DemoHscan': 198 | Ref_sp[:, fv_sp[t][0]:, fv_sp[t][1]:] = 1 199 | else: 200 | Ref_sp[:, fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W] = 1 201 | Ref = GT_imgs[t] * Ref_sp 202 | FV_imgs.append(Ref) 203 | Ref_sps.append(Ref_sp) 204 | #### Without padding #### 205 | # FV_imgs.append(GT_imgs[t][:, fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W].clone()) 206 | # FV_imgs = torch.stack(FV_imgs, dim=0) 207 | else: 208 | FV_imgs = [] 209 | Ref_sps = [] 210 | for t in range(len(GT_imgs)): 211 | #### With padding #### 212 | # Ref_sp = np.zeros_like(GT_imgs[t]) 213 | H, W, C = GT_imgs[t].shape 214 | Ref_sp = np.zeros((H, W, 1)) 215 | if method == 'DemoHscan': 216 | Ref_sp[fv_sp[t][0]:, fv_sp[t][1]:, :] = 1 217 | else: 218 | Ref_sp[fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W, :] = 1 219 | Ref = GT_imgs[t] * Ref_sp 220 | FV_imgs.append(Ref) 221 | Ref_sps.append(Ref_sp) 222 | #### Without padding #### 223 | # FV_imgs.append(GT_imgs[t][fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W, :].copy()) 224 | # FV_imgs = np.stack(FV_imgs, dim=0) 225 | 226 | return FV_imgs, Ref_sps, fv_sp 227 | 228 | class TrainSet(data.Dataset): 229 | ''' 230 | Reading the training Vimeo dataset 231 | key example: train/00001/0001/im1.png 232 | GT: Ground-Truth; 233 | LQ: Low-Quality, e.g., low-resolution frames 234 | support reading N HR frames, N = 3, 5, 7 235 | ''' 236 | def __init__(self, args): 237 | super(TrainSet, self).__init__() 238 | self.args = args 239 | 240 | scale = self.args.scale 241 | if scale == 8: 242 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI_x8') 243 | elif scale == 4: 244 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI') 245 | self.GT_dir_list = sorted([os.path.join(args.dataset_dir, 'train/train/train_sharp', name) for name in 246 | os.listdir(os.path.join(args.dataset_dir, 'train/train/train_sharp')) if name not in ['000', '011', '015', '020']]) + \ 247 | sorted([os.path.join(args.dataset_dir, 'val/val/val_sharp', name) for name in 248 | os.listdir(os.path.join(args.dataset_dir, 'val/val/val_sharp')) if name not in ['000', '001', '006', '017']]) 249 | self.LR_dir_list = sorted([os.path.join(LR_root, 'train/train/train_sharp', name) for name in 250 | os.listdir(os.path.join(LR_root, 'train/train/train_sharp')) if name not in ['000', '011', '015', '020']]) + \ 251 | sorted([os.path.join(LR_root, 'val/val/val_sharp', name) for name in 252 | os.listdir(os.path.join(LR_root, 'val/val/val_sharp')) if name not in ['000', '001', '006', '017']]) 253 | N_frames = self.args.N_frames 254 | self.GT_imgfiles = [] 255 | self.LR_imgfiles = [] 256 | for idx in range(len(self.GT_dir_list)): 257 | GT_imgfiles_cur = sorted(os.listdir(self.GT_dir_list[idx])) 258 | for img_idx in range(0, len(GT_imgfiles_cur) - N_frames + 1): 259 | self.GT_imgfiles.append([os.path.join(self.GT_dir_list[idx], img_f) for img_f in GT_imgfiles_cur[img_idx:img_idx + N_frames]]) 260 | for idx in range(len(self.LR_dir_list)): 261 | LR_imgfiles_cur = sorted(os.listdir(self.LR_dir_list[idx])) 262 | for img_idx in range(0, len(LR_imgfiles_cur) - N_frames + 1): 263 | self.LR_imgfiles.append([os.path.join(self.LR_dir_list[idx], img_f) for img_f in LR_imgfiles_cur[img_idx:img_idx + N_frames]]) 264 | 265 | def __getitem__(self, index): 266 | #### Configs 267 | scale = self.args.scale 268 | GT_size = self.args.GT_size 269 | LR_size = GT_size // scale 270 | FV_size = self.args.FV_size 271 | 272 | ### GT 273 | GT_imgfiles = self.GT_imgfiles[index] 274 | GT_imgs = [np.array(PIL.Image.open(img)) for img in GT_imgfiles] 275 | 276 | #### Bicubic downsampling 277 | ### LR and LR_sr 278 | H_, W_, _ = GT_imgs[0].shape 279 | LR_imgfiles = self.LR_imgfiles[index] 280 | LR_imgs = [np.array(PIL.Image.open(img)) for img in LR_imgfiles] 281 | LR_sr_imgs = [np.array(PIL.Image.fromarray(img).resize((W_, H_), PIL.Image.BICUBIC)) for img in LR_imgs] 282 | 283 | ### Random cropping 284 | H, W, C = LR_imgs[0].shape 285 | rnd_h = random.randint(0, max(0, H - LR_size)) 286 | rnd_w = random.randint(0, max(0, W - LR_size)) 287 | LR_imgs = [v[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] for v in LR_imgs] 288 | 289 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) 290 | GT_imgs = [v[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] for v in GT_imgs] 291 | LR_sr_imgs = [v[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] for v in LR_sr_imgs] 292 | 293 | ### Ref(FV) and Ref_sr 294 | Ref, Ref_sp, _ = fovea_generator(GT_imgs, method='Nanascan', FV_HW=(FV_size, FV_size)) 295 | 296 | #### Stacking 297 | GT_imgs = np.stack(GT_imgs, axis=0) 298 | LR_imgs = np.stack(LR_imgs, axis=0) 299 | LR_sr_imgs = np.stack(LR_sr_imgs, axis=0) 300 | Ref = np.stack(Ref, axis=0) 301 | Ref_sp = np.stack(Ref_sp, axis=0) 302 | 303 | #### Scaling 304 | GT_imgs = GT_imgs.astype(np.float32) / 255. 305 | LR_imgs = LR_imgs.astype(np.float32) / 255. 306 | LR_sr_imgs = LR_sr_imgs.astype(np.float32) / 255. 307 | Ref = Ref.astype(np.float32) / 255. 308 | Ref_sp = Ref_sp.astype(np.bool_) 309 | 310 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float() 311 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float() 312 | LR_sr_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_sr_imgs, (0, 3, 1, 2)))).float() 313 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float() 314 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2)))) 315 | 316 | if torch.rand(1) < 0.5: 317 | GT_imgs = F.hflip(GT_imgs) 318 | LR_imgs = F.hflip(LR_imgs) 319 | LR_sr_imgs = F.hflip(LR_sr_imgs) 320 | Ref = F.hflip(Ref) 321 | Ref_sp = F.hflip(Ref_sp) 322 | 323 | if torch.rand(1) < 0.5: 324 | GT_imgs = F.vflip(GT_imgs) 325 | LR_imgs = F.vflip(LR_imgs) 326 | LR_sr_imgs = F.vflip(LR_sr_imgs) 327 | Ref = F.vflip(Ref) 328 | Ref_sp = F.vflip(Ref_sp) 329 | 330 | return {'LR': LR_imgs, 331 | 'LR_sr': LR_sr_imgs, 332 | 'HR': GT_imgs, 333 | 'Ref': Ref, 334 | 'Ref_sp': Ref_sp} 335 | 336 | def __len__(self): 337 | return len(self.GT_imgfiles) 338 | 339 | class EvalSet(data.Dataset): 340 | ''' 341 | Reading the training Vimeo dataset 342 | key example: train/00001/0001/im1.png 343 | GT: Ground-Truth; 344 | LQ: Low-Quality, e.g., low-resolution frames 345 | support reading N HR frames, N = 3, 5, 7 346 | ''' 347 | def __init__(self, args): 348 | super(EvalSet, self).__init__() 349 | self.args = args 350 | 351 | scale = self.args.scale 352 | if scale == 8: 353 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI_x8') 354 | elif scale == 4: 355 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI') 356 | self.GT_dir_list = sorted([os.path.join(args.dataset_dir, 'val/val/val_sharp', name) for name in 357 | ['000', '001', '006', '017']]) 358 | self.LR_dir_list = sorted([os.path.join(LR_root, 'val/val/val_sharp', name) for name in 359 | ['000', '001', '006', '017']]) 360 | # self.GT_dir_list = sorted([os.path.join(args.dataset_dir, 'train/train/train_sharp', name) for name in 361 | # ['000', '011', '015', '020']]) 362 | # self.LR_dir_list = sorted([os.path.join(LR_root, 'train/train/train_sharp', name) for name in 363 | # ['000', '011', '015', '020']]) 364 | N_frames = self.args.N_frames 365 | 366 | self.GT_imgfiles = [] 367 | self.LR_imgfiles = [] 368 | for idx in range(len(self.GT_dir_list)): 369 | GT_imgfiles_cur = sorted(os.listdir(self.GT_dir_list[idx])) 370 | for img_idx in range(0, len(GT_imgfiles_cur) - N_frames + 1): 371 | self.GT_imgfiles.append([os.path.join(self.GT_dir_list[idx], img_f) for img_f in GT_imgfiles_cur[img_idx:img_idx + N_frames]]) 372 | for idx in range(len(self.LR_dir_list)): 373 | LR_imgfiles_cur = sorted(os.listdir(self.LR_dir_list[idx])) 374 | for img_idx in range(0, len(LR_imgfiles_cur) - N_frames + 1): 375 | self.LR_imgfiles.append([os.path.join(self.LR_dir_list[idx], img_f) for img_f in LR_imgfiles_cur[img_idx:img_idx + N_frames]]) 376 | 377 | def __getitem__(self, index): 378 | #### Configs 379 | scale = self.args.scale 380 | GT_size = self.args.GT_size 381 | LR_size = GT_size // scale 382 | FV_size = self.args.FV_size 383 | 384 | ### GT 385 | GT_imgfiles = self.GT_imgfiles[index] 386 | GT_imgs = [np.array(PIL.Image.open(img)) for img in GT_imgfiles] 387 | 388 | #### Bicubic downsampling 389 | ### LR and LR_sr 390 | H_, W_, _ = GT_imgs[0].shape 391 | LR_imgfiles = self.LR_imgfiles[index] 392 | LR_imgs = [np.array(PIL.Image.open(img)) for img in LR_imgfiles] 393 | # LR_imgs = [np.array(PIL.Image.fromarray(img).resize((W_ // 8, H_ // 8), PIL.Image.BILINEAR)) for img in GT_imgs] 394 | LR_sr_imgs = [np.array(PIL.Image.fromarray(img).resize((W_, H_), PIL.Image.BICUBIC)) for img in LR_imgs] 395 | 396 | ### Ref(FV) and Ref_sr 397 | Ref, Ref_sp, fv_sp = fovea_generator(GT_imgs, method='Evenscan', FV_HW=(FV_size, FV_size)) 398 | 399 | #### Stacking 400 | GT_imgs = np.stack(GT_imgs, axis=0) 401 | LR_imgs = np.stack(LR_imgs, axis=0) 402 | LR_sr_imgs = np.stack(LR_sr_imgs, axis=0) 403 | Ref = np.stack(Ref, axis=0) 404 | Ref_sp = np.stack(Ref_sp, axis=0) 405 | 406 | #### Scaling 407 | GT_imgs = GT_imgs.astype(np.float32) / 255. 408 | LR_imgs = LR_imgs.astype(np.float32) / 255. 409 | LR_sr_imgs = LR_sr_imgs.astype(np.float32) / 255. 410 | Ref = Ref.astype(np.float32) / 255. 411 | Ref_sp = Ref_sp.astype(np.bool_) 412 | 413 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float() 414 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float() 415 | LR_sr_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_sr_imgs, (0, 3, 1, 2)))).float() 416 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float() 417 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2)))) 418 | 419 | return {'LR': LR_imgs, 420 | 'LR_sr': LR_sr_imgs, 421 | 'HR': GT_imgs, 422 | 'Ref': Ref, 423 | 'Ref_sp': Ref_sp, 424 | 'FV_sp': fv_sp} 425 | 426 | def __len__(self): 427 | return len(self.GT_imgfiles) 428 | 429 | class TestSet(data.Dataset): 430 | ''' 431 | Reading the training Vimeo dataset 432 | key example: train/00001/0001/im1.png 433 | GT: Ground-Truth; 434 | LQ: Low-Quality, e.g., low-resolution frames 435 | support reading N HR frames, N = 3, 5, 7 436 | ''' 437 | def __init__(self, args): 438 | super(TestSet, self).__init__() 439 | self.args = args 440 | scale = self.args.scale 441 | if scale == 8: 442 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI_x8') 443 | elif scale == 4: 444 | LR_root = args.dataset_dir.replace('_sharp', '_sharp_BI') 445 | 446 | self.GT_dir_list = sorted([os.path.join(args.dataset_dir, 'train/train/train_sharp', name) for name in 447 | ['000', '011', '015', '020']]) 448 | self.LR_dir_list = sorted([os.path.join(LR_root, 'train/train/train_sharp', name) for name in 449 | ['000', '011', '015', '020']]) 450 | 451 | N_frames = self.args.N_frames 452 | 453 | self.GT_imgfiles = [] 454 | self.LR_imgfiles = [] 455 | for idx in range(len(self.GT_dir_list)): 456 | GT_imgfiles_cur = sorted(os.listdir(self.GT_dir_list[idx])) 457 | for img_idx in range(0, len(GT_imgfiles_cur) - N_frames + 1): 458 | self.GT_imgfiles.append([os.path.join(self.GT_dir_list[idx], img_f) for img_f in GT_imgfiles_cur[img_idx:img_idx + N_frames]]) 459 | for idx in range(len(self.LR_dir_list)): 460 | LR_imgfiles_cur = sorted(os.listdir(self.LR_dir_list[idx])) 461 | for img_idx in range(0, len(LR_imgfiles_cur) - N_frames + 1): 462 | self.LR_imgfiles.append([os.path.join(self.LR_dir_list[idx], img_f) for img_f in LR_imgfiles_cur[img_idx:img_idx + N_frames]]) 463 | 464 | def __getitem__(self, index): 465 | #### Configs 466 | scale = self.args.scale 467 | GT_size = self.args.GT_size 468 | LR_size = GT_size // scale 469 | FV_size = self.args.FV_size 470 | 471 | ### GT 472 | GT_imgfiles = self.GT_imgfiles[index] 473 | GT_imgs = [np.array(PIL.Image.open(img)) for img in GT_imgfiles] 474 | 475 | #### Bicubic downsampling 476 | ### LR and LR_sr 477 | H_, W_, _ = GT_imgs[0].shape 478 | LR_imgfiles = self.LR_imgfiles[index] 479 | LR_imgs = [np.array(PIL.Image.open(img)) for img in LR_imgfiles] 480 | LR_sr_imgs = [np.array(PIL.Image.fromarray(img).resize((W_, H_), PIL.Image.BICUBIC)) for img in LR_imgs] 481 | 482 | ### Ref(FV) and Ref_sr 483 | Ref, Ref_sp, fv_sp = fovea_generator(GT_imgs, method='Evenscan', FV_HW=(FV_size, FV_size)) 484 | 485 | #### Stacking 486 | GT_imgs = np.stack(GT_imgs, axis=0) 487 | LR_imgs = np.stack(LR_imgs, axis=0) 488 | LR_sr_imgs = np.stack(LR_sr_imgs, axis=0) 489 | Ref = np.stack(Ref, axis=0) 490 | Ref_sp = np.stack(Ref_sp, axis=0) 491 | 492 | #### Scaling 493 | GT_imgs = GT_imgs.astype(np.float32) / 255. 494 | LR_imgs = LR_imgs.astype(np.float32) / 255. 495 | LR_sr_imgs = LR_sr_imgs.astype(np.float32) / 255. 496 | Ref = Ref.astype(np.float32) / 255. 497 | Ref_sp = Ref_sp.astype(np.bool_) 498 | 499 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float() 500 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float() 501 | LR_sr_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_sr_imgs, (0, 3, 1, 2)))).float() 502 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float() 503 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2)))) 504 | 505 | return {'LR': LR_imgs, 506 | 'LR_sr': LR_sr_imgs, 507 | 'HR': GT_imgs, 508 | 'Ref': Ref, 509 | 'Ref_sp': Ref_sp, 510 | 'FV_sp': fv_sp} 511 | 512 | def __len__(self): 513 | return len(self.GT_imgfiles) -------------------------------------------------------------------------------- /dataset/vimeo7.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import pickle 5 | import logging 6 | import numpy as np 7 | import PIL 8 | import pdb 9 | import math 10 | 11 | import torch 12 | import torch.utils.data as data 13 | import torchvision.transforms.functional as F 14 | import torch.nn.functional as nnF 15 | from torchvision.transforms import Compose, ToTensor 16 | 17 | logger = logging.getLogger('base') 18 | 19 | def gaussian_downsample(x, scale=4): 20 | """Downsamping with Gaussian kernel used in the DUF official code 21 | Args: 22 | x (Tensor, [C, T, H, W]): frames to be downsampled. 23 | scale (int): downsampling factor: 2 | 3 | 4. 24 | """ 25 | 26 | assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale) 27 | 28 | def gkern(kernlen=13, nsig=1.6): 29 | import scipy.ndimage.filters as fi 30 | inp = np.zeros((kernlen, kernlen)) 31 | # set element at the middle to one, a dirac delta 32 | inp[kernlen // 2, kernlen // 2] = 1 33 | # gaussian-smooth the dirac, resulting in a gaussian filter mask 34 | return fi.gaussian_filter(inp, nsig) 35 | 36 | if scale == 2: 37 | h = gkern(13, 0.8) # 13 and 0.8 for x2 38 | elif scale == 3: 39 | h = gkern(13, 1.2) # 13 and 1.2 for x3 40 | elif scale == 4: 41 | h = gkern(13, 1.6) # 13 and 1.6 for x4 42 | else: 43 | print('Invalid upscaling factor: {} (Must be one of 2, 3, 4)'.format(R)) 44 | exit(1) 45 | 46 | C, T, H, W = x.size() 47 | x = x.contiguous().view(-1, 1, H, W) # depth convolution (channel-wise convolution) 48 | pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter 49 | r_h, r_w = 0, 0 50 | 51 | if scale == 3: 52 | r_h = 3 - (H % 3) 53 | r_w = 3 - (W % 3) 54 | 55 | x = nnF.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], mode='reflect') 56 | gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0) 57 | x = nnF.conv2d(x, gaussian_filter, stride=scale) 58 | # please keep the operation same as training. 59 | # if downsample to 32 on training time, use the below code. 60 | x = x[:, :, 2:-2, 2:-2] 61 | # if downsample to 28 on training time, use the below code. 62 | #x = x[:,:,scale:-scale,scale:-scale] 63 | x = x.view(C, T, x.size(2), x.size(3)).permute(1, 0 ,2, 3) 64 | return x 65 | 66 | def fovea_generator(GT_imgs, method='Rscan', step=0.1, FV_HW=(32, 32)): 67 | 68 | len_sp = len(GT_imgs) 69 | if torch.is_tensor(GT_imgs): 70 | _, GT_H, GT_W = GT_imgs[0].size() 71 | else: 72 | GT_H, GT_W, _ = GT_imgs[0].shape 73 | 74 | FV_H, FV_W = FV_HW 75 | if method == 'Cscan' or method == 'Zscan': 76 | SP = 0.1 77 | CP = 0.5 78 | EP = 0.9 79 | else: 80 | SP = 0.1 81 | CP = 0.5 82 | EP = 0.9 83 | 84 | #### shift according to center point 85 | CP_H = (GT_H*CP - FV_H//2)/GT_H 86 | CP_W = (GT_W*CP - FV_W//2)/GT_W 87 | EP_H = (GT_H*EP - FV_H)/GT_H 88 | EP_W = (GT_W*EP - FV_W)/GT_W 89 | 90 | #### finetune step size 91 | if method == 'Cscan' or method == 'Zscan': 92 | if SP + math.ceil(math.sqrt(len_sp)) * step > EP_H or SP + math.ceil(math.sqrt(len_sp)) * step > EP_W: 93 | step = min((EP_H - SP) / math.ceil(math.sqrt(len_sp)), (EP_W - SP) / math.ceil(math.sqrt(len_sp))) 94 | SP = int(SP * 100) 95 | step = int(step * 100) 96 | EP = int(SP + math.ceil(math.sqrt(len_sp) - 1) * step) 97 | elif method == 'Hscan': 98 | if SP + len_sp * step > EP_W: 99 | step = (EP_W - SP) / len_sp 100 | SP = int(SP * 100) 101 | step = int(step * 100) 102 | EP = int((SP + len_sp * step)) 103 | elif method == 'Vscan': 104 | if SP + len_sp * step > EP_H: 105 | step = (EP_H - SP) / len_sp 106 | SP = int(SP * 100) 107 | step = int(step * 100) 108 | EP = int((SP + len_sp * step)) 109 | else: 110 | if SP + len_sp * step > EP_H or SP + len_sp * step > EP_W: 111 | step = min((EP_H - SP) / len_sp, (EP_W - SP) / len_sp) 112 | SP = int(SP * 100) 113 | step = int(step * 100) 114 | EP = int((SP + len_sp * step)) 115 | 116 | #### fovea scan simulation 117 | if method == 'Hscan': 118 | fv_sp = [[int(CP_H * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]] 119 | elif method == 'Vscan': 120 | fv_sp = [[int((v / 100) * GT_H), int(CP_W * GT_W)] for v in [*range(SP, EP, step)]] 121 | elif method == 'DRscan': # Not done 122 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]] 123 | elif method == 'DLscan': # Not done 124 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]] 125 | elif method == 'Cscan': 126 | fv_sp = [] 127 | v, h = (SP, SP) 128 | v_step, h_step = (step, step) 129 | for t in range(len_sp): 130 | fv_sp.append([int((v / 100) * GT_H), int((h / 100) * GT_W)]) 131 | if h == EP and h_step > 0: 132 | h_step = -h_step 133 | v += v_step 134 | elif h == SP and h_step < 0: 135 | h_step = -h_step 136 | v += v_step 137 | else: 138 | h += h_step 139 | elif method == 'Zscan': 140 | fv_sp = [] 141 | v, h = (SP, SP) 142 | v_step, h_step = (step, step) 143 | for t in range(len_sp): 144 | fv_sp.append([int((v / 100) * GT_H), int((h / 100) * GT_W)]) 145 | if h == EP and v_step < 0: 146 | v_step = -v_step 147 | v += v_step 148 | h_step = -abs(h_step) 149 | elif v == SP and h_step > 0: 150 | h += h_step 151 | h_step = -h_step 152 | v_step = abs(v_step) 153 | elif v == EP and h_step < 0: 154 | h_step = -h_step 155 | h += h_step 156 | v_step = -abs(v_step) 157 | elif h == SP and v_step > 0: 158 | v += v_step 159 | v_step = -v_step 160 | h_step = abs(h_step) 161 | else: 162 | h += h_step 163 | v += v_step 164 | elif method == 'Rscan': 165 | sigma = 0.05 166 | rand_h = np.random.normal(CP_H, sigma, len_sp).clip(0, EP_H) 167 | rand_w = np.random.normal(CP_W, sigma, len_sp).clip(0, EP_W) 168 | fv_sp = [[int(rh * GT_H), int(rw * GT_W)] for rh, rw in zip(rand_h, rand_w)] 169 | elif method == 'Nanascan': # Not done 170 | SP_H = 0 171 | EP_H = (GT_H - FV_H - 1) / GT_H 172 | Q1_H = (0.25 - (FV_H / GT_H)/2) if (0.25 - (FV_H / GT_H)/2) > 0 else SP_H 173 | Q2_H = (0.50 - (FV_H / GT_H)/2) 174 | Q3_H = (0.75 - (FV_H / GT_H)/2) if (0.75 + (FV_H / GT_H)/2) <= 1 else EP_H 175 | T1_H = (0.33 - (FV_H / GT_H)/2) if (0.33 - (FV_H / GT_H)/2) > 0 else SP_H 176 | T2_H = (0.66 - (FV_H / GT_H)/2) if (0.66 + (FV_H / GT_H)/2) <= 1 else EP_H 177 | SP_W = 0 178 | EP_W = (GT_W - FV_W - 1) / GT_W 179 | Q1_W = (0.25 - (FV_W / GT_W)/2) if (0.25 - (FV_W / GT_W)/2) > 0 else SP_W 180 | Q2_W = (0.50 - (FV_W / GT_W)/2) 181 | Q3_W = (0.75 - (FV_W / GT_W)/2) if (0.75 + (FV_W / GT_W)/2) <= 1 else EP_W 182 | T1_W = (0.33 - (FV_W / GT_W)/2) if (0.33 - (FV_W / GT_W)/2) > 0 else SP_W 183 | T2_W = (0.66 - (FV_W / GT_W)/2) if (0.66 + (FV_W / GT_W)/2) <= 1 else EP_W 184 | 185 | fv_sp = [[Q1_H, T1_W], [Q1_H, T2_W], [Q2_H, Q1_W], [Q2_H, Q2_W], [Q2_H, Q3_W], [Q3_H, T1_W], [Q3_H, T2_W]] 186 | fv_sp = [[int(v[0] * GT_H), int(v[1] * GT_W)] for v in fv_sp] 187 | random.shuffle(fv_sp) 188 | else: 189 | fv_sp = [[int((v / 100) * GT_H), int((v / 100) * GT_W)] for v in [*range(SP, EP, step)]] 190 | 191 | fv_sp = torch.tensor(fv_sp) 192 | if torch.is_tensor(GT_imgs): 193 | FV_imgs = [] 194 | Ref_sps = [] 195 | for t in range(len(GT_imgs)): 196 | #### With padding #### 197 | Ref_sp = torch.zeros_like(GT_imgs[t]) 198 | Ref_sp[:, fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W] = 1 199 | Ref = GT_imgs[t] * Ref_sp 200 | FV_imgs.append(Ref) 201 | Ref_sps.append(Ref_sp) 202 | #### Without padding #### 203 | # FV_imgs.append(GT_imgs[t][:, fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W].clone()) 204 | # FV_imgs = torch.stack(FV_imgs, dim=0) 205 | else: 206 | FV_imgs = [] 207 | Ref_sps = [] 208 | for t in range(len(GT_imgs)): 209 | #### With padding #### 210 | # Ref_sp = np.zeros_like(GT_imgs[t]) 211 | H, W, C = GT_imgs[t].shape 212 | Ref_sp = np.zeros((H, W, 1)) 213 | Ref_sp[fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W, :] = 1 214 | Ref = GT_imgs[t] * Ref_sp 215 | FV_imgs.append(Ref) 216 | Ref_sps.append(Ref_sp) 217 | #### Without padding #### 218 | # FV_imgs.append(GT_imgs[t][fv_sp[t][0]:fv_sp[t][0] + FV_H, fv_sp[t][1]:fv_sp[t][1] + FV_W, :].copy()) 219 | # FV_imgs = np.stack(FV_imgs, dim=0) 220 | 221 | return FV_imgs, Ref_sps, fv_sp 222 | # return FV_imgs, fv_sp 223 | 224 | class TrainSet(data.Dataset): 225 | ''' 226 | Reading the training Vimeo dataset 227 | key example: train/00001/0001/im1.png 228 | GT: Ground-Truth; 229 | LQ: Low-Quality, e.g., low-resolution frames 230 | support reading N HR frames, N = 3, 5, 7 231 | ''' 232 | def __init__(self, args): 233 | super(TrainSet, self).__init__() 234 | self.args = args 235 | 236 | self.LR_dir_list = [] 237 | self.GT_dir_list = [] 238 | GT_list = open(os.path.join(args.dataset_dir, 'sep_trainlist.txt'), 'r') 239 | LR_root = args.dataset_dir.replace('90k', '90k_BD') 240 | for line in GT_list.readlines(): 241 | self.GT_dir_list.append(os.path.join(args.dataset_dir, 'sequences', line.strip())) 242 | self.LR_dir_list.append(os.path.join(LR_root, 'sequences', line.strip())) 243 | 244 | self.transform = Compose([ToTensor()]) 245 | 246 | def __getitem__(self, index): 247 | #### Configs 248 | scale = self.args.scale 249 | GT_size = self.args.GT_size 250 | LR_size = GT_size // scale 251 | FV_size = self.args.FV_size 252 | 253 | #### Bicubic downsampling 254 | ### GT 255 | GT_imgfiles = sorted(os.listdir(self.GT_dir_list[index])) 256 | GT_imgs = [np.array(PIL.Image.open(os.path.join(self.GT_dir_list[index], img))) for img in GT_imgfiles] 257 | 258 | ### LR and LR_sr 259 | H, W, C = GT_imgs[0].shape 260 | LR_imgs = [np.array(PIL.Image.fromarray(img).resize((W // scale, H // scale), PIL.Image.BICUBIC)) for img in GT_imgs] 261 | 262 | #### Random cropping 263 | H, W, C = LR_imgs[0].shape 264 | rnd_h = random.randint(0, max(0, H - LR_size)) 265 | rnd_w = random.randint(0, max(0, W - LR_size)) 266 | LR_imgs = [v[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] for v in LR_imgs] 267 | 268 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) 269 | GT_imgs = [v[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] for v in GT_imgs] 270 | 271 | ### Ref(FV) and Ref_sr 272 | Ref, Ref_sp, _ = fovea_generator(GT_imgs, method='Nanascan', FV_HW=(FV_size, FV_size)) 273 | 274 | #### Stacking 275 | GT_imgs = np.stack(GT_imgs, axis=0) 276 | LR_imgs = np.stack(LR_imgs, axis=0) 277 | Ref = np.stack(Ref, axis=0) 278 | Ref_sp = np.stack(Ref_sp, axis=0) 279 | 280 | GT_imgs = GT_imgs.astype(np.float32) / 255. 281 | LR_imgs = LR_imgs.astype(np.float32) / 255. 282 | Ref = Ref.astype(np.float32) / 255. 283 | Ref_sp = Ref_sp.astype(np.bool_) 284 | 285 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float() 286 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float() 287 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float() 288 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2)))).float() 289 | 290 | if torch.rand(1) < 0.5: 291 | GT_imgs = F.hflip(GT_imgs) 292 | LR_imgs = F.hflip(LR_imgs) 293 | Ref = F.hflip(Ref) 294 | Ref_sp = F.hflip(Ref_sp) 295 | 296 | if torch.rand(1) < 0.5: 297 | GT_imgs = F.vflip(GT_imgs) 298 | LR_imgs = F.vflip(LR_imgs) 299 | Ref = F.vflip(Ref) 300 | Ref_sp = F.vflip(Ref_sp) 301 | 302 | return {'LR': LR_imgs, 303 | 'HR': GT_imgs, 304 | 'Ref': Ref, 305 | 'Ref_sp': Ref_sp} 306 | 307 | def __len__(self): 308 | return len(self.GT_dir_list) 309 | 310 | class EvalSet(data.Dataset): 311 | ''' 312 | Reading the training Vimeo dataset 313 | key example: train/00001/0001/im1.png 314 | GT: Ground-Truth; 315 | LQ: Low-Quality, e.g., low-resolution frames 316 | support reading N HR frames, N = 3, 5, 7 317 | ''' 318 | def __init__(self, args): 319 | super(EvalSet, self).__init__() 320 | self.args = args 321 | 322 | self.LR_dir_list = [] 323 | self.GT_dir_list = [] 324 | GT_list = open(os.path.join(args.dataset_dir, 'sep_testlist.txt'), 'r') 325 | LR_root = args.dataset_dir.replace('90k', '90k_BD') 326 | for line in GT_list.readlines(): 327 | self.GT_dir_list.append(os.path.join(args.dataset_dir, 'sequences', line.strip())) 328 | self.LR_dir_list.append(os.path.join(LR_root, 'sequences', line.strip())) 329 | 330 | self.transform = Compose([ToTensor()]) 331 | 332 | def __getitem__(self, index): 333 | #### Configs 334 | scale = self.args.scale 335 | GT_size = self.args.GT_size 336 | LR_size = GT_size // scale 337 | FV_size = self.args.FV_size 338 | 339 | #### Bicubic downsampling 340 | ### GT 341 | GT_imgfiles = sorted(os.listdir(self.GT_dir_list[index])) 342 | GT_imgs = [np.array(PIL.Image.open(os.path.join(self.GT_dir_list[index], img))) for img in GT_imgfiles] 343 | 344 | ### LR and LR_sr 345 | H, W, C = GT_imgs[0].shape 346 | LR_imgs = [np.array(PIL.Image.fromarray(img).resize((W // scale, H // scale), PIL.Image.BICUBIC)) for img in GT_imgs] 347 | 348 | ### Ref(FV) and Ref_sr 349 | Ref, Ref_sp, _ = fovea_generator(GT_imgs, method='Nanascan', FV_HW=(FV_size, FV_size)) 350 | 351 | #### Stacking 352 | GT_imgs = np.stack(GT_imgs, axis=0) 353 | LR_imgs = np.stack(LR_imgs, axis=0) 354 | Ref = np.stack(Ref, axis=0) 355 | Ref_sp = np.stack(Ref_sp, axis=0) 356 | 357 | #### Scaling 358 | GT_imgs = GT_imgs.astype(np.float32) / 255. 359 | LR_imgs = LR_imgs.astype(np.float32) / 255. 360 | Ref = Ref.astype(np.float32) / 255. 361 | Ref_sp = Ref_sp.astype(np.bool_) 362 | 363 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float() 364 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float() 365 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float() 366 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2)))) 367 | 368 | return {'LR': LR_imgs, 369 | 'HR': GT_imgs, 370 | 'Ref': Ref, 371 | 'Ref_sp': Ref_sp} 372 | 373 | def __len__(self): 374 | return len(self.GT_dir_list) 375 | 376 | class TestSet(data.Dataset): 377 | ''' 378 | Reading the training Vimeo dataset 379 | key example: train/00001/0001/im1.png 380 | GT: Ground-Truth; 381 | LQ: Low-Quality, e.g., low-resolution frames 382 | support reading N HR frames, N = 3, 5, 7 383 | ''' 384 | def __init__(self, args): 385 | super(TestSet, self).__init__() 386 | self.args = args 387 | 388 | self.LR_dir_list = [] 389 | self.GT_dir_list = [] 390 | GT_list = open(os.path.join(args.dataset_dir, 'slow_testset.txt'), 'r') 391 | LR_root = args.dataset_dir.replace('90k', '90k_LR') 392 | for line in GT_list.readlines(): 393 | self.GT_dir_list.append(os.path.join(args.dataset_dir, 'sequences', line.strip())) 394 | self.LR_dir_list.append(os.path.join(LR_root, 'sequences', line.strip())) 395 | 396 | def __getitem__(self, index): 397 | #### Configs 398 | scale = self.args.scale 399 | 400 | #### Bicubic downsampling 401 | ### GT 402 | GT_imgfiles = sorted(os.listdir(self.GT_dir_list[index])) 403 | GT_imgs = [np.array(PIL.Image.open(os.path.join(self.GT_dir_list[index], img))) for img in GT_imgfiles] 404 | GT_H, GT_W = GT_imgs[0].shape[:2] 405 | LR_H, LR_W = GT_H // scale, GT_W // scale 406 | FV_size = self.args.FV_size 407 | ### LR and LR_sr 408 | LR_imgs = [np.array(PIL.Image.fromarray(img).resize((LR_W, LR_H), PIL.Image.BICUBIC)) for img in GT_imgs] 409 | 410 | ### Ref(FV) and Ref_sr 411 | Ref, Ref_sp, fv_sp = fovea_generator(GT_imgs, method='Hscan', step=0.2, FV_HW=(FV_size, FV_size)) 412 | 413 | #### Stacking 414 | GT_imgs = np.stack(GT_imgs, axis=0) 415 | LR_imgs = np.stack(LR_imgs, axis=0) 416 | Ref = np.stack(Ref, axis=0) 417 | Ref_sp = np.stack(Ref_sp, axis=0) 418 | 419 | #### Scaling 420 | GT_imgs = GT_imgs.astype(np.float32) / 255. 421 | LR_imgs = LR_imgs.astype(np.float32) / 255. 422 | Ref = Ref.astype(np.float32) / 255. 423 | Ref_sp = Ref_sp.astype(np.bool_) 424 | 425 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float() 426 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float() 427 | Ref = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref, (0, 3, 1, 2)))).float() 428 | Ref_sp = torch.from_numpy(np.ascontiguousarray(np.transpose(Ref_sp, (0, 3, 1, 2)))) 429 | 430 | return {'LR': LR_imgs, 431 | 'HR': GT_imgs, 432 | 'Ref': Ref, 433 | 'Ref_sp': Ref_sp, 434 | 'FV_sp': fv_sp} 435 | 436 | def __len__(self): 437 | return len(self.GT_dir_list) -------------------------------------------------------------------------------- /demo/sigma10.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/demo/sigma10.gif -------------------------------------------------------------------------------- /demo/sigma100.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/demo/sigma100.gif -------------------------------------------------------------------------------- /demo/sigma50.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/demo/sigma50.gif -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | ### evaluation 2 | python3 main.py --save_dir ./eval/REDS/FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_cra \ 3 | --reset True \ 4 | --num_gpu 1 \ 5 | --gpu_id 0 \ 6 | --log_file_name eval.log \ 7 | --eval True \ 8 | --eval_save_results True \ 9 | --num_workers 1 \ 10 | --scale 8 \ 11 | --cra true \ 12 | --mrcf true \ 13 | --hr_dcn true \ 14 | --offset_prop true \ 15 | --N_frames 15 \ 16 | --FV_size 96 \ 17 | --dataset Reds \ 18 | --dataset_dir /DATA/REDS_sharp/ \ 19 | --model_path ./train/REDS/FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_cra/model/ \ 20 | --visdom_port 8803 \ 21 | --visdom_view 0301_FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_cra -------------------------------------------------------------------------------- /gen_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | 4 | model_code = 13 5 | hr_dcn = False 6 | offset_prop = False 7 | split_ratio = 3 8 | model_name = 'FVSR_x8_simple_v{}_hrdcn_{}_offsetprop_{}_fnet{}'.format(model_code, 'y' if hr_dcn else 'n', 9 | 'y' if offset_prop else 'n', 10 | '_{}outof4'.format(4-split_ratio) if model_code == 18 else '') 11 | # print('Current model name: {}'.format(model_name)) 12 | 13 | dir_root = 'test_png' 14 | gt_root = '{}/GroundTruth/'.format(dir_root) 15 | model_names = os.listdir(dir_root) 16 | for model_name in model_names: 17 | if model_name == 'GroundTruth' or model_name == 'eval_video' or model_name == 'results': 18 | continue 19 | print('Current model name: {}'.format(model_name)) 20 | img_root = '{}/{}/'.format(dir_root, model_name) 21 | save_root = '{}/{}/'.format(dir_root, 'results') 22 | if not os.path.exists(save_root): 23 | os.makedirs(save_root) 24 | 25 | #### 000 Past Foveated Region 26 | video_num = 0 27 | img_0_num = 66 28 | img_1_num = 75 29 | img_dir_path = os.path.join(gt_root, str(video_num)) 30 | save_path = os.path.join(save_root, model_name, 'pastfv') 31 | if not os.path.exists(save_path): 32 | os.makedirs(save_path) 33 | begin_x_0 = 20 34 | begin_y_0 = 350 35 | begin_x_1 = 550 36 | begin_y_1 = 350 37 | begin_x_2 = 5 38 | begin_y_2 = 210 39 | img_w = 350 40 | img_h = 350 41 | img_w_0 = 80 42 | img_h_0 = 80 43 | hr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num)) 44 | hr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num)) 45 | hr_img_0 = cv2.imread(hr_img_0) 46 | hr_img_1 = cv2.imread(hr_img_1) 47 | GT_H, GT_W, _ = hr_img_0.shape 48 | cv2.rectangle(hr_img_0, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), (255, 51, 153), 3) 49 | cv2.rectangle(hr_img_1, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), ( 51, 255, 153), 3) 50 | cv2.rectangle(hr_img_1, (begin_x_1, begin_y_1), (begin_x_1+img_w, begin_y_1+img_h), ( 51, 153, 255), 3) 51 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_0_num)), hr_img_0) 52 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_1_num)), hr_img_1) 53 | 54 | dirs = os.listdir(os.path.join(img_root, str(video_num))) 55 | for dir in dirs: 56 | if dir == 'traj.png': 57 | continue 58 | img_dir_path = os.path.join(img_root, str(video_num), dir) 59 | sr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num)) 60 | sr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num)) 61 | sr_img_0 = cv2.imread(sr_img_0) 62 | sr_img_1 = cv2.imread(sr_img_1) 63 | sr_img_2 = sr_img_1.copy() 64 | if dir == 'results': 65 | sr_img_0 = sr_img_0[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :] 66 | sr_img_1 = sr_img_1[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :] 67 | sr_img_2 = sr_img_2[begin_y_1:begin_y_1+img_h, begin_x_1:begin_x_1+img_w, :] 68 | sr_img_3 = sr_img_1[begin_y_2:begin_y_2+img_h_0, begin_x_2:begin_x_2+img_w_0, :].copy() 69 | sr_img_4 = sr_img_1.copy() 70 | cv2.rectangle(sr_img_4, (begin_x_2, begin_y_2), (begin_x_2+img_w_0, begin_y_2+img_h_0), (255, 51, 153), 3) 71 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_3.png'.format(img_1_num, dir)), sr_img_3) 72 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_1_line.png'.format(img_1_num, dir)), sr_img_4) 73 | else: 74 | H, W, _ = sr_img_0.shape 75 | a = H / GT_H 76 | b = W / GT_W 77 | sr_img_0 = sr_img_0[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :] 78 | sr_img_1 = sr_img_1[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :] 79 | sr_img_2 = sr_img_2[int(begin_y_1*a):int((begin_y_1+img_h)*a), int(begin_x_1*b):int((begin_x_1+img_w)*b), :] 80 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_0_num, dir)), sr_img_0) 81 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_1.png'.format(img_1_num, dir)), sr_img_1) 82 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_2.png'.format(img_1_num, dir)), sr_img_2) 83 | 84 | #### 011 Whole Region 85 | video_num = 11 86 | img_0_num = 30 87 | img_1_num = 36 88 | img_dir_path = os.path.join(gt_root, str(video_num)) 89 | save_path = os.path.join(save_root, model_name, 'whole') 90 | if not os.path.exists(save_path): 91 | os.makedirs(save_path) 92 | begin_x_0 = 200 93 | begin_y_0 = 100 94 | begin_x_1 = 500 95 | begin_y_1 = 100 96 | begin_x_2 = 360 97 | begin_y_2 = 100 98 | begin_x_3 = 60 99 | begin_y_3 = 90 100 | begin_x_4 = 60 101 | begin_y_4 = 40 102 | begin_x_5 = 200 103 | begin_y_5 = 60 104 | img_w = 750 105 | img_h = 450 106 | img_w_ = 360 107 | img_h_ = 216 108 | img_w__ = 100 109 | img_h__ = 120 110 | hr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num)) 111 | hr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num)) 112 | hr_img_0 = cv2.imread(hr_img_0) 113 | hr_img_1 = cv2.imread(hr_img_1) 114 | hr_img_2 = hr_img_0.copy() 115 | hr_img_3 = hr_img_1.copy() 116 | GT_H, GT_W, _ = hr_img_0.shape 117 | cv2.rectangle(hr_img_0, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), (255, 51, 153), 3) 118 | cv2.rectangle(hr_img_1, (begin_x_1, begin_y_1), (begin_x_1+img_w, begin_y_1+img_h), ( 51, 153, 255), 3) 119 | cv2.rectangle(hr_img_2, (begin_x_2, begin_y_2), (begin_x_2+img_w, begin_y_2+img_h), (255, 51, 153), 3) 120 | cv2.rectangle(hr_img_3, (begin_x_2, begin_y_2), (begin_x_2+img_w, begin_y_2+img_h), ( 51, 153, 255), 3) 121 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_0_num)), hr_img_0) 122 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_1_num)), hr_img_1) 123 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line_2.png'.format(img_0_num)), hr_img_2) 124 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line_2.png'.format(img_1_num)), hr_img_3) 125 | 126 | dirs = os.listdir(os.path.join(img_root, str(video_num))) 127 | for dir in dirs: 128 | if dir == 'traj.png': 129 | continue 130 | img_dir_path = os.path.join(img_root, str(video_num), dir) 131 | sr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num)) 132 | sr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num)) 133 | sr_img_0 = cv2.imread(sr_img_0) 134 | sr_img_1 = cv2.imread(sr_img_1) 135 | sr_img_2 = sr_img_0.copy() 136 | sr_img_3 = sr_img_1.copy() 137 | if dir == 'results': 138 | sr_img_0 = sr_img_0[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :] 139 | sr_img_1 = sr_img_1[begin_y_1:begin_y_1+img_h, begin_x_1:begin_x_1+img_w, :] 140 | sr_img_2 = sr_img_2[begin_y_2:begin_y_2+img_h, begin_x_2:begin_x_2+img_w, :] 141 | sr_img_3 = sr_img_3[begin_y_2:begin_y_2+img_h, begin_x_2:begin_x_2+img_w, :] 142 | cv2.rectangle(sr_img_2, (begin_x_3, begin_y_3), (begin_x_3+img_w_, begin_y_3+img_h_), ( 51, 153, 255), 3) 143 | cv2.rectangle(sr_img_3, (begin_x_3, begin_y_3), (begin_x_3+img_w_, begin_y_3+img_h_), ( 51, 153, 255), 3) 144 | sr_img_4 = sr_img_2[begin_y_3:begin_y_3+img_h_, begin_x_3:begin_x_3+img_w_, :] 145 | sr_img_5 = sr_img_3[begin_y_3:begin_y_3+img_h_, begin_x_3:begin_x_3+img_w_, :] 146 | sr_img_6 = sr_img_5[begin_y_4:begin_y_4+img_h__, begin_x_4:begin_x_4+img_w__, :].copy() 147 | sr_img_7 = sr_img_5[begin_y_5:begin_y_5+img_h__, begin_x_5:begin_x_5+img_w__, :].copy() 148 | cv2.rectangle(sr_img_5, (begin_x_4, begin_y_4), (begin_x_4+img_w__, begin_y_4+img_h__), ( 51, 153, 255), 3) 149 | cv2.rectangle(sr_img_5, (begin_x_5, begin_y_5), (begin_x_5+img_w__, begin_y_5+img_h__), ( 51, 153, 255), 3) 150 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_3.png'.format(img_0_num, dir)), sr_img_4) 151 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_3.png'.format(img_1_num, dir)), sr_img_5) 152 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_4.png'.format(img_1_num, dir)), sr_img_6) 153 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_5.png'.format(img_1_num, dir)), sr_img_7) 154 | else: 155 | H, W, _ = sr_img_0.shape 156 | a = H / GT_H 157 | b = W / GT_W 158 | sr_img_0 = sr_img_0[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :] 159 | sr_img_1 = sr_img_1[int(begin_y_1*a):int((begin_y_1+img_h)*a), int(begin_x_1*b):int((begin_x_1+img_w)*b), :] 160 | sr_img_2 = sr_img_2[int(begin_y_2*a):int((begin_y_2+img_h)*a), int(begin_x_2*b):int((begin_x_2+img_w)*b), :] 161 | sr_img_3 = sr_img_3[int(begin_y_2*a):int((begin_y_2+img_h)*a), int(begin_x_2*b):int((begin_x_2+img_w)*b), :] 162 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_0_num, dir)), sr_img_0) 163 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_1_num, dir)), sr_img_1) 164 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_2.png'.format(img_0_num, dir)), sr_img_2) 165 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}_2.png'.format(img_1_num, dir)), sr_img_3) 166 | 167 | #### 015 Title 168 | video_num = 15 169 | img_0_num = 31 170 | img_1_num = 36 171 | img_2_num = 43 172 | img_dir_path = os.path.join(gt_root, str(video_num)) 173 | save_path = os.path.join(save_root, model_name, 'title') 174 | if not os.path.exists(save_path): 175 | os.makedirs(save_path) 176 | begin_x_0 = 350 177 | begin_y_0 = 140 178 | begin_x_1 = 700 179 | begin_y_1 = 140 180 | img_w = 500 181 | img_h = 300 182 | hr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num)) 183 | hr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num)) 184 | hr_img_2 = os.path.join(img_dir_path, '{:03d}.png'.format(img_2_num)) 185 | hr_img_0 = cv2.imread(hr_img_0) 186 | hr_img_1 = cv2.imread(hr_img_1) 187 | hr_img_2 = cv2.imread(hr_img_2) 188 | GT_H, GT_W, _ = hr_img_0.shape 189 | cv2.rectangle(hr_img_0, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), ( 255, 51, 153), 3) 190 | cv2.rectangle(hr_img_1, (begin_x_1, begin_y_1), (begin_x_1+img_w, begin_y_1+img_h), ( 51, 153, 255), 3) 191 | cv2.rectangle(hr_img_2, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), ( 51, 153, 255), 3) 192 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_0_num)), hr_img_0) 193 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_1_num)), hr_img_1) 194 | cv2.imwrite(os.path.join(save_path, '{:03d}_hr_line.png'.format(img_2_num)), hr_img_2) 195 | 196 | dirs = os.listdir(os.path.join(img_root, str(video_num))) 197 | for dir in dirs: 198 | if dir == 'traj.png': 199 | continue 200 | img_dir_path = os.path.join(img_root, str(video_num), dir) 201 | sr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num)) 202 | sr_img_1 = os.path.join(img_dir_path, '{:03d}.png'.format(img_1_num)) 203 | sr_img_2 = os.path.join(img_dir_path, '{:03d}.png'.format(img_2_num)) 204 | sr_img_0 = cv2.imread(sr_img_0) 205 | sr_img_1 = cv2.imread(sr_img_1) 206 | sr_img_2 = cv2.imread(sr_img_2) 207 | if dir == 'results': 208 | sr_img_0 = sr_img_0[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :] 209 | sr_img_1 = sr_img_1[begin_y_1:begin_y_1+img_h, begin_x_1:begin_x_1+img_w, :] 210 | sr_img_2 = sr_img_2[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :] 211 | else: 212 | H, W, _ = sr_img_0.shape 213 | a = H / GT_H 214 | b = W / GT_W 215 | sr_img_0 = sr_img_0[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :] 216 | sr_img_1 = sr_img_1[int(begin_y_1*a):int((begin_y_1+img_h)*a), int(begin_x_1*b):int((begin_x_1+img_w)*b), :] 217 | sr_img_2 = sr_img_2[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :] 218 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_0_num, dir)), sr_img_0) 219 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_1_num, dir)), sr_img_1) 220 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_2_num, dir)), sr_img_2) 221 | 222 | #### 020 Gaussian 223 | video_num = 20 224 | img_0_num = 99 225 | img_dir_path = os.path.join(img_root, str(video_num), 'results') 226 | save_path = os.path.join(save_root, model_name, 'title') 227 | if not os.path.exists(save_path): 228 | os.makedirs(save_path) 229 | begin_x_0 = 320 230 | begin_y_0 = 180 231 | img_w = 640 232 | img_h = 360 233 | hr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num)) 234 | hr_img_0 = cv2.imread(hr_img_0) 235 | GT_H, GT_W, _ = hr_img_0.shape 236 | cv2.rectangle(hr_img_0, (begin_x_0, begin_y_0), (begin_x_0+img_w, begin_y_0+img_h), ( 255, 51, 153), 3) 237 | cv2.imwrite(os.path.join(save_path, '{:03d}_sr_line.png'.format(img_0_num)), hr_img_0) 238 | 239 | dirs = os.listdir(os.path.join(img_root, str(video_num))) 240 | for dir in dirs: 241 | if dir == 'traj.png': 242 | continue 243 | img_dir_path = os.path.join(img_root, str(video_num), dir) 244 | sr_img_0 = os.path.join(img_dir_path, '{:03d}.png'.format(img_0_num)) 245 | sr_img_0 = cv2.imread(sr_img_0) 246 | if dir == 'results': 247 | sr_img_0 = sr_img_0[begin_y_0:begin_y_0+img_h, begin_x_0:begin_x_0+img_w, :] 248 | else: 249 | H, W, _ = sr_img_0.shape 250 | a = H / GT_H 251 | b = W / GT_W 252 | sr_img_0 = sr_img_0[int(begin_y_0*a):int((begin_y_0+img_h)*a), int(begin_x_0*b):int((begin_x_0+img_w)*b), :] 253 | cv2.imwrite(os.path.join(save_path, '{:03d}_{}.png'.format(img_0_num, dir)), sr_img_0) -------------------------------------------------------------------------------- /gen_video.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import cv2 6 | 7 | if __name__ == '__main__': 8 | # fourcc = cv2.VideoWriter_fourcc(*'MP4V') 9 | # out = cv2.VideoWriter('test_video_arcane.mp4', fourcc, 20.0, (1920, 1080)) 10 | save_dir = 'old_tree_x1' 11 | if not os.path.exists(save_dir): 12 | os.mkdir(save_dir) 13 | cap = cv2.VideoCapture('old_tree.mp4') 14 | gen_frames = [] 15 | n = 0 16 | scale = 8 17 | while cap.isOpened(): 18 | ret, frame = cap.read() 19 | if not isinstance(frame, np.ndarray): 20 | break 21 | H, W, C = frame.shape 22 | n += 1 23 | # gen_frames.append(frame) 24 | # if H != 1080 or W != 1920: 25 | # frame = cv2.resize(frame, (1920, 1080), interpolation=cv2.INTER_CUBIC) 26 | frame = frame[(H - 1080)//2:(H + 1080)//2, (W - 1920)//2:(W + 1920)//2, :] 27 | # frame = cv2.resize(frame, (W//scale, H//scale), interpolation=cv2.INTER_CUBIC) 28 | # cv2.putText(frame, str(n), (10, 120), cv2.FONT_HERSHEY_DUPLEX, 1, (0, 255, 255), 1, cv2.LINE_AA) 29 | # cv2.imshow('frame', frame) 30 | # if n <= 800: 31 | # continue 32 | # if n > 1200: 33 | # break 34 | # if cv2.waitKey(1) == ord('q'): 35 | # break 36 | print('{}\r'.format(n), end='') 37 | cv2.imwrite(os.path.join(save_dir, '{:08d}.png'.format(n)), frame) 38 | 39 | cap.release() 40 | cv2.destroyAllWindows() 41 | 42 | # print(n) 43 | # with get_writer('fungtsun_confuse.gif', mode="I", fps=20) as writer: 44 | # for i in range(n): 45 | # if i > 4: 46 | # out.write(gen_frames[i]) 47 | # writer.append_data(gen_frames[i][:,:,::-1]) 48 | -------------------------------------------------------------------------------- /gif_combine.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image,ImageSequence 3 | import numpy as np 4 | import cv2 5 | from imageio import imread, imsave, get_writer 6 | 7 | im_1 = Image.open('DCN_3and1.gif') 8 | im_2 = Image.open('DCN_4.gif') 9 | 10 | def iter_frames(im): 11 | try: 12 | i= 0 13 | while 1: 14 | im.seek(i) 15 | imframe = im.copy() 16 | if i == 0: 17 | palette = imframe.getpalette() 18 | else: 19 | imframe.putpalette(palette) 20 | yield imframe 21 | i += 1 22 | except EOFError: 23 | pass 24 | 25 | gif_1 = [] 26 | gif_2 = [] 27 | n_frames = 0 28 | for i,frame in enumerate(ImageSequence.Iterator(im_1),1): 29 | frame = frame.convert('RGB') 30 | # frame.save(os.path.join('test.png')) 31 | # gif_1.append(np.array(Image.open('test.png'))) 32 | gif_1.append(np.array(frame)) 33 | n_frames += 1 34 | # cv2.imshow('image', gif_1[-1][:,:,::-1]) 35 | # cv2.waitKey(10) 36 | 37 | for i,frame in enumerate(ImageSequence.Iterator(im_2),1): 38 | frame = frame.convert('RGB') 39 | frame.save(os.path.join('test.png')) 40 | # gif_2.append(np.array(Image.open('test.png'))) 41 | gif_2.append(np.array(frame)) 42 | # cv2.imshow('image', gif_2[-1][:,:,::-1]) 43 | # cv2.waitKey(10) 44 | 45 | with get_writer('test_1.gif', mode="I", fps=7) as writer: 46 | for n in range(n_frames - 10): 47 | writer.append_data(np.concatenate((gif_2[n][260:620,270:590,:],gif_1[n][260:620,270:590,:]), axis=1)) 48 | cv2.imshow('image', np.concatenate((gif_2[n][260:620,270:590,:],gif_1[n][260:620,270:590,:]), axis=1)) 49 | cv2.waitKey(100) 50 | 51 | # with get_writer('test_1.gif', mode="I", fps=7) as writer: 52 | # for n in range(n_frames): 53 | # writer.append_data(gif_1[n][260:620,220:640,:]) 54 | -------------------------------------------------------------------------------- /loss/loss.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | 7 | def reduce_loss(loss, reduction): 8 | """Reduce loss as specified. 9 | 10 | Args: 11 | loss (Tensor): Elementwise loss tensor. 12 | reduction (str): Options are "none", "mean" and "sum". 13 | 14 | Returns: 15 | Tensor: Reduced loss tensor. 16 | """ 17 | reduction_enum = F._Reduction.get_enum(reduction) 18 | # none: 0, elementwise_mean:1, sum: 2 19 | if reduction_enum == 0: 20 | return loss 21 | if reduction_enum == 1: 22 | return loss.mean() 23 | 24 | return loss.sum() 25 | 26 | def mask_reduce_loss(loss, weight=None, reduction='mean', sample_wise=False): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. Default: None. 32 | reduction (str): Same as built-in losses of PyTorch. Options are 33 | "none", "mean" and "sum". Default: 'mean'. 34 | sample_wise (bool): Whether calculate the loss sample-wise. This 35 | argument only takes effect when `reduction` is 'mean' and `weight` 36 | (argument of `forward()`) is not None. It will first reduces loss 37 | with 'mean' per-sample, and then it means over all the samples. 38 | Default: False. 39 | 40 | Returns: 41 | Tensor: Processed loss values. 42 | """ 43 | # if weight is specified, apply element-wise weight 44 | if weight is not None: 45 | assert weight.dim() == loss.dim() 46 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 47 | loss = loss * weight 48 | 49 | # if weight is not specified or reduction is sum, just reduce the loss 50 | if weight is None or reduction == 'sum': 51 | loss = reduce_loss(loss, reduction) 52 | # if reduction is mean, then compute mean over masked region 53 | elif reduction == 'mean': 54 | # expand weight from N1HW to NCHW 55 | if weight.size(1) == 1: 56 | weight = weight.expand_as(loss) 57 | # small value to prevent division by zero 58 | eps = 1e-12 59 | 60 | # perform sample-wise mean 61 | if sample_wise: 62 | weight = weight.sum(dim=[1, 2, 3], keepdim=True) # NCHW to N111 63 | loss = (loss / (weight + eps)).sum() / weight.size(0) 64 | # perform pixel-wise mean 65 | else: 66 | loss = loss.sum() / (weight.sum() + eps) 67 | 68 | return loss 69 | 70 | def masked_loss(loss_func): 71 | """Create a masked version of a given loss function. 72 | 73 | To use this decorator, the loss function must have the signature like 74 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 75 | element-wise loss without any reduction. This decorator will add weight 76 | and reduction arguments to the function. The decorated function will have 77 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 78 | avg_factor=None, **kwargs)`. 79 | 80 | :Example: 81 | 82 | >>> import torch 83 | >>> @masked_loss 84 | >>> def l1_loss(pred, target): 85 | >>> return (pred - target).abs() 86 | 87 | >>> pred = torch.Tensor([0, 2, 3]) 88 | >>> target = torch.Tensor([1, 1, 1]) 89 | >>> weight = torch.Tensor([1, 0, 1]) 90 | 91 | >>> l1_loss(pred, target) 92 | tensor(1.3333) 93 | >>> l1_loss(pred, target, weight) 94 | tensor(1.5000) 95 | >>> l1_loss(pred, target, reduction='none') 96 | tensor([1., 1., 2.]) 97 | >>> l1_loss(pred, target, weight, reduction='sum') 98 | tensor(3.) 99 | """ 100 | 101 | @functools.wraps(loss_func) 102 | def wrapper(pred, 103 | target, 104 | weight=None, 105 | reduction='mean', 106 | sample_wise=False, 107 | **kwargs): 108 | # get element-wise loss 109 | loss = loss_func(pred, target, **kwargs) 110 | loss = mask_reduce_loss(loss, weight, reduction, sample_wise) 111 | return loss 112 | 113 | return wrapper 114 | 115 | @masked_loss 116 | def charbonnier_loss(pred, target, eps=1e-12): 117 | """Charbonnier loss. 118 | Args: 119 | pred (Tensor): Prediction Tensor with shape (n, c, h, w). 120 | target ([type]): Target Tensor with shape (n, c, h, w). 121 | Returns: 122 | Tensor: Calculated Charbonnier loss. 123 | """ 124 | return torch.sqrt((pred - target)**2 + eps) 125 | 126 | class CharbonnierLoss(nn.Module): 127 | """Charbonnier loss (one variant of Robust L1Loss, a differentiable 128 | variant of L1Loss). 129 | Described in "Deep Laplacian Pyramid Networks for Fast and Accurate 130 | Super-Resolution". 131 | Args: 132 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 133 | reduction (str): Specifies the reduction to apply to the output. 134 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 135 | sample_wise (bool): Whether calculate the loss sample-wise. This 136 | argument only takes effect when `reduction` is 'mean' and `weight` 137 | (argument of `forward()`) is not None. It will first reduces loss 138 | with 'mean' per-sample, and then it means over all the samples. 139 | Default: False. 140 | eps (float): A value used to control the curvature near zero. 141 | Default: 1e-12. 142 | """ 143 | 144 | def __init__(self, 145 | loss_weight=1.0, 146 | reduction='mean', 147 | sample_wise=False, 148 | eps=1e-12): 149 | super().__init__() 150 | if reduction not in ['none', 'mean', 'sum']: 151 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 152 | f'Supported ones are: {_reduction_modes}') 153 | 154 | self.loss_weight = loss_weight 155 | self.reduction = reduction 156 | self.sample_wise = sample_wise 157 | self.eps = eps 158 | 159 | def forward(self, pred, target, weight=None, **kwargs): 160 | """Forward Function. 161 | Args: 162 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 163 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 164 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 165 | weights. Default: None. 166 | """ 167 | return self.loss_weight * charbonnier_loss( 168 | pred, 169 | target, 170 | weight, 171 | eps=self.eps, 172 | reduction=self.reduction, 173 | sample_wise=self.sample_wise) 174 | 175 | 176 | 177 | def get_loss_dict(args, logger): 178 | loss = {} 179 | if (abs(args.rec_w - 0) <= 1e-8): 180 | raise SystemExit('NotImplementError: ReconstructionLoss must exist!') 181 | else: 182 | loss['cb_loss'] = CharbonnierLoss() 183 | 184 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #### take reference from 2 | 3 | from option import args 4 | from utils import mkExpDir 5 | from dataset import dataloader 6 | from model import CRFP 7 | from loss.loss import get_loss_dict 8 | from trainer import Trainer 9 | 10 | import math 11 | import os 12 | import time 13 | import torch 14 | import torch.nn as nn 15 | import warnings 16 | from tqdm import tqdm 17 | warnings.filterwarnings('ignore') 18 | 19 | if __name__ == '__main__': 20 | 21 | ### make save_dir 22 | _logger = mkExpDir(args) 23 | 24 | ### device and model 25 | if args.num_gpu == 1: 26 | device = torch.device('cpu') if args.cpu else torch.device('cuda:{}'.format(args.gpu_id)) 27 | else: 28 | device = torch.device('cpu') if args.cpu else torch.device('cuda') 29 | 30 | # _model = CRFP.BasicFVSR(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 31 | # _model = CRFP.CRFP_simple_noDCN(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 32 | # _model = CRFP.CRFP_simple(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 33 | # _model = CRFP.CRFP(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 34 | _model = CRFP.CRFP_DSV(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 35 | # _model = CRFP.CRFP_DSV_CRA(mid_channels=32, y_only=args.y_only, hr_dcn=args.hr_dcn, offset_prop=args.offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 36 | 37 | if ((not args.cpu) and (args.num_gpu > 1)): 38 | _model = nn.DataParallel(_model, list(range(args.num_gpu))) 39 | 40 | ### dataloader of training set and testing set 41 | _dataloader = dataloader.get_dataloader(args) 42 | 43 | ### loss 44 | _loss_all = get_loss_dict(args, _logger) 45 | 46 | ### trainer 47 | t = Trainer(args, _logger, _dataloader, _model, _loss_all) 48 | t.before_run() 49 | ### test / eval / train 50 | if (args.test): 51 | t.load(model_path=args.model_path) 52 | t.test_basicvsr() 53 | elif (args.eval): 54 | model_list = sorted(os.listdir(args.model_path)) 55 | # model_list = model_list[::-1] 56 | for idx, m in enumerate(model_list): 57 | t.load(model_path=os.path.join(args.model_path, m)) 58 | t.eval_basicvsr(idx) 59 | t.vis_plot_metric('eval') 60 | else: 61 | # t.load(model_path=args.model_path) 62 | for epoch in range(1, args.num_epochs+1): 63 | t.train_basicvsr(current_epoch=epoch) 64 | t.vis_plot_metric('train') 65 | # if (epoch % args.val_every == 0): 66 | # t.eval_basicvsr(current_epoch=epoch) 67 | # t.vis_plot_metric('eval') 68 | torch.cuda.empty_cache() -------------------------------------------------------------------------------- /model/LTE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def pixel_unshuffle(input, downscale_factor): 6 | ''' 7 | input: batchSize * c * k*w * k*h 8 | kdownscale_factor: k 9 | batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h 10 | ''' 11 | c = input.shape[1] 12 | 13 | kernel = torch.zeros(size=[downscale_factor * downscale_factor * c, 14 | 1, downscale_factor, downscale_factor], 15 | device=input.device) 16 | for y in range(downscale_factor): 17 | for x in range(downscale_factor): 18 | kernel[x + y * downscale_factor::downscale_factor*downscale_factor, 0, y, x] = 1 19 | return F.conv2d(input, kernel, stride=downscale_factor, groups=c) 20 | 21 | class PixelUnshuffle(nn.Module): 22 | def __init__(self, downscale_factor): 23 | super(PixelUnshuffle, self).__init__() 24 | self.downscale_factor = downscale_factor 25 | def forward(self, input): 26 | ''' 27 | input: batchSize * c * k*w * k*h 28 | kdownscale_factor: k 29 | batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h 30 | ''' 31 | 32 | return pixel_unshuffle(input, self.downscale_factor) 33 | 34 | class LTE_simple_lr(torch.nn.Module): 35 | def __init__(self, mid_channels): 36 | super(LTE_simple_lr, self).__init__() 37 | 38 | ### use vgg19 weights to initialize 39 | self.slice1 = torch.nn.Sequential( 40 | nn.Conv2d(3, mid_channels, 3, 1, 1), 41 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 42 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1), 43 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 44 | ) 45 | 46 | def forward(self, x, islr=False): 47 | x = self.slice1(x) 48 | # x_lv3 = x 49 | # x_lv2 = x 50 | # x_lv1 = x 51 | return None, None, x 52 | 53 | class LTE_simple_hr(torch.nn.Module): 54 | def __init__(self, mid_channels): 55 | super(LTE_simple_hr, self).__init__() 56 | 57 | ### use vgg19 weights to initialize 58 | self.slice1 = torch.nn.Sequential( 59 | nn.Conv2d(6, mid_channels, 3, 1, 1), 60 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 61 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1), 62 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 63 | ) 64 | self.slice2 = torch.nn.Sequential( 65 | nn.MaxPool2d(2, 2), 66 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1), 67 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 68 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1), 69 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 70 | ) 71 | self.slice3 = torch.nn.Sequential( 72 | nn.MaxPool2d(2, 2), 73 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1), 74 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 75 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1), 76 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 77 | ) 78 | 79 | self.conv_lv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) 80 | self.conv_lv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) 81 | self.conv_lv3 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) 82 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 83 | 84 | def forward(self, x, islr=False): 85 | if islr: 86 | x = self.slice1(x) 87 | x_lv3 = self.lrelu(self.conv_lv3(x)) 88 | x = self.slice2(x) 89 | x_lv2 = self.lrelu(self.conv_lv2(x)) 90 | x = self.slice3(x) 91 | x_lv1 = self.lrelu(self.conv_lv1(x)) 92 | else: 93 | x_lv3 = x 94 | x = self.slice2(x) 95 | x_lv2 = x 96 | x = self.slice3(x) 97 | x_lv1 = x 98 | return x_lv1, x_lv2, x_lv3 99 | 100 | class LTE_simple_hr_single(torch.nn.Module): 101 | def __init__(self, mid_channels): 102 | super(LTE_simple_hr_single, self).__init__() 103 | 104 | ### use vgg19 weights to initialize 105 | self.slice1 = torch.nn.Sequential( 106 | nn.Conv2d(6, mid_channels, 3, 1, 1), 107 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 108 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1), 109 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 110 | ) 111 | 112 | def forward(self, x, islr=False): 113 | x = self.slice1(x) 114 | # x_lv3 = x 115 | # x_lv2 = x 116 | # x_lv1 = x 117 | return None, None, x 118 | 119 | class LTE_simple_hr_ps(torch.nn.Module): 120 | def __init__(self, mid_channels): 121 | super(LTE_simple_hr_ps, self).__init__() 122 | 123 | ### use vgg19 weights to initialize 124 | self.slice1 = torch.nn.Sequential( 125 | nn.Conv2d(6, mid_channels, 3, 1, 1), 126 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 127 | nn.Conv2d(mid_channels, mid_channels, 3, 1, 1), 128 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 129 | ) 130 | self.slice2 = torch.nn.Sequential( 131 | PixelUnshuffle(4), 132 | nn.Conv2d(mid_channels*16, mid_channels*4, 3, 1, 1), 133 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 134 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1), 135 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 136 | ) 137 | self.slice3 = torch.nn.Sequential( 138 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1), 139 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 140 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1), 141 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 142 | ) 143 | self.slice4 = torch.nn.Sequential( 144 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1), 145 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 146 | nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1), 147 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 148 | ) 149 | 150 | self.conv_lv0 = nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1) 151 | self.conv_lv1 = nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1) 152 | self.conv_lv2 = nn.Conv2d(mid_channels*4, mid_channels*4, 3, 1, 1) 153 | self.conv_lv3 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) 154 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 155 | 156 | def forward(self, x): 157 | x = self.slice1(x) 158 | x_lv3 = self.lrelu(self.conv_lv3(x)) 159 | x = self.slice2(x) 160 | x_lv2 = self.lrelu(self.conv_lv2(x)) 161 | x = self.slice3(x) 162 | x_lv1 = self.lrelu(self.conv_lv1(x)) 163 | x = self.slice4(x) 164 | x_lv0 = self.lrelu(self.conv_lv0(x)) 165 | 166 | return x_lv0, x_lv1, x_lv2, x_lv3 167 | 168 | class LTE_simple_hr_v1(torch.nn.Module): 169 | def __init__(self, mid_channels): 170 | super(LTE_simple_hr_v1, self).__init__() 171 | 172 | ### use vgg19 weights to initialize 173 | self.slice1 = torch.nn.Sequential( 174 | nn.Conv2d(6, mid_channels//4, 3, 1, 1), 175 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 176 | nn.Conv2d(mid_channels//4, mid_channels//4, 3, 1, 1), 177 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 178 | ) 179 | self.slice2 = torch.nn.Sequential( 180 | nn.MaxPool2d(2, 2), 181 | nn.Conv2d(mid_channels//4, mid_channels//2, 3, 1, 1), 182 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 183 | nn.Conv2d(mid_channels//2, mid_channels//2, 3, 1, 1), 184 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 185 | ) 186 | self.slice3 = torch.nn.Sequential( 187 | nn.MaxPool2d(2, 2), 188 | nn.Conv2d(mid_channels//2, mid_channels//1, 3, 1, 1), 189 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 190 | nn.Conv2d(mid_channels//1, mid_channels//1, 3, 1, 1), 191 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 192 | ) 193 | 194 | self.conv_lv3 = nn.Conv2d(mid_channels//4, mid_channels//4, 3, 1, 1) 195 | self.conv_lv2 = nn.Conv2d(mid_channels//2, mid_channels//2, 3, 1, 1) 196 | self.conv_lv1 = nn.Conv2d(mid_channels//1, mid_channels//1, 3, 1, 1) 197 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 198 | 199 | def forward(self, x, islr=False): 200 | if islr: 201 | x = self.slice1(x) 202 | x_lv3 = self.lrelu(self.conv_lv3(x)) 203 | x = self.slice2(x) 204 | x_lv2 = self.lrelu(self.conv_lv2(x)) 205 | x = self.slice3(x) 206 | x_lv1 = self.lrelu(self.conv_lv1(x)) 207 | else: 208 | x_lv3 = x 209 | x = self.slice2(x) 210 | x_lv2 = x 211 | x = self.slice3(x) 212 | x_lv1 = x 213 | return x_lv1, x_lv2, x_lv3 214 | 215 | class LTE_simple_hr_x8(torch.nn.Module): 216 | def __init__(self, mid_channels): 217 | super(LTE_simple_hr_x8, self).__init__() 218 | 219 | ### use vgg19 weights to initialize 220 | self.slice1 = torch.nn.Sequential( 221 | nn.Conv2d(6, 64, 3, 1, 1), 222 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 223 | nn.Conv2d(64, 64, 3, 1, 1), 224 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 225 | ) 226 | self.slice2 = torch.nn.Sequential( 227 | nn.MaxPool2d(2, 2), 228 | nn.Conv2d(64, 64, 3, 1, 1), 229 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 230 | nn.Conv2d(64, 64, 3, 1, 1), 231 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 232 | ) 233 | self.slice3 = torch.nn.Sequential( 234 | nn.MaxPool2d(2, 2), 235 | nn.Conv2d(64, 64, 3, 1, 1), 236 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 237 | nn.Conv2d(64, 64, 3, 1, 1), 238 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 239 | ) 240 | self.slice4 = torch.nn.Sequential( 241 | nn.MaxPool2d(2, 2), 242 | nn.Conv2d(64, 64, 3, 1, 1), 243 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 244 | nn.Conv2d(64, 64, 3, 1, 1), 245 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 246 | ) 247 | 248 | self.conv_lv0 = nn.Conv2d(64, 64, 3, 1, 1) 249 | self.conv_lv1 = nn.Conv2d(64, 64, 3, 1, 1) 250 | self.conv_lv2 = nn.Conv2d(64, 64, 3, 1, 1) 251 | self.conv_lv3 = nn.Conv2d(64, 64, 3, 1, 1) 252 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 253 | # self.maxpool = nn.MaxPool2d(2, 2) 254 | 255 | def forward(self, x, islr=False): 256 | if islr: 257 | x = self.slice1(x) 258 | x_lv3 = self.lrelu(self.conv_lv3(x)) 259 | x = self.slice2(x) 260 | x_lv2 = self.lrelu(self.conv_lv2(x)) 261 | x = self.slice3(x) 262 | x_lv1 = self.lrelu(self.conv_lv1(x)) 263 | x = self.slice4(x) 264 | x_lv0 = self.lrelu(self.conv_lv0(x)) 265 | else: 266 | x_lv3 = x 267 | x = self.slice2(x) 268 | x_lv2 = x 269 | x = self.slice3(x) 270 | x_lv1 = x 271 | x = self.slice4(x) 272 | x_lv0 = x 273 | return x_lv0, x_lv1, x_lv2, x_lv3 274 | -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def str2bool(v): 4 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 5 | return True 6 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 7 | return False 8 | else: 9 | raise argparse.ArgumentTypeError('Boolean value expected.') 10 | 11 | parser = argparse.ArgumentParser(description='MRCF') 12 | 13 | ### visdom setting 14 | parser.add_argument('--visdom_port', type=int, default=8801, 15 | help='Visdom execution port') 16 | parser.add_argument('--visdom_view', type=str, default='MRCF', 17 | help='Visdom execution view') 18 | 19 | ### log setting 20 | parser.add_argument('--save_dir', type=str, default='save_dir', 21 | help='Directory to save log, arguments, models and images') 22 | parser.add_argument('--reset', type=str2bool, default=False, 23 | help='Delete save_dir to create a new one') 24 | parser.add_argument('--log_file_name', type=str, default='MRCF.log', 25 | help='Log file name') 26 | parser.add_argument('--logger_name', type=str, default='MRCF', 27 | help='Logger name') 28 | 29 | ### device setting 30 | parser.add_argument('--cpu', type=str2bool, default=False, 31 | help='Use CPU to run code') 32 | parser.add_argument('--num_gpu', type=int, default=1, 33 | help='The number of GPU used in training') 34 | parser.add_argument('--gpu_id', type=int, default=0, 35 | help='The id of GPU used in training') 36 | 37 | ### dataset setting 38 | parser.add_argument('--dataset', type=str, default='REDS', 39 | help='Which dataset to train and test') 40 | parser.add_argument('--dataset_dir', type=str, default='/Data/REDS_sharp/', 41 | help='Directory of dataset') 42 | 43 | ### dataloader setting 44 | parser.add_argument('--num_workers', type=int, default=4, 45 | help='The number of workers when loading data') 46 | 47 | ### model setting 48 | parser.add_argument('--num_res_blocks', type=str, default='4+4+4+4', 49 | help='The number of residual blocks in each stage') 50 | parser.add_argument('--n_feats', type=int, default=64, 51 | help='The number of channels in network') 52 | parser.add_argument('--res_scale', type=float, default=1., 53 | help='Residual scale') 54 | parser.add_argument('--cra', type=str2bool, default=True, 55 | help='Use cra module or not') 56 | parser.add_argument('--mrcf', type=str2bool, default=True, 57 | help='MRCF or SRCF') 58 | parser.add_argument('--y_only', type=str2bool, default=False, 59 | help='Output Y-channel only or RGB-channels') 60 | parser.add_argument('--hr_dcn', type=str2bool, default=True, 61 | help='Move on DCN to highest spatial dimension or not') 62 | parser.add_argument('--offset_prop', type=str2bool, default=True, 63 | help='Propagate offset feature among each DCN or not') 64 | 65 | ### loss setting 66 | parser.add_argument('--rec_w', type=float, default=1., 67 | help='The weight of reconstruction loss') 68 | 69 | ### optimizer setting 70 | parser.add_argument('--beta1', type=float, default=0.9, 71 | help='The beta1 in Adam optimizer') 72 | parser.add_argument('--beta2', type=float, default=0.999, 73 | help='The beta2 in Adam optimizer') 74 | parser.add_argument('--eps', type=float, default=1e-12, 75 | help='The eps in Adam optimizer') 76 | parser.add_argument('--lr_rate', type=float, default=1e-4, 77 | help='Learning rate') 78 | parser.add_argument('--lr_rate_flow', type=float, default=2.5e-5, 79 | help='Learning rate of optical flow estimater') 80 | parser.add_argument('--decay', type=float, default=999999, 81 | help='Learning rate decay type') 82 | parser.add_argument('--gamma', type=float, default=0.5, 83 | help='Learning rate decay factor for step decay') 84 | 85 | ### training setting 86 | parser.add_argument('--batch_size', type=int, default=8, 87 | help='Training batch size') 88 | parser.add_argument('--GT_size', type=int, default=256, 89 | help='Training GT size') 90 | parser.add_argument('--FV_size', type=int, default=80, 91 | help='Training FV size') 92 | parser.add_argument('--scale', type=int, default=4, 93 | help='Training upsampling scale') 94 | parser.add_argument('--N_frames', type=int, default=15, 95 | help='Training number of frames') 96 | parser.add_argument('--train_crop_size', type=int, default=40, 97 | help='Training data crop size') 98 | parser.add_argument('--num_init_epochs', type=int, default=2, 99 | help='The number of init epochs which are trained with only reconstruction loss') 100 | parser.add_argument('--num_epochs', type=int, default=1, 101 | help='The number of training epochs') 102 | parser.add_argument('--print_every', type=int, default=1, 103 | help='Print period') 104 | parser.add_argument('--save_every', type=int, default=999999, 105 | help='Save period') 106 | parser.add_argument('--val_every', type=int, default=999999, 107 | help='Validation period') 108 | 109 | ### evaluate / test / finetune setting 110 | parser.add_argument('--eval', type=str2bool, default=False, 111 | help='Evaluation mode') 112 | parser.add_argument('--eval_save_results', type=str2bool, default=False, 113 | help='Save each image during evaluation') 114 | parser.add_argument('--model_path', type=str, default=None, 115 | help='The path of model to evaluation') 116 | parser.add_argument('--test', type=str2bool, default=False, 117 | help='Test mode') 118 | 119 | args = parser.parse_args() 120 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/overview.png -------------------------------------------------------------------------------- /png2mp4.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import cv2 6 | from torchaudio import save_encinfo 7 | 8 | if __name__ == '__main__': 9 | fourcc = cv2.VideoWriter_fourcc(*'MP4V') 10 | model_code = 15 11 | hr_dcn = True 12 | offset_prop = True 13 | split_ratio = 3 14 | model_name = 'FVSR_x8_simple_v{}_hrdcn_{}_offsetprop_{}_fnet{}_gaussian'.format(model_code, 'y' if hr_dcn else 'n', 15 | 'y' if offset_prop else 'n', 16 | '_{}outof4'.format(4-split_ratio) if model_code == 18 else '') 17 | # model_name = 'Bicubic' 18 | 19 | video_nums = [0, 11, 15, 20] 20 | for video_num in video_nums: 21 | sr_png_dir = 'test_png/eval_video/{}/{}/results'.format(model_name, video_num) 22 | gt_png_dir = 'test_png/eval_video/GroundTruth/{}/'.format(video_num) 23 | 24 | save_dir = 'test_video/{}/{}'.format(model_name, video_num) 25 | if not os.path.exists(save_dir): 26 | os.makedirs(save_dir) 27 | 28 | gt_imgs = [] 29 | sr_imgs = [] 30 | files = os.listdir(sr_png_dir) 31 | files = sorted(files) 32 | for f in files: 33 | if '.gif' in f: 34 | continue 35 | img = cv2.imread(os.path.join(sr_png_dir, f)) 36 | sr_imgs.append(img) 37 | files = os.listdir(gt_png_dir) 38 | files = sorted(files) 39 | for f in files: 40 | img = cv2.imread(os.path.join(gt_png_dir, f)) 41 | gt_imgs.append(img) 42 | 43 | H, W, C = sr_imgs[0].shape 44 | out = cv2.VideoWriter(os.path.join(save_dir, 'sr.mp4'), fourcc, 20.0, (W, H)) 45 | for i in range(len(sr_imgs)): 46 | out.write(sr_imgs[i]) 47 | 48 | H, W, C = gt_imgs[0].shape 49 | out = cv2.VideoWriter(os.path.join(save_dir, 'gt.mp4'), fourcc, 20.0, (W, H)) 50 | for i in range(len(gt_imgs)): 51 | out.write(gt_imgs[i]) -------------------------------------------------------------------------------- /pretrained_models/EGVSR_iter420000.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/pretrained_models/EGVSR_iter420000.pth -------------------------------------------------------------------------------- /pretrained_models/fnet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/pretrained_models/fnet.pth -------------------------------------------------------------------------------- /pretrained_models/spynet_20210409-c6c1bd09.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eugenelet/CRFP/c7d0b82735514ba182c14f188f29fdec390d6a6f/pretrained_models/spynet_20210409-c6c1bd09.pth -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True, batch_avg = False): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if batch_avg and ssim_map.dim() == 4: 35 | B, C, H, W = ssim_map.size() 36 | return ssim_map.view(B, -1).mean(1) 37 | elif size_average: 38 | return ssim_map.mean() 39 | else: 40 | return ssim_map.mean(1).mean(1).mean(1) 41 | 42 | class SSIM(torch.nn.Module): 43 | def __init__(self, window_size = 11, size_average = True): 44 | super(SSIM, self).__init__() 45 | self.window_size = window_size 46 | self.size_average = size_average 47 | self.channel = 1 48 | self.window = create_window(window_size, self.channel) 49 | 50 | def forward(self, img1, img2): 51 | (_, channel, _, _) = img1.size() 52 | 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): 54 | window = self.window 55 | else: 56 | window = create_window(self.window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | self.window = window 63 | self.channel = channel 64 | 65 | 66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | 68 | def ssim(img1, img2, window_size = 11, size_average = True, batch_avg = False): 69 | (_, channel, _, _) = img1.size() 70 | window = create_window(window_size, channel) 71 | 72 | if img1.is_cuda: 73 | window = window.cuda(img1.get_device()) 74 | window = window.type_as(img1) 75 | 76 | return _ssim(img1, img2, window, window_size, channel, size_average, batch_avg) 77 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | ### test 2 | 3 | ### test_video_with_reds 4 | python3 main.py --save_dir ./test/demo/FVSR_x4_mrcf/output/ \ 5 | --model_path ./train/model_00001_002900.pt \ 6 | --log_file_name test.log \ 7 | --reset True \ 8 | --test True \ 9 | --num_workers 1 \ 10 | --scale 4 \ 11 | --cra true \ 12 | --mrcf true \ 13 | --N_frames 15 \ 14 | --FV_size 96 \ 15 | --dataset Reds \ 16 | --dataset_dir /DATA/REDS_sharp/ \ 17 | --visdom_port 8803 \ 18 | --visdom_view FVSR_x4_mrcf 19 | -------------------------------------------------------------------------------- /test_img_coor.py: -------------------------------------------------------------------------------- 1 | # importing the module 2 | import cv2 3 | 4 | # function to display the coordinates of 5 | # of the points clicked on the image 6 | def click_event(event, x, y, flags, params): 7 | 8 | # checking for left mouse clicks 9 | if event == cv2.EVENT_LBUTTONDOWN: 10 | 11 | # displaying the coordinates 12 | # on the Shell 13 | print(x, ' ', y) 14 | 15 | # displaying the coordinates 16 | # on the image window 17 | font = cv2.FONT_HERSHEY_SIMPLEX 18 | cv2.putText(img, str(x) + ',' + 19 | str(y), (x,y), font, 20 | 1, (255, 0, 0), 2) 21 | cv2.imshow('image', img) 22 | 23 | # checking for right mouse clicks 24 | if event==cv2.EVENT_RBUTTONDOWN: 25 | 26 | # displaying the coordinates 27 | # on the Shell 28 | print(x, ' ', y) 29 | 30 | # displaying the coordinates 31 | # on the image window 32 | font = cv2.FONT_HERSHEY_SIMPLEX 33 | b = img[y, x, 0] 34 | g = img[y, x, 1] 35 | r = img[y, x, 2] 36 | cv2.putText(img, str(b) + ',' + 37 | str(g) + ',' + str(r), 38 | (x,y), font, 1, 39 | (255, 255, 0), 2) 40 | cv2.imshow('image', img) 41 | 42 | # driver function 43 | if __name__=="__main__": 44 | 45 | # reading the image 46 | # img_path = '/DATA/REDS_sharp/train/train/train_sharp/011/00000036.png' 47 | # img_path = '/DATA/REDS_sharp/train/train/train_sharp/020/00000099.png' 48 | img_path = '/home/si2/TTSR/env/test/test_png/results/FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_cra/pastfv/075_results_1.png' 49 | img = cv2.imread(img_path, 1) 50 | 51 | # displaying the image 52 | cv2.imshow('image', img) 53 | 54 | # setting mouse handler for the image 55 | # and calling the click_event() function 56 | cv2.setMouseCallback('image', click_event) 57 | 58 | # wait for a key to be pressed to exit 59 | cv2.waitKey(0) 60 | 61 | # close the window 62 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /test_runtime.py: -------------------------------------------------------------------------------- 1 | from model import MRCF_runtime as MRCF 2 | 3 | import os 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from torch.profiler import profile, record_function, ProfilerActivity 9 | from pytorch_memlab import LineProfiler 10 | from pytorch_memlab import MemReporter 11 | from dcn_v2 import DCNv2 12 | 13 | def conv_identify(weight, bias): 14 | weight.data.zero_() 15 | bias.data.zero_() 16 | o, i, h, w = weight.shape 17 | y = h//2 18 | x = w//2 19 | for p in range(i): 20 | for q in range(o): 21 | if p == q: 22 | weight.data[q, p, y, x] = 1.0 23 | 24 | if __name__ == '__main__': 25 | device = torch.device('cuda') 26 | 27 | mid_channels = 32 28 | start = torch.cuda.Event(enable_timing=True) 29 | end = torch.cuda.Event(enable_timing=True) 30 | # print(torch.cuda.memory_allocated(0), '-3') 31 | # print(torch.cuda.max_memory_allocated(0), '-3') 32 | # torch.cuda.reset_max_memory_allocated(0) 33 | y_only = False 34 | hr_dcn = True 35 | offset_prop = True 36 | split_ratio = 3 37 | # model = MRCF.MRCF_simple_v0(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 38 | # model = MRCF.MRCF_simple_v13(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 39 | # model = MRCF.MRCF_simple_v13_nodcn(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 40 | # model = MRCF.MRCF_simple_v15(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 41 | model = MRCF.MRCF_simple_v18(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, split_ratio=split_ratio, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 42 | # model = MRCF.MRCF_simple_v18_nofv(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, split_ratio=split_ratio, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 43 | 44 | # model = MRCF.MRCF_simple(mid_channels=mid_channels, num_blocks=1, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device) 45 | # model = MRCF.MRCF_simple_v1_dcn2_v4_kai(mid_channels=32, y_only=y_only, spynet_pretrained='pretrained_models/250.pth', device=device).to(device) 46 | # model = MRCF.MRCF_simple_v1_dcn2_v4_kai(mid_channels=32, y_only=y_only, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device) 47 | # model = MRCF.MRCF_simple_v4(mid_channels=mid_channels, y_only=False, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device) 48 | # model = MRCF.MRCF_CRA_x8(mid_channels=mid_channels, num_blocks=1, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device) 49 | # model = MRCF.MRCF_CRA_x8_v1(mid_channels=mid_channels, num_blocks=1, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device) 50 | # model_spy = model.spynet 51 | # model_en_hr = model.encoder_hr 52 | # model_en_lr = model.encoder_lr 53 | # model = nn.Upsample(scale_factor=8, mode='bicubic', align_corners=False) 54 | # for k, v in model.named_parameters(): 55 | # v.requires_grad_(False) 56 | 57 | # model = MRCF.ResidualBlocksWithInputConv(mid_channels * 2, mid_channels, 3).to(device) 58 | # model = nn.Sequential( 59 | # nn.Conv2d(mid_channels, mid_channels, 3, 1, 1), 60 | # nn.LeakyReLU(0.1, inplace=True), 61 | # nn.Conv2d(mid_channels, 3, 3, 1 ,1)).to(device) 62 | 63 | # group = 1 64 | # model = nn.Sequential( 65 | # nn.Conv2d(mid_channels*2+2, mid_channels, 3, 1, 1, bias=True), 66 | # nn.LeakyReLU(0.1, inplace=True), 67 | # nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True), 68 | # nn.LeakyReLU(0.1, inplace=True), 69 | # nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True), 70 | # nn.LeakyReLU(0.1, inplace=True) 71 | # ).to(device) 72 | # dcn_offset = nn.Conv2d(mid_channels, group*2*3*3, 3, 1, 1).to(device) 73 | # dcn_mask = nn.Conv2d(mid_channels, group*1*3*3, 3, 1, 1).to(device) 74 | # dcn = DCNv2(mid_channels, mid_channels, 3, stride=1, padding=1, dilation=1, deformable_groups=group).to(device) 75 | # dcn_offset.weight.data.zero_() 76 | # dcn_offset.bias.data.zero_() 77 | # dcn_mask.weight.data.zero_() 78 | # dcn_mask.bias.data.zero_() 79 | # conv_identify(dcn.weight, dcn.bias) 80 | 81 | scale = 1 82 | # HR_h = 720 83 | # HR_w = 1280 84 | HR_h = 1080 85 | HR_w = 1920 86 | # HR_h = 512 87 | # HR_w = 512 88 | LR_h = HR_h // 8 89 | LR_w = HR_w // 8 90 | FV_h = 96 91 | FV_w = 96 92 | WP_h = 720 93 | WP_w = 720 94 | # WP_h = 1080 95 | # WP_w = 1920 96 | 97 | t = 5 98 | repeat_time = 30 99 | warm_up = 10 100 | infer_time = 0 101 | 102 | model.eval() 103 | # dcn_offset.eval() 104 | # dcn_mask.eval() 105 | # dcn.eval() 106 | 107 | with torch.no_grad(): 108 | # print(torch.cuda.memory_allocated(0), '-2') 109 | # print(torch.cuda.max_memory_allocated(0), '-2') 110 | # torch.cuda.reset_max_memory_allocated(0) 111 | 112 | # x = torch.rand(1, mid_channels, HR_h, HR_w).cuda() 113 | # i = torch.rand(1, mid_channels * 2, HR_h//scale, HR_w//scale).cuda() 114 | 115 | # f = torch.rand(1, 2, HR_h//scale, HR_w//scale).cuda() 116 | # i = torch.rand(1, mid_channels*2+2, HR_h//scale, HR_w//scale).cuda() 117 | 118 | # f = torch.rand(1, 2, HR_h//scale, HR_w//scale).cuda() 119 | # x = torch.rand(1, mid_channels, HR_h//scale, HR_w//scale).cuda() 120 | 121 | # i = torch.rand(1, 3, HR_h//scale, HR_w//scale).cuda() 122 | # f = model(i) 123 | # o = dcn_offset(f) 124 | # o = 10. * torch.tanh(o) 125 | # m = dcn_mask(f) 126 | # m = torch.sigmoid(m) 127 | 128 | lr = torch.rand(1, t, 3, LR_h, LR_w).cuda() 129 | fv = torch.rand(1, t, 3, FV_h, FV_w).cuda() 130 | # mk = torch.ones(1, t, 1, HR_h, HR_w).cuda() 131 | 132 | # ref = torch.rand(1, 3, LR_h, LR_w).cuda() 133 | # sup = torch.rand(1, 3, LR_h, LR_w).cuda() 134 | 135 | # x_lr = torch.rand(1, 3, LR_h, LR_w).cuda() 136 | # x_fv = torch.rand(1, 6, FV_h, FV_w).cuda() 137 | 138 | # print(torch.cuda.memory_allocated(0), '-1') 139 | # print(torch.cuda.max_memory_allocated(0), '-1') 140 | # torch.cuda.reset_max_memory_allocated(0) 141 | 142 | for idx in range(repeat_time): 143 | if idx < warm_up: 144 | infer_time = 0 145 | torch.cuda.synchronize() 146 | # start_time = time.time() 147 | start.record() 148 | 149 | y = model(lr, fv, warp_size=(WP_h, WP_w)) 150 | # y = model(lr, fv) 151 | # y = model(lr, fv, mk) 152 | # y = model_spy(ref, sup) 153 | 154 | # y0, y1, y2 = model_en_lr(x_lr, islr=True) 155 | # y0, y1, y2 = model_en_hr(x_fv, islr=True) 156 | 157 | # y = MRCF.flow_warp(x, f.permute(0, 2, 3, 1)) 158 | 159 | # y = dcn(f, o, m) 160 | 161 | # y = model(i) 162 | # o = dcn_offset(y) 163 | # o = 10. * torch.tanh(o) 164 | # f = torch.cat((f[:, 1:2, :, :], f[:, 0:1, :, :]), dim=1) 165 | # f = f.repeat(1, o.size(1) // 2, 1, 1) 166 | # o = o + f 167 | # m = dcn_mask(y) 168 | # m = torch.sigmoid(m) 169 | 170 | # y = model(f) 171 | # y = model(x) 172 | # y = model(x_lr) 173 | 174 | # print(torch.cuda.memory_allocated(0), '0') 175 | # print(torch.cuda.max_memory_allocated(0), '0') 176 | # torch.cuda.reset_max_memory_allocated(0) 177 | end.record() 178 | torch.cuda.synchronize() 179 | # infer_time += (time.time() - start_time) 180 | infer_time += (start.elapsed_time(end)/1000) 181 | 182 | # with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof: 183 | # with record_function("model_inference"): 184 | # y = model(lr, fv, mk) 185 | 186 | print(y.shape, infer_time / (repeat_time - warm_up + 1) / t) 187 | # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=100)) 188 | # prof.export_chrome_trace('./mrcf_profile.json') 189 | # reporter.report(verbose=True) -------------------------------------------------------------------------------- /test_video.py: -------------------------------------------------------------------------------- 1 | from torch import save 2 | from model import MRCF_test as MRCF 3 | from model import LTE 4 | from utils import flow_to_color 5 | from dataset import dataloader 6 | from utils import calc_psnr_and_ssim_cuda, bgr2ycbcr 7 | 8 | import os 9 | import numpy as np 10 | import PIL 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.utils.data import DataLoader 15 | import visdom 16 | from imageio import imread, imsave, get_writer 17 | from PIL import Image 18 | import cv2 19 | # from ptflops import get_model_complexity_info 20 | import warnings 21 | warnings.filterwarnings("ignore", category=UserWarning) 22 | 23 | def foveated_metric(LR, LR_fv, HR, mn, hw, crop, kernel_size, stride_size, eval_mode=False): 24 | m, n = mn 25 | h, w = hw 26 | crop_h, crop_w = crop 27 | HR_fold = F.unfold(HR.unsqueeze(0), kernel_size=(kernel_size, kernel_size), stride=stride_size) # [N, 3*11*11, Hr*Wr] 28 | LR_fv_fold = F.unfold(LR_fv.unsqueeze(0), kernel_size=(kernel_size, kernel_size), stride=stride_size) # [N, 3*11*11, Hr*Wr] 29 | B, C, N = HR_fold.size() 30 | HR_fold = HR_fold.permute(0, 2, 1).view(B*N , 3, kernel_size, kernel_size) # [N, 3*11*11, Hr*Wr] 31 | LR_fv_fold = LR_fv_fold.permute(0, 2, 1).view(B*N , 3, kernel_size, kernel_size) 32 | Hr = (h - kernel_size) // stride_size + 1 33 | Wr = (w - kernel_size) // stride_size + 1 34 | 35 | B, C, H, W = HR_fold.size() 36 | mask = torch.ones((B, 1, H, W)).float() 37 | psnr_score, ssim_score = calc_psnr_and_ssim_cuda(HR_fold, LR_fv_fold, mask, is_tensor=False, batch_avg=True) 38 | psnr_score = psnr_score.view(Hr, Wr) 39 | ssim_score = ssim_score.view(Hr, Wr) 40 | 41 | # psnr_y_idx = (torch.argmax(psnr_score) // Wr) * stride_size 42 | # psnr_x_idx = (torch.argmax(psnr_score) % Wr) * stride_size 43 | # ssim_y_idx = (torch.argmax(ssim_score) // Wr) * stride_size 44 | # ssim_x_idx = (torch.argmax(ssim_score) % Wr) * stride_size 45 | if not eval_mode: 46 | HR[:, m:m+crop_h, n] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_h)) 47 | HR[:, m:m+crop_h, n+crop_w-1] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_h)) 48 | HR[:, m, n:n+crop_w] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_w)) 49 | HR[:, m+crop_h-1, n:n+crop_w] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_w)) 50 | 51 | LR_fv[:, m:m+crop_h, n] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_h)) 52 | LR_fv[:, m:m+crop_h, n+crop_w-1] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_h)) 53 | LR_fv[:, m, n:n+crop_w] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_w)) 54 | LR_fv[:, m+crop_h-1, n:n+crop_w] = torch.tensor([0., 0., 255.]).unsqueeze(1).repeat((1,crop_w)) 55 | 56 | psnr_min = psnr_score.min() 57 | psnr_max = psnr_score.max() 58 | ssim_min = ssim_score.min() 59 | ssim_max = ssim_score.max() 60 | # psnr_score = (psnr_score - psnr_min) / (psnr_max - psnr_min) 61 | # ssim_score = (ssim_score - ssim_min) / (ssim_max - ssim_min) 62 | psnr_score = psnr_score / 100 63 | ssim_score = (ssim_score.clip(0, 1) - 0.7) / 0.3 64 | 65 | # psnr_score_discrete = torch.zeros_like(psnr_score) 66 | # ssim_score_discrete = torch.zeros_like(ssim_score) 67 | 68 | # psnr_score_discrete[psnr_score <= 1.0] = 1.0 69 | # psnr_score_discrete[psnr_score <= 0.9] = 0.9 70 | # psnr_score_discrete[psnr_score <= 0.8] = 0.8 71 | # psnr_score_discrete[psnr_score <= 0.7] = 0.7 72 | # psnr_score_discrete[psnr_score <= 0.6] = 0.6 73 | # psnr_score_discrete[psnr_score <= 0.5] = 0.5 74 | # psnr_score_discrete[psnr_score <= 0.4] = 0.4 75 | # psnr_score_discrete[psnr_score <= 0.3] = 0.3 76 | # psnr_score_discrete[psnr_score <= 0.2] = 0.2 77 | # psnr_score_discrete[psnr_score <= 0.1] = 0.1 78 | 79 | # ssim_score_discrete[ssim_score <= 1.0] = 1.0 80 | # ssim_score_discrete[ssim_score <= 0.9] = 0.9 81 | # ssim_score_discrete[ssim_score <= 0.8] = 0.8 82 | # ssim_score_discrete[ssim_score <= 0.7] = 0.7 83 | # ssim_score_discrete[ssim_score <= 0.6] = 0.6 84 | # ssim_score_discrete[ssim_score <= 0.5] = 0.5 85 | # ssim_score_discrete[ssim_score <= 0.4] = 0.4 86 | # ssim_score_discrete[ssim_score <= 0.3] = 0.3 87 | # ssim_score_discrete[ssim_score <= 0.2] = 0.2 88 | # ssim_score_discrete[ssim_score <= 0.1] = 0.1 89 | 90 | # self.viz.viz.image(HR.cpu().numpy(), win='{}'.format('HR'), opts=dict(title='{}, Image size : {}'.format('HR', HR.size()))) 91 | # self.viz.viz.image(LR.cpu().numpy(), win='{}'.format('LR'), opts=dict(title='{}, Image size : {}'.format('LR', LR.size()))) 92 | # self.viz.viz.image(LR_fv.cpu().numpy(), win='{}'.format('FV'), opts=dict(title='{}, Image size : {}'.format('FV', LR_fv.size()))) 93 | # self.viz.viz.image(psnr_score.cpu().numpy(), win='{}'.format('PSNR_score'), opts=dict(title='{}, Image size : {}'.format('PSNR_score', psnr_score.size()))) 94 | # self.viz.viz.image(ssim_score.cpu().numpy(), win='{}'.format('SSIM_score'), opts=dict(title='{}, Image size : {}'.format('SSIM_score', ssim_score.size()))) 95 | # self.viz.viz.image(psnr_score_discrete.cpu().numpy(), win='{}'.format('PSNR_score_discrete'), opts=dict(title='{}, Image size : {}'.format('PSNR_score_discrete', psnr_score_discrete.size()))) 96 | # self.viz.viz.image(ssim_score_discrete.cpu().numpy(), win='{}'.format('SSIM_score_discrete'), opts=dict(title='{}, Image size : {}'.format('SSIM_score_discrete', ssim_score_discrete.size()))) 97 | 98 | return psnr_score, ssim_score, (psnr_min, psnr_max), (ssim_min, ssim_max) 99 | 100 | def rgb2yuv(rgb, y_only=True): 101 | # rgb_ = rgb.permute(0,2,3,1) 102 | # A = torch.tensor([[0.299, -0.14714119,0.61497538], 103 | # [0.587, -0.28886916, -0.51496512], 104 | # [0.114, 0.43601035, -0.10001026]]) 105 | # yuv = torch.tensordot(rgb_,A,1).transpose(0,2) 106 | r = rgb[:, 0, :, :] 107 | g = rgb[:, 1, :, :] 108 | b = rgb[:, 2, :, :] 109 | 110 | y = 0.299 * r + 0.587 * g + 0.114 * b 111 | u = -0.147 * r - 0.289 * g + 0.436 * b 112 | v = 0.615 * r - 0.515 * g - 0.100 * b 113 | yuv = torch.stack([y,u,v], dim=1) 114 | if y_only: 115 | return y.unsqueeze(1) 116 | else: 117 | return yuv 118 | 119 | def yuv2rgb(yuv): 120 | y = yuv[:, 0, :, :] 121 | u = yuv[:, 1, :, :] 122 | v = yuv[:, 2, :, :] 123 | 124 | r = y + 1.14 * v # coefficient for g is 0 125 | g = y + -0.396 * u - 0.581 * v 126 | b = y + 2.029 * u # coefficient for b is 0 127 | rgb = torch.stack([r,g,b], 1) 128 | 129 | return rgb 130 | 131 | if __name__ == '__main__': 132 | 133 | device = torch.device('cuda') 134 | viz = visdom.Visdom(server='140.113.212.214', port=8803, env='Gen_video') 135 | # fourcc = cv2.VideoWriter_fourcc('m','p','4','v') 136 | # out = cv2.VideoWriter('test_arcane_simple.mp4', fourcc, 30.0, (1080, 1920)) 137 | 138 | dataset_name = 'REDS' 139 | # dataset_name = 'old_tree' 140 | # dataset_name = 'arcane' 141 | regional_dcn = False 142 | eval_mode = True 143 | model_code = 15 144 | model_epoch = 99 145 | y_only = False 146 | hr_dcn = True 147 | offset_prop = True 148 | split_ratio = 3 149 | sigma = 50 150 | dcn_size = 720 151 | model_name = 'FVSR_x8_simple_v{}_hrdcn_{}_offsetprop_{}_fnet{}'.format(model_code, 'y' if hr_dcn else 'n', 152 | 'y' if offset_prop else 'n', 153 | '_{}outof4'.format(4-split_ratio) if model_code == 18 else '') 154 | print('Current model name: {}, Epoch: {}'.format(model_name, model_epoch)) 155 | video_num = [ 0, 11, 15, 20] 156 | # video_num = [ 0, 1, 6, 17] 157 | if eval_mode: 158 | fv_st_idx = [0, 0, 0, 0] 159 | else: 160 | fv_st_idx = [66, 30, 31, 0] 161 | # fv_st_idx = [100, 100, 100, 100] 162 | video_set = 'train' 163 | # video_set = 'val' 164 | model_path = 'train/REDS/{}/model/'.format(model_name) 165 | model_saves = os.listdir(model_path) 166 | model_save = [v for v in model_saves if '{:05d}'.format(model_epoch) in v] 167 | assert len(model_save) == 1 168 | model_save = model_save[0] 169 | model_name += '_gaussian' 170 | if eval_mode: 171 | save_dir = 'test_png/eval_video/' 172 | else: 173 | save_dir = 'test_png/' 174 | if not os.path.exists(save_dir): 175 | os.makedirs(save_dir) 176 | if not os.path.exists(os.path.join(save_dir, model_name)): 177 | os.makedirs(os.path.join(save_dir, model_name)) 178 | 179 | # model = MRCF.MRCF_CRA_x8(mid_channels=64, num_blocks=1, spynet_pretrained='pretrained_models/spynet_20210409-c6c1bd09.pth', device=device).to(device) 180 | if model_code == 13: 181 | model = MRCF.MRCF_simple_v13(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 182 | # model = MRCF.MRCF_simple_v13_nodcn(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 183 | elif model_code == 15: 184 | model = MRCF.MRCF_simple_v15(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 185 | elif model_code == 18: 186 | model = MRCF.MRCF_simple_v18(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, split_ratio=split_ratio, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 187 | # model = MRCF.MRCF_simple_v18_cra(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, split_ratio=split_ratio, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 188 | elif model_code == 0: 189 | model = MRCF.MRCF_simple_v0(mid_channels=32, y_only=y_only, hr_dcn=hr_dcn, offset_prop=offset_prop, spynet_pretrained='pretrained_models/fnet.pth', device=device).to(device) 190 | 191 | model_state_dict = model.state_dict() 192 | model_state_dict_save = {k.replace('basic_', 'basic_module.'):v for k,v in torch.load(os.path.join(model_path, model_save)).items() if k.replace('basic_', 'basic_module.') in model_state_dict} 193 | # model_state_dict_save = {k.replace('module.',''):v for k,v in torch.load(model_path).items() if k.replace('module.','') in model_state_dict} 194 | # for k in model_state_dict.keys(): 195 | # print(k) 196 | # print('-----') 197 | # for k in model_state_dict_save.keys(): 198 | # print(k) 199 | # print(model_save) 200 | # for k,v in torch.load(model_path).items(): 201 | # print(k) 202 | model_state_dict.update(model_state_dict_save) 203 | model.load_state_dict(model_state_dict, strict=True) 204 | 205 | psnr_whole_list = [] 206 | ssim_whole_list = [] 207 | psnr_outskirt_list = [] 208 | ssim_outskirt_list = [] 209 | psnr_past_list = [] 210 | ssim_past_list = [] 211 | psnr_fovea_list = [] 212 | ssim_fovea_list = [] 213 | 214 | for v_idx, v in enumerate(video_num): 215 | if dataset_name == 'REDS': 216 | GT_img_dir = '/DATA/REDS_sharp/{}/{}/{}_sharp/{:03d}/'.format(video_set, video_set, video_set, v) 217 | LR_img_dir = '/DATA/REDS_sharp_BI_x8/{}/{}/{}_sharp/{:03d}/'.format(video_set, video_set, video_set, v) 218 | else: 219 | GT_img_dir = '{}_x1'.format(dataset_name) 220 | LR_img_dir = '{}_x8'.format(dataset_name) 221 | print('Data location: {}'.format(GT_img_dir)) 222 | 223 | lr_frames = [] 224 | hr_frames = [] 225 | GT_imgs = [] 226 | LR_imgs = [] 227 | LRSR_imgs = [] 228 | GT_files = os.listdir(GT_img_dir) 229 | LR_files = os.listdir(LR_img_dir) 230 | GT_files = sorted(GT_files) 231 | LR_files = sorted(LR_files) 232 | for file in GT_files: 233 | img = cv2.imread(os.path.join(GT_img_dir, file)) 234 | GT_imgs.append(img[:1072, :1920, :]) 235 | # img = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2BGR) 236 | # hr_frames.append(img) 237 | H_, W_, _ = GT_imgs[0].shape 238 | for file in LR_files: 239 | img = cv2.imread(os.path.join(LR_img_dir, file)) 240 | LR_imgs.append(img[:134, :240, :]) 241 | LRSR_imgs.append(np.array(PIL.Image.fromarray(img).resize((W_, H_), PIL.Image.BICUBIC))) 242 | # img = cv2.cvtColor(img.copy(), cv2.COLOR_RGB2BGR) 243 | # lr_frames.append(img) 244 | 245 | gen_frames = [] 246 | gt_frames = [] 247 | lr_frames = [] 248 | psnr_score_list = [] 249 | ssim_score_list = [] 250 | psnr_score_bicubic_list = [] 251 | ssim_score_bicubic_list = [] 252 | # fv_size = 144 253 | fv_size = 96 254 | dx = 0 255 | dy = 0 256 | psnr_min = 1000 257 | psnr_max = 0 258 | ssim_min = 1000 259 | ssim_max = 0 260 | 261 | if dataset_name == 'arcane': 262 | st_x = 760 263 | st_y = 300 264 | ed_x = 1160 265 | ed_y = 500 266 | else: 267 | st_x = 360 + dx 268 | st_y = 300 + dy 269 | ed_x = 720 + dx 270 | ed_y = 500 + dy 271 | 272 | cur_x = ed_x 273 | cur_y = ed_y 274 | step_x = 20 275 | step_y = 0 276 | n_frames = 100 277 | 278 | bd_length = 10 279 | rg_w = dcn_size 280 | rg_h = dcn_size 281 | 282 | GT_imgs = GT_imgs[:n_frames] 283 | LR_imgs = LR_imgs[:n_frames] 284 | LRSR_imgs = LRSR_imgs[:n_frames] 285 | # with get_writer(os.path.join(save_dir,'test_{}_{}_{:03}_{}_bicubic.gif'.format(model_name, dataset_name, model_epoch, int(y_only))), mode="I", fps=7) as writer: 286 | # for n in range(n_frames): 287 | # writer.append_data(LRSR_imgs[n][:,:,::-1]) 288 | # with get_writer(os.path.join(save_dir,'test_{}_{}_{:03}_{}_gt.gif'.format(model_name, dataset_name, model_epoch, int(y_only))), mode="I", fps=7) as writer: 289 | # for n in range(n_frames): 290 | # writer.append_data(GT_imgs[n][:,:,::-1]) 291 | 292 | GT_imgs = np.stack(GT_imgs, axis=0) 293 | LR_imgs = np.stack(LR_imgs, axis=0) 294 | LRSR_imgs = np.stack(LRSR_imgs, axis=0) 295 | GT_imgs = GT_imgs.astype(np.float32) / 255. 296 | LR_imgs = LR_imgs.astype(np.float32) / 255. 297 | LRSR_imgs = LRSR_imgs.astype(np.float32) / 255. 298 | GT_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(GT_imgs, (0, 3, 1, 2)))).float() 299 | LR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LR_imgs, (0, 3, 1, 2)))).float() 300 | LRSR_imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(LRSR_imgs, (0, 3, 1, 2)))).float() 301 | N, C, H, W = GT_imgs.size() 302 | 303 | kernel = np.array([ [1, 1, 1], 304 | [1, 1, 1], 305 | [1, 1, 1] ], dtype=np.float32) 306 | kernel_tensor = torch.Tensor(np.expand_dims(np.expand_dims(kernel, 0), 0)) # size: (1, 1, 3, 3) 307 | mk_list = [] 308 | mk_one = torch.ones((1, 1, H, W)).to(device) 309 | x_array = sigma * np.random.randn(N) + (W / 2) 310 | y_array = sigma * np.random.randn(N) + (H / 2) 311 | white_paper = np.ones((H,W,3), np.uint8) * 255 312 | traj_list = [] 313 | 314 | model.eval() 315 | with torch.no_grad(): 316 | for n in range(N): 317 | print(n, '\r', end='') 318 | lr = LR_imgs[n:n+1].unsqueeze(0).to(device) 319 | lrsr = LRSR_imgs[n:n+1].unsqueeze(0).to(device).clone() 320 | gt = GT_imgs[n:n+1].unsqueeze(0).to(device).clone() 321 | fv = torch.zeros_like(gt).to(device) 322 | mk = torch.zeros((1, 1, 1, H, W)).to(device) 323 | fg = torch.zeros((1, 1, 1, H, W)).to(device) 324 | 325 | #### Raster scan 326 | # N_H = H // fv_size 327 | # N_W = W // fv_size 328 | # SP_H = H / N_H 329 | # SP_W = W / N_W 330 | # fv_sp = [] 331 | # x_i = n % N_W 332 | # y_i = (n // N_W) % N_H 333 | # cur_y = int((1+y_i)*SP_H - (SP_H + fv_size)//2) 334 | # cur_x = int((1+x_i)*SP_W - (SP_W + fv_size)//2) 335 | 336 | #### Gaussian span 337 | cur_y = int(y_array[n]) - fv_size//2 338 | cur_x = int(x_array[n]) - fv_size//2 339 | traj_list.append((cur_y, cur_x)) 340 | 341 | if n >= fv_st_idx[v_idx]: 342 | fv[:, :, :, cur_y:cur_y+fv_size, cur_x:cur_x+fv_size] = gt[:, :, :, cur_y:cur_y+fv_size, cur_x:cur_x+fv_size] 343 | mk[:, :, :, cur_y:cur_y+fv_size, cur_x:cur_x+fv_size] = 1 344 | 345 | mk_fv = mk.clone() 346 | mk_fv[:, :, :, cur_y:cur_y+fv_size, cur_x:cur_x+fv_size] = 1 347 | mk_out = mk_fv.clone().squeeze(0) 348 | for _ in range(10): 349 | mk_out = torch.clamp(F.conv2d(mk_out, kernel_tensor.to(mk_out.device), padding=(1, 1)), 0, 1) 350 | mk_out = torch.logical_and(torch.logical_not(mk), mk_out) 351 | 352 | st_rg_x = max(cur_x+(fv_size//2)-(rg_w//2), 0) 353 | ed_rg_x = min(cur_x+(fv_size//2)+(rg_w//2), 1920) 354 | st_rg_y = max(cur_y+(fv_size//2)-(rg_h//2), 0) 355 | ed_rg_y = min(cur_y+(fv_size//2)+(rg_h//2), 1080) 356 | if regional_dcn: 357 | fg[:, :, :, st_rg_y:ed_rg_y, st_rg_x:ed_rg_x] = 1 358 | else: 359 | fg = torch.ones((1, 1, 1, H, W)).to(device) 360 | 361 | sr = model(lrs=lr, fvs=fv, mks=mk, fgs=fg) 362 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.squeeze(0), gt.squeeze(0), mk_one) 363 | psnr_whole_list.append(psnr) 364 | ssim_whole_list.append(ssim) 365 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.squeeze(0), gt.squeeze(0), mk_fv) 366 | psnr_fovea_list.append(psnr) 367 | ssim_fovea_list.append(ssim) 368 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.squeeze(0), gt.squeeze(0), mk_out) 369 | psnr_outskirt_list.append(psnr) 370 | ssim_outskirt_list.append(ssim) 371 | if n > 0: 372 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.squeeze(0), gt.squeeze(0), mk_past) 373 | psnr_past_list.append(psnr) 374 | ssim_past_list.append(ssim) 375 | 376 | mk_list.append(mk_out.squeeze(0)) 377 | if len(mk_list) > 3: 378 | mk_list.pop(0) 379 | mk_past = torch.sum(torch.cat(mk_list, dim=1), dim=1, keepdim=True).clip(0, 1) 380 | 381 | psnr_score, ssim_score, psnr, ssim = foveated_metric(lr[0,0], sr[0,0], gt[0,0].clone(), (cur_y, cur_x), (H, W), (fv_size, fv_size), kernel_size=10, stride_size=5, eval_mode=eval_mode) 382 | psnr_score_list.append((psnr_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy().astype(np.uint8)) 383 | ssim_score_list.append((ssim_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy().astype(np.uint8)) 384 | psnr_score, ssim_score, psnr, ssim = foveated_metric(lr[0,0], lrsr[0,0], gt[0,0], (cur_y, cur_x), (H, W), (fv_size, fv_size), kernel_size=10, stride_size=5, eval_mode=eval_mode) 385 | if psnr[0] < psnr_min: 386 | psnr_min = psnr[0] 387 | if psnr[1] > psnr_max: 388 | psnr_max = psnr[1] 389 | if ssim[0] < ssim_min: 390 | ssim_min = ssim[0] 391 | if ssim[1] > ssim_max: 392 | ssim_max = ssim[1] 393 | psnr_score_bicubic_list.append((psnr_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy().astype(np.uint8)) 394 | ssim_score_bicubic_list.append((ssim_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy().astype(np.uint8)) 395 | 396 | if y_only: 397 | B, N, C, H, W = lrsr.size() 398 | lrsr = lrsr.view(B*N, C, H, W) 399 | B, N, C, H, W = sr.size() 400 | sr = sr.view(B*N, C, H, W) 401 | lrsr = rgb2yuv(lrsr, y_only=False) 402 | sr = yuv2rgb(torch.cat((sr[:,0:1,:,:], lrsr[:,1:3,:,:]), dim=1)) 403 | 404 | sr = (sr * 255.).clip(0., 255.) 405 | lr = (lrsr * 255.).clip(0., 255.) 406 | gt = (gt * 255.).clip(0., 255.) 407 | sr = np.transpose(sr.squeeze().clone().detach().round().cpu().numpy().astype(np.uint8), (1, 2, 0)) 408 | lr = np.transpose(lr.squeeze().clone().detach().round().cpu().numpy().astype(np.uint8), (1, 2, 0)) 409 | gt = np.transpose(gt.squeeze().clone().detach().round().cpu().numpy().astype(np.uint8), (1, 2, 0)) 410 | sr = cv2.cvtColor(sr, cv2.COLOR_RGB2BGR) 411 | H, W, C = sr.shape 412 | # lr = cv2.resize(cv2.cvtColor(lr, cv2.COLOR_RGB2BGR), (W, H), cv2.INTER_CUBIC) 413 | lr = cv2.cvtColor(lr, cv2.COLOR_RGB2BGR) 414 | gt = cv2.cvtColor(gt, cv2.COLOR_RGB2BGR) 415 | 416 | sr_copy = sr.copy() 417 | # cv2.rectangle(sr, (cur_x, cur_y), (cur_x+fv_size, cur_y+fv_size), (51, 51, 255), 3) 418 | # cv2.rectangle(sr, (st_rg_x, st_rg_y), (ed_rg_x, ed_rg_y), (255, 51, 51), 3) 419 | # cv2.line(sr, (cur_x+fv_size//2-5, cur_y+fv_size//2), (cur_x+fv_size//2+5, cur_y+fv_size//2), (51, 51, 255), 3) 420 | # cv2.line(sr, (cur_x+fv_size//2, cur_y+fv_size//2-5), (cur_x+fv_size//2, cur_y+fv_size//2+5), (51, 51, 255), 3) 421 | 422 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x, :] 423 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x-1, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x-1, :] 424 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+1, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+1, :] 425 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x-2, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x-2, :] 426 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+2, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+2, :] 427 | 428 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size, :] 429 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size-1, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size-1, :] 430 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size+1, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size+1, :] 431 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size-2, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size-2, :] 432 | # sr[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size+2, :] = sr_copy[cur_y+bd_length:cur_y+fv_size-bd_length, cur_x+fv_size+2, :] 433 | 434 | # sr[cur_y, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y, cur_x+bd_length:cur_x+fv_size-bd_length, :] 435 | # sr[cur_y-1, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y-1, cur_x+bd_length:cur_x+fv_size-bd_length, :] 436 | # sr[cur_y+1, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+1, cur_x+bd_length:cur_x+fv_size-bd_length, :] 437 | # sr[cur_y-2, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y-2, cur_x+bd_length:cur_x+fv_size-bd_length, :] 438 | # sr[cur_y+2, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+2, cur_x+bd_length:cur_x+fv_size-bd_length, :] 439 | 440 | # sr[cur_y+fv_size, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size, cur_x+bd_length:cur_x+fv_size-bd_length, :] 441 | # sr[cur_y+fv_size-1, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size-1, cur_x+bd_length:cur_x+fv_size-bd_length, :] 442 | # sr[cur_y+fv_size+1, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size+1, cur_x+bd_length:cur_x+fv_size-bd_length, :] 443 | # sr[cur_y+fv_size-2, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size-2, cur_x+bd_length:cur_x+fv_size-bd_length, :] 444 | # sr[cur_y+fv_size+2, cur_x+bd_length:cur_x+fv_size-bd_length, :] = sr_copy[cur_y+fv_size+2, cur_x+bd_length:cur_x+fv_size-bd_length, :] 445 | 446 | # sr[st_rg_y+bd_length:ed_rg_y-bd_length, st_rg_x-2:st_rg_x+3, :] = sr_copy[st_rg_y+bd_length:ed_rg_y-bd_length, st_rg_x-2:st_rg_x+3, :] 447 | # sr[st_rg_y+bd_length:ed_rg_y-bd_length, ed_rg_x-2:ed_rg_x+3, :] = sr_copy[st_rg_y+bd_length:ed_rg_y-bd_length, ed_rg_x-2:ed_rg_x+3, :] 448 | # sr[st_rg_y-2:st_rg_y+3, st_rg_x+bd_length:ed_rg_x-bd_length, :] = sr_copy[st_rg_y-2:st_rg_y+3, st_rg_x+bd_length:ed_rg_x-bd_length, :] 449 | # sr[ed_rg_y-2:ed_rg_y+3, st_rg_x+bd_length:ed_rg_x-bd_length, :] = sr_copy[ed_rg_y-2:ed_rg_y+3, st_rg_x+bd_length:ed_rg_x-bd_length, :] 450 | # cv2.rectangle(sr, (0, 100), (0+fv_size, 100+fv_size), (51, 51, 255), 3) 451 | # sr = cv2.cvtColor(sr, cv2.COLOR_BGR2RGB) 452 | gen_frames.append(sr.copy()) 453 | lr_frames.append(lr.copy()) 454 | gt_frames.append(gt.copy()) 455 | cur_x += step_x 456 | cur_y += step_y 457 | # viz.image(sr.transpose(2, 0, 1), win='{}'.format('sr'), opts=dict(title='{}, Image size : {}'.format('sr', sr.shape))) 458 | # viz.image(lr.transpose(2, 0, 1), win='{}'.format('lr'), opts=dict(title='{}, Image size : {}'.format('lr', lr.shape))) 459 | # viz.image(gt.transpose(2, 0, 1), win='{}'.format('gt'), opts=dict(title='{}, Image size : {}'.format('gt', gt.shape))) 460 | 461 | if cur_x >= ed_x and cur_y <= st_y: 462 | step_x = 0 463 | step_y = 20 464 | elif cur_x >= ed_x and cur_y >= ed_y: 465 | step_x = -20 466 | step_y = 0 467 | elif cur_x <= st_x and cur_y >= ed_y: 468 | step_x = 0 469 | step_y = -20 470 | elif cur_x <= st_x and cur_y <= st_y: 471 | step_x = 20 472 | step_y = 0 473 | 474 | # for (idy, idx) in traj_list: 475 | # white_paper = cv2.circle(white_paper, (idx,idy), radius=10, color=(0, 0, 255), thickness=-1) 476 | 477 | model.clear_states() 478 | if dataset_name == 'REDS': 479 | #### Reconstructed results 480 | if not os.path.exists(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'results')): 481 | os.makedirs(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'results')) 482 | for i in range(len(gen_frames)): 483 | cv2.imwrite(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'results', '{:03d}.png'.format(i)), gen_frames[i][:,:,::-1]) 484 | with get_writer(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'results', 'results.gif'), mode="I", fps=7) as writer: 485 | for n in range(len(gen_frames)): 486 | writer.append_data(gen_frames[n]) 487 | if not os.path.exists(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'psnr')): 488 | os.makedirs(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'psnr')) 489 | for i in range(len(gen_frames)): 490 | cv2.imwrite(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'psnr', '{:03d}.png'.format(i)), psnr_score_list[i]) 491 | if not os.path.exists(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'ssim')): 492 | os.makedirs(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'ssim')) 493 | for i in range(len(gen_frames)): 494 | cv2.imwrite(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'ssim', '{:03d}.png'.format(i)), ssim_score_list[i]) 495 | # cv2.imwrite(os.path.join(save_dir, model_name, str(video_num[v_idx]), 'traj.png'), white_paper) 496 | #### Bicubic upsample results 497 | if not os.path.exists(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]))): 498 | os.makedirs(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'results')) 499 | os.makedirs(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'psnr')) 500 | os.makedirs(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'ssim')) 501 | for i in range(len(lr_frames)): 502 | cv2.imwrite(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'results', '{:03d}.png'.format(i)), lr_frames[i][:,:,::-1]) 503 | for i in range(len(lr_frames)): 504 | cv2.imwrite(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'psnr', '{:03d}.png'.format(i)), psnr_score_bicubic_list[i]) 505 | for i in range(len(lr_frames)): 506 | cv2.imwrite(os.path.join(save_dir, 'Bicubic', str(video_num[v_idx]), 'ssim', '{:03d}.png'.format(i)), ssim_score_bicubic_list[i]) 507 | #### GroundTruth 508 | if not os.path.exists(os.path.join(save_dir, 'GroundTruth', str(video_num[v_idx]))): 509 | os.makedirs(os.path.join(save_dir, 'GroundTruth', str(video_num[v_idx]))) 510 | for i in range(len(lr_frames)): 511 | cv2.imwrite(os.path.join(save_dir, 'GroundTruth', str(video_num[v_idx]), '{:03d}.png'.format(i)), gt_frames[i][:,:,::-1]) 512 | else: 513 | if not os.path.exists(os.path.join(save_dir, model_name, dataset_name, 'results')): 514 | os.makedirs(os.path.join(save_dir, model_name, dataset_name, 'results')) 515 | for i in range(len(gen_frames)): 516 | cv2.imwrite(os.path.join(save_dir, model_name, dataset_name, 'results', '{:03d}.png'.format(i)), gen_frames[i][:,:,::-1]) 517 | if not os.path.exists(os.path.join(save_dir, model_name, dataset_name, 'psnr')): 518 | os.makedirs(os.path.join(save_dir, model_name, dataset_name, 'psnr')) 519 | for i in range(len(gen_frames)): 520 | cv2.imwrite(os.path.join(save_dir, model_name, dataset_name, 'psnr', '{:03d}.png'.format(i)), psnr_score_list[i]) 521 | if not os.path.exists(os.path.join(save_dir, model_name, dataset_name, 'ssim')): 522 | os.makedirs(os.path.join(save_dir, model_name, dataset_name, 'ssim')) 523 | for i in range(len(gen_frames)): 524 | cv2.imwrite(os.path.join(save_dir, model_name, dataset_name, 'ssim', '{:03d}.png'.format(i)), ssim_score_list[i]) 525 | # cv2.imwrite(os.path.join(save_dir, model_name, dataset_name, 'traj.png'), white_paper) 526 | break 527 | 528 | print('PSNR_MIN: {}, PSNR_MAX: {}'.format(psnr_min, psnr_max)) 529 | print('SSIM_MIN: {}, SSIM_MAX: {}'.format(ssim_min, ssim_max)) 530 | # with get_writer(os.path.join(save_dir,'test_{}_{}_{:03}_{}.gif'.format(model_name, dataset_name, model_epoch, int(y_only))), mode="I", fps=7) as writer: 531 | # for n in range(n_frames): 532 | # # out.write(gen_frames[n][:,:,::-1]) 533 | # writer.append_data(gen_frames[n]) 534 | 535 | # with get_writer('test_output_hr.gif', mode="I", fps=10) as writer: 536 | # for n in range(n_frames): 537 | # writer.append_data(hr_frames[n]) 538 | 539 | # with get_writer('test_output_lr.gif', mode="I", fps=10) as writer: 540 | # for n in range(n_frames): 541 | # writer.append_data(lr_frames[n]) 542 | 543 | print('PSNR_W: {}, SSIM_W: {}'.format(sum(psnr_whole_list)/len(psnr_whole_list), sum(ssim_whole_list)/len(ssim_whole_list))) 544 | print('PSNR_F: {}, SSIM_F: {}'.format(sum(psnr_fovea_list)/len(psnr_fovea_list), sum(ssim_fovea_list)/len(ssim_fovea_list))) 545 | print('PSNR_P: {}, SSIM_P: {}'.format(sum(psnr_past_list)/len(psnr_past_list), sum(ssim_past_list)/len(ssim_past_list))) 546 | print('PSNR_O: {}, SSIM_O: {}'.format(sum(psnr_outskirt_list)/len(psnr_outskirt_list), sum(ssim_outskirt_list)/len(ssim_outskirt_list))) 547 | -------------------------------------------------------------------------------- /test_video_quality.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # file_dir="FVSR_x8_simple_v0_hrdcn_n_offsetprop_n_fnet" 4 | # file_dir="FVSR_x8_simple_v13_hrdcn_n_offsetprop_n_fnet" 5 | # file_dir="FVSR_x8_simple_v13_hrdcn_n_offsetprop_n_fnet_nodcn" 6 | # file_dir="FVSR_x8_simple_v15_hrdcn_n_offsetprop_n_fnet" 7 | # file_dir="FVSR_x8_simple_v15_hrdcn_n_offsetprop_y_fnet" 8 | # file_dir="FVSR_x8_simple_v15_hrdcn_y_offsetprop_n_fnet" 9 | # file_dir="FVSR_x8_simple_v15_hrdcn_y_offsetprop_y_fnet" 10 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4" 11 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_nofv" 12 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_fast" 13 | file_dir="FVSR_x8_simple_v15_hrdcn_y_offsetprop_y_fnet_gaussian" 14 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_gaussian" 15 | # file_dir="FVSR_x8_simple_v18_hrdcn_y_offsetprop_y_fnet_1outof4_gaussian_regional" 16 | # file_dir="Bicubic" 17 | video_num=$1 18 | ffmpeg \ 19 | -r 24 -i test_video/$file_dir/$video_num/gt.mp4 \ 20 | -r 24 -i test_video/$file_dir/$video_num/sr.mp4 \ 21 | -lavfi "[0:v]setpts=PTS-STARTPTS[reference]; \ 22 | [1:v]scale=1280:720:flags=bicubic,setpts=PTS-STARTPTS[distorted]; \ 23 | [distorted][reference]libvmaf=log_fmt=xml:log_path=/dev/stdout:model_path=/home/si2/vmaf/model/vmaf_v0.6.1.json" \ 24 | -f null - > test_video/$file_dir/$video_num/eval.log 25 | echo $file_dir -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #### Train BasicVSR with Reds 2 | python3 main.py --save_dir ./train/REDS/FVSR_x8_simple_v1_dcn2_v11 \ 3 | --reset True \ 4 | --log_file_name train.log \ 5 | --num_gpu 4 \ 6 | --gpu_id 0 \ 7 | --num_workers 9 \ 8 | --dataset Reds \ 9 | --dataset_dir /DATA/REDS_sharp/ \ 10 | --model_path ./train/REDS/FVSR_x8_simple_v1_dcn2_v10/model/model_00002_005600.pt \ 11 | --n_feats 64 \ 12 | --lr_rate 2e-4 \ 13 | --lr_rate_flow 2.5e-5 \ 14 | --rec_w 1 \ 15 | --scale 8 \ 16 | --cra true \ 17 | --mrcf true \ 18 | --batch_size 8 \ 19 | --FV_size 128 \ 20 | --GT_size 256 \ 21 | --N_frames 15 \ 22 | --y_only false \ 23 | --num_init_epochs 2 \ 24 | --num_epochs 80 \ 25 | --print_every 200 \ 26 | --save_every 100 \ 27 | --val_every 1 \ 28 | --visdom_port 8803 \ 29 | --visdom_view 1227_FVSR_x8_simple_v1_dcn2_v11 30 | 31 | ### simple_v1 dk=1, fd=32-> 8, lv1, range=10, dcn_gp=16 32 | ### simple_v2 dk=3, fd=32-> 8, lv1, range=10, dcn_gp=16 33 | ### simple_v3 dk=3, fd=64->16, lv1, range=10, dcn_gp=16 34 | 35 | ### simple_v4 dk=1, fd=32-> 8, lv3, range=10, dcn_gp= 1 36 | ### simple_v5 dk=3, fd=32-> 8, lv3, range=10, dcn_gp= 1 37 | ### simple_v6 dk=3, fd=32-> 8, lv3, range=80, dcn_gp= 1 38 | ### simple_v7 dk=3, fd=32-> 8, lv3, range=80, dcn_gp= 4 39 | 40 | ### simple_v8 dk=3, fd=64->16, lv3, range=80, dcn_gp= 4 41 | ### simple_v9 dk=3, fd=64-> 8, lv3, range=80, dcn_gp= 4 42 | ### simple_v10 dk=3, fd=32->16, lv3, range=80, dcn_gp= 4 43 | ### simple_v11 dk=1, fd=32->16, lv3, range=80, dcn_gp=16 44 | 45 | ### simple_duf dk=1, fd=32-> 8, lv1, range=10, dcn_gp=16 46 | 47 | ### simple_v12 dk=1, fd=32->32, lv3, range=80, dcn_gp= 4 48 | ### simple_v13 dk=1, fd=32->32, lv3, range=80, dcn_gp= 4, dcn * 2, res * 2 49 | 50 | ### dcn3_v1 dcn * 3, dk=1 51 | ### dcn3_v2 dcn * 3, dk=3(offset using repeat) 52 | ### dcn3_v3 dcn * 3, dk=3(offset & mask using repeat) 53 | 54 | ### dcn2_v1 dcn * 2, dk=1 55 | ### dcn2_v2 dcn * 2, dk=1, offset finetune(mean of generated offset) 56 | ### dcn2_v3 dcn * 2, dk=1, upsample * 2 57 | ### dcn2_v4 dcn * 2, dk=1, res_block * 2 58 | ### dcn2_v5 dcn * 2, dk=1, branch out 59 | ### dcn2_v6 dcn * 2, dk=1, deeper downsampel layer(conv2d * 2) 60 | ### dcn2_v7 dcn * 2, dk=1, deeper downsampel layer(conv2d * 4) 61 | ### dcn2_v8 dcn * 2, dk=1, v4 deeper downsampel layer(PS * 2) 62 | ### dcn2_v9 dcn * 2, dk=1, channel dimension = 32 63 | ### dcn2_v10 dcn * 4, dk=1, channel dimension = 32 64 | ### dcn2_v11 pca , dk=1, channel dimension = 32 -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from utils import calc_psnr_and_ssim_cuda, bgr2ycbcr 2 | 3 | import os 4 | import numpy as np 5 | from imageio import imread, imsave, get_writer 6 | from PIL import Image 7 | import random 8 | import time 9 | from math import cos, pi 10 | from tqdm import tqdm 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | import torchvision.utils as utils 17 | import visdom 18 | 19 | def rgb2yuv(rgb, y_only=True): 20 | # rgb_ = rgb.permute(0,2,3,1) 21 | # A = torch.tensor([[0.299, -0.14714119,0.61497538], 22 | # [0.587, -0.28886916, -0.51496512], 23 | # [0.114, 0.43601035, -0.10001026]]) 24 | # yuv = torch.tensordot(rgb_,A,1).transpose(0,2) 25 | r = rgb[:, 0, :, :] 26 | g = rgb[:, 1, :, :] 27 | b = rgb[:, 2, :, :] 28 | 29 | y = 0.299 * r + 0.587 * g + 0.114 * b 30 | u = -0.147 * r - 0.289 * g + 0.436 * b 31 | v = 0.615 * r - 0.515 * g - 0.100 * b 32 | yuv = torch.stack([y,u,v], dim=1) 33 | if y_only: 34 | return y.unsqueeze(1) 35 | else: 36 | return yuv 37 | 38 | def yuv2rgb(yuv): 39 | y = yuv[:, 0, :, :] 40 | u = yuv[:, 1, :, :] 41 | v = yuv[:, 2, :, :] 42 | 43 | r = y + 1.14 * v # coefficient for g is 0 44 | g = y + -0.396 * u - 0.581 * v 45 | b = y + 2.029 * u # coefficient for b is 0 46 | rgb = torch.stack([r,g,b], 1) 47 | 48 | return rgb 49 | 50 | def get_position_from_periods(iteration, cumulative_periods): 51 | """Get the position from a period list. 52 | It will return the index of the right-closest number in the period list. 53 | For example, the cumulative_periods = [100, 200, 300, 400], 54 | if iteration == 50, return 0; 55 | if iteration == 210, return 2; 56 | if iteration == 300, return 3. 57 | Args: 58 | iteration (int): Current iteration. 59 | cumulative_periods (list[int]): Cumulative period list. 60 | Returns: 61 | int: The position of the right-closest number in the period list. 62 | """ 63 | for i, period in enumerate(cumulative_periods): 64 | if iteration < period: 65 | return i 66 | raise ValueError(f'Current iteration {iteration} exceeds ' 67 | f'cumulative_periods {cumulative_periods}') 68 | 69 | 70 | def annealing_cos(start, end, factor, weight=1): 71 | """Calculate annealing cos learning rate. 72 | Cosine anneal from `weight * start + (1 - weight) * end` to `end` as 73 | percentage goes from 0.0 to 1.0. 74 | Args: 75 | start (float): The starting learning rate of the cosine annealing. 76 | end (float): The ending learing rate of the cosine annealing. 77 | factor (float): The coefficient of `pi` when calculating the current 78 | percentage. Range from 0.0 to 1.0. 79 | weight (float, optional): The combination factor of `start` and `end` 80 | when calculating the actual starting learning rate. Default to 1. 81 | """ 82 | cos_out = cos(pi * factor) + 1 83 | return end + 0.5 * weight * (start - end) * cos_out 84 | 85 | class Visdom_exe(object): 86 | def __init__(self, port='8907', env='main'): 87 | self.port = port 88 | self.env = env 89 | self.viz = visdom.Visdom(server='140.113.212.214', port=port, env=env) 90 | 91 | def plot_metric(self, loss=[], psnr=[], ssim=[], psnr_cuda=[], ssim_cuda=[], psnr_y_cuda=[], ssim_y_cuda=[], phase='train'): 92 | if len(loss) != 0: 93 | self.viz.line(X=[*range(len(loss))], Y=loss, win='{}_LOSS'.format(phase), opts={'title':'{}_LOSS'.format(phase)}) 94 | if len(psnr) != 0: 95 | self.viz.line(X=[*range(len(psnr))], Y=psnr, win='{}_PSNR'.format(phase), opts={'title':'{}_PSNR'.format(phase)}) 96 | if len(ssim) != 0: 97 | self.viz.line(X=[*range(len(ssim))], Y=ssim, win='{}_SSIM'.format(phase), opts={'title':'{}_SSIM'.format(phase)}) 98 | if len(psnr_cuda) != 0: 99 | self.viz.line(X=[*range(len(psnr_cuda))], Y=psnr_cuda, win='{}_PSNR_cuda'.format(phase), opts={'title':'{}_PSNR_cuda'.format(phase)}) 100 | if len(ssim_cuda) != 0: 101 | self.viz.line(X=[*range(len(ssim_cuda))], Y=ssim_cuda, win='{}_SSIM_cuda'.format(phase), opts={'title':'{}_SSIM_cuda'.format(phase)}) 102 | if len(psnr_y_cuda) != 0: 103 | self.viz.line(X=[*range(len(psnr_y_cuda))], Y=psnr_y_cuda, win='{}_PSNR_y_cuda'.format(phase), opts={'title':'{}_PSNR_y_cuda'.format(phase)}) 104 | if len(ssim_y_cuda) != 0: 105 | self.viz.line(X=[*range(len(ssim_y_cuda))], Y=ssim_y_cuda, win='{}_SSIM_y_cuda'.format(phase), opts={'title':'{}_SSIM_y_cuda'.format(phase)}) 106 | 107 | class Trainer(): 108 | def __init__(self, args, logger, dataloader, model, loss_all): 109 | self.viz = Visdom_exe(args.visdom_port, args.visdom_view) 110 | self.args = args 111 | self.logger = logger 112 | self.dataloader = dataloader 113 | self.model = model 114 | self.loss_all = loss_all 115 | # self.device = torch.device('cpu') if args.cpu else torch.device('cuda:{}'.format(args.gpu_id)) 116 | self.device = torch.device('cpu') if args.cpu else torch.device('cuda') 117 | 118 | self.cur_clip = 0 119 | 120 | #### CosineAnnealing settings #### 121 | self.cur_iter = 0 122 | self.by_epoch = False 123 | self.periods = [600000] 124 | self.min_lr = 1e-7 125 | self.restart_weights = [1] 126 | self.cumulative_periods = [ 127 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 128 | ] 129 | #### CosineAnnealing settings #### 130 | 131 | self.params = [ 132 | {"params": [p for n, p in (self.model.named_parameters() if 133 | args.num_gpu==1 else self.model.module.named_parameters()) 134 | if ('spynet' not in n)], 135 | "lr": args.lr_rate 136 | }, 137 | {"params": [p for n, p in (self.model.named_parameters() if 138 | args.num_gpu==1 else self.model.module.named_parameters()) 139 | if ('spynet' in n)], 140 | "lr": args.lr_rate_flow 141 | } 142 | ] 143 | 144 | #### TTSR settings #### 145 | # self.optimizer = optim.Adam(self.params, lr=args.lr_rate, betas=(args.beta1, args.beta2), eps=args.eps) 146 | # self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.decay, gamma=self.args.gamma) 147 | 148 | #### BasicVSR settings #### 149 | self.optimizer = optim.Adam(self.params, lr=args.lr_rate, betas=(args.beta1, args.beta2), eps=args.eps) 150 | 151 | self.max_psnr = 0. 152 | self.max_psnr_epoch = 0 153 | self.max_ssim = 0. 154 | self.max_ssim_epoch = 0 155 | 156 | self.max_y_psnr = 0. 157 | self.max_y_psnr_epoch = 0 158 | self.max_y_ssim = 0. 159 | self.max_y_ssim_epoch = 0 160 | 161 | self.train_loss_list = [] 162 | self.train_psnr_list = [] 163 | self.train_ssim_list = [] 164 | self.train_psnr_cuda_list = [] 165 | self.train_ssim_cuda_list = [] 166 | self.train_psnr_y_cuda_list = [] 167 | self.train_ssim_y_cuda_list = [] 168 | 169 | self.eval_loss_list = [] 170 | self.eval_psnr_list = [] 171 | self.eval_ssim_list = [] 172 | self.eval_psnr_cuda_list = [] 173 | self.eval_ssim_cuda_list = [] 174 | self.eval_psnr_y_cuda_list = [] 175 | self.eval_ssim_y_cuda_list = [] 176 | 177 | self.test_psnr = [] 178 | self.test_ssim = [] 179 | self.test_lr = [] 180 | self.test_hr = [] 181 | self.test_sr = [] 182 | 183 | self.print_network(self.model) 184 | 185 | def load(self, model_path=None): 186 | if (model_path): 187 | self.logger.info('load_model_path: ' + model_path) 188 | 189 | model_state_dict = self.model.state_dict() 190 | for k in model_state_dict.keys(): 191 | print(k) 192 | print('-----') 193 | model_state_dict_save = {k.replace('basic_', 'basic_module.'):v for k,v in torch.load(model_path).items() if k.replace('basic_', 'basic_module.') in model_state_dict} 194 | # model_state_dict_save = {k.replace('basic_','basic_module.'):v for k,v in torch.load(model_path).items()} 195 | # model_state_dict_save = {k:v for k,v in torch.load(model_path, map_location=self.device)['state_dict'].items()} 196 | for k in model_state_dict_save.keys(): 197 | print(k) 198 | model_state_dict.update(model_state_dict_save) 199 | self.model.load_state_dict(model_state_dict, strict=True) 200 | 201 | def prepare(self, sample_batched): 202 | for key in sample_batched.keys(): 203 | sample_batched[key] = sample_batched[key].to(self.device) 204 | return sample_batched 205 | 206 | def train_basicvsr(self, current_epoch=0): 207 | self.model.train() 208 | loss_list = [] 209 | psnr_cuda_list = [] 210 | ssim_cuda_list = [] 211 | psnr_y_cuda_list = [] 212 | ssim_y_cuda_list = [] 213 | loop = tqdm(enumerate(self.dataloader['train']), total=len(self.dataloader['train']), leave=False) 214 | for i_batch, sample_batched in loop: 215 | sample_batched = self.prepare(sample_batched) 216 | lr = sample_batched['LR'] 217 | lr_sr = sample_batched['LR_sr'] 218 | hr = sample_batched['HR'] 219 | ref = sample_batched['Ref'] 220 | # ref_sr = sample_batched['Ref_sr'] 221 | ref_sp = sample_batched['Ref_sp'] 222 | 223 | if self.cur_iter < 5000: 224 | for k, v in self.model.named_parameters(): 225 | if 'spynet' in k: 226 | v.requires_grad_(False) 227 | elif self.cur_iter == 5000: 228 | #### train all the parameters 229 | self.model.requires_grad_(True) 230 | 231 | self.before_train_iter() 232 | 233 | sr = self.model(lrs=lr, fvs=ref, mks=ref_sp) 234 | B, N, C, H, W = sr.size() 235 | sr = sr.view(B*N, C, H, W) 236 | B, N, C, H, W = hr.size() 237 | hr = hr.view(B*N, C, H, W) 238 | 239 | if self.args.y_only: 240 | B, N, C, H, W = lr_sr.size() 241 | lr_sr = lr_sr.view(B*N, C, H, W) 242 | lr_sr = rgb2yuv(lr_sr, y_only=False) 243 | sr = yuv2rgb(torch.cat((sr[:,0:1,:,:], lr_sr[:,1:3,:,:]), dim=1)) 244 | 245 | rec_loss = self.args.rec_w * self.loss_all['cb_loss'](sr, hr) 246 | loss = rec_loss 247 | 248 | self.optimizer.zero_grad() 249 | loss.backward() 250 | self.optimizer.step() 251 | loss_list.append(loss.cpu().item()) 252 | 253 | ### calculate psnr and ssim 254 | #### RGB domain #### 255 | # B, N, C, H, W = sr.size() 256 | mk_ = torch.ones((B*N, 1, H, W)).to(sr.device) 257 | psnr, ssim = calc_psnr_and_ssim_cuda(sr.detach(), hr.detach(), mk_) 258 | psnr_cuda_list.append(psnr.cpu().item()) 259 | ssim_cuda_list.append(ssim.cpu().item()) 260 | 261 | # #### YCbCr domain #### 262 | # B, N, C, H, W = sr.size() 263 | psnr, ssim = calc_psnr_and_ssim_cuda(bgr2ycbcr(sr.permute(0, 2, 3, 1).detach(), y_only=True), \ 264 | bgr2ycbcr(hr.permute(0, 2, 3, 1).detach(), y_only=True), mk_) 265 | # psnr, ssim = calc_psnr_and_ssim_cuda(sr.view(B*N, C, H, W).permute(0, 2, 3, 1).detach(), \ 266 | # hr.view(B*N, C, H, W).permute(0, 2, 3, 1).detach(), mk_) 267 | psnr_y_cuda_list.append(psnr.cpu().item()) 268 | ssim_y_cuda_list.append(ssim.cpu().item()) 269 | 270 | loop.set_description(f"Epoch[{current_epoch}/{self.args.num_epochs}](Train)") 271 | # loop.set_postfix(loss=loss.item(), psnr=psnr.item(), ssim=ssim.item()) 272 | loop.set_postfix(loss=loss.item(), lr=self.optimizer.param_groups[0]['lr'], lr_flow=self.optimizer.param_groups[1]['lr']) 273 | 274 | self.cur_iter += 1 275 | 276 | if self.cur_iter % self.args.save_every == 0: 277 | tmp = self.model.state_dict() 278 | model_state_dict = {key.replace('module.',''): tmp[key] for key in tmp} 279 | model_name = self.args.save_dir.strip('/')+'/model/model_{}_{}.pt'.format(str(current_epoch).zfill(5), str(self.cur_iter).zfill(6)) 280 | torch.save(model_state_dict, model_name) 281 | self.train_loss_list.append(sum(loss_list)/len(loss_list)) 282 | self.train_psnr_cuda_list.append(sum(psnr_cuda_list)/len(psnr_cuda_list)) 283 | self.train_ssim_cuda_list.append(sum(ssim_cuda_list)/len(ssim_cuda_list)) 284 | self.train_psnr_y_cuda_list.append(sum(psnr_y_cuda_list)/len(psnr_y_cuda_list)) 285 | self.train_ssim_y_cuda_list.append(sum(ssim_y_cuda_list)/len(ssim_y_cuda_list)) 286 | loss_list.clear() 287 | psnr_cuda_list.clear() 288 | ssim_cuda_list.clear() 289 | psnr_y_cuda_list.clear() 290 | ssim_y_cuda_list.clear() 291 | 292 | # if self.cur_iter % self.args.print_every == 0: 293 | # self.test_basicvsr(save_img=False) 294 | 295 | def eval_basicvsr(self, current_epoch=0): 296 | self.model.eval() 297 | psnr_cuda_list = [] 298 | ssim_cuda_list = [] 299 | psnr_y_cuda_list = [] 300 | ssim_y_cuda_list = [] 301 | # kernel = np.array([ [1, 1, 1], 302 | # [1, 1, 1], 303 | # [1, 1, 1] ], dtype=np.float32) 304 | # kernel_tensor = torch.Tensor(np.expand_dims(np.expand_dims(kernel, 0), 0)) # size: (1, 1, 3, 3) 305 | loop = tqdm(enumerate(self.dataloader['eval']), total=len(self.dataloader['eval']), leave=False) 306 | with torch.no_grad(): 307 | for i_batch, sample_batched in loop: 308 | sample_batched = self.prepare(sample_batched) 309 | lr = sample_batched['LR'] 310 | lr_sr = sample_batched['LR_sr'] 311 | hr = sample_batched['HR'] 312 | ref = sample_batched['Ref'] 313 | ref_sp = sample_batched['Ref_sp'] 314 | 315 | B, N, C, H, W = ref_sp.size() 316 | mk = torch.split(ref_sp.view(B*N, 1, H, W).clone().detach(), 1, dim=0) 317 | mk_bd = ref_sp.view(B*N, 1, H, W).float().clone().detach() 318 | sr = self.model(lrs=lr, fvs=ref, mks=ref_sp) 319 | # LR_fv = (sr * 255.).clip(0., 255.).squeeze().clone().detach() 320 | # for n in range(N): 321 | # self.viz.viz.image(LR_fv[n , :, :, :].cpu().numpy(), win='{}'.format('FV'), opts=dict(title='{}, Image size : {}'.format('FV', LR_fv.size()))) 322 | # sr = lr_sr 323 | 324 | ### calculate psnr and ssim 325 | B, N, C, H, W = sr.size() 326 | sr = sr.view(B*N, C, H, W) 327 | B, N, C, H, W = hr.size() 328 | hr = hr.view(B*N, C, H, W) 329 | 330 | if self.args.y_only: 331 | B, N, C, H, W = lr_sr.size() 332 | lr_sr = lr_sr.view(B*N, C, H, W) 333 | lr_sr = rgb2yuv(lr_sr, y_only=False) 334 | sr = yuv2rgb(torch.cat((sr[:,0:1,:,:], lr_sr[:,1:3,:,:]), dim=1)) 335 | 336 | # LR_fv = (sr * 255.).clip(0., 255.).squeeze().clone().detach() 337 | # for i in range(B*N): 338 | # self.viz.viz.image(LR_fv[i].cpu().numpy(), win='{}'.format('FV'), opts=dict(title='{}, Image size : {}'.format('FV', LR_fv.size()))) 339 | 340 | # mk = torch.split(ref_sp.view(B*N, 1, H, W), 1, dim=0) 341 | # mk_bd = ref_sp.view(B*N, 1, H, W).float() 342 | # for _ in range(10): 343 | # mk_bd = torch.clamp(F.conv2d(mk_bd, kernel_tensor.to(mk_bd.device), padding=(1, 1)), 0, 1) 344 | # mk_bd = torch.split(mk_bd, 1, dim=0) 345 | # past_len = 3 346 | # if i_batch % 50 == 0: 347 | # mk_pre = [mk_bd[0]] 348 | mk_ = torch.ones((B*N, 1, H, W)).to(sr.device) 349 | for idx in range(0, N): 350 | if idx == 0 and i_batch % 50 == 0: 351 | continue 352 | sr_ = sr[idx, :, :, :].unsqueeze(0) 353 | hr_ = hr[idx, :, :, :].unsqueeze(0) 354 | # mk_ = mk[idx] 355 | # mk_ = torch.logical_and(mk_, torch.logical_not(mk[idx])) 356 | # mk_ = torch.logical_or(mk_, mk[idx - 1]) 357 | # mk_ = torch.logical_and(torch.logical_not(mk[idx]), mk_bd[idx]) 358 | # mk_ = torch.sum(torch.cat(mk_pre, dim=1), dim=1, keepdim=True).clip(0, 1) 359 | # psnr, ssim = calc_psnr_and_ssim_cuda(sr.view(B*N, C, H, W).detach(), hr.view(B*N, C, H, W).detach()) 360 | # psnr, ssim = calc_psnr_and_ssim_cuda((sr_*mk_).detach(), (hr_*mk_).detach()) 361 | psnr, ssim = calc_psnr_and_ssim_cuda(sr_.detach(), hr_.detach(), mk_) 362 | psnr_cuda_list.append(psnr.cpu().item()) 363 | ssim_cuda_list.append(ssim.cpu().item()) 364 | psnr, ssim = calc_psnr_and_ssim_cuda(bgr2ycbcr(sr_.permute(0, 2, 3, 1).detach(), y_only=True), \ 365 | bgr2ycbcr(hr_.permute(0, 2, 3, 1).detach(), y_only=True), mk_) 366 | # psnr, ssim = calc_psnr_and_ssim_cuda(sr_.detach(), \ 367 | # hr_.detach(), mk_) 368 | psnr_y_cuda_list.append(psnr.cpu().item()) 369 | ssim_y_cuda_list.append(ssim.cpu().item()) 370 | 371 | # mk_pre.append(mk_bd[idx]) 372 | # if len(mk_pre) > past_len: 373 | # mk_pre.pop(0) 374 | 375 | psnr_ave = sum(psnr_cuda_list)/len(psnr_cuda_list) 376 | ssim_ave = sum(ssim_cuda_list)/len(ssim_cuda_list) 377 | psnr_ave_y = sum(psnr_y_cuda_list)/len(psnr_y_cuda_list) 378 | ssim_ave_y = sum(ssim_y_cuda_list)/len(ssim_y_cuda_list) 379 | loop.set_description(f"Epoch[{current_epoch}/{self.args.num_epochs}](Eval)") 380 | loop.set_postfix(psnr=psnr_ave, ssim=ssim_ave, psnr_y=psnr_ave_y, ssim_y=ssim_ave_y) 381 | 382 | psnr_ave = sum(psnr_cuda_list)/len(psnr_cuda_list) 383 | ssim_ave = sum(ssim_cuda_list)/len(ssim_cuda_list) 384 | self.eval_psnr_cuda_list.append(psnr_ave) 385 | self.eval_ssim_cuda_list.append(ssim_ave) 386 | self.logger.info('Ref PSNR (now): %.3f \t SSIM (now): %.4f' %(psnr_ave, ssim_ave)) 387 | if (psnr_ave > self.max_psnr): 388 | self.max_psnr = psnr_ave 389 | self.max_psnr_epoch = current_epoch 390 | if (ssim_ave > self.max_ssim): 391 | self.max_ssim = ssim_ave 392 | self.max_ssim_epoch = current_epoch 393 | self.logger.info('Ref PSNR (max): %.3f (%d) \t SSIM (max): %.4f (%d)' 394 | %(self.max_psnr, self.max_psnr_epoch, self.max_ssim, self.max_ssim_epoch)) 395 | 396 | psnr_y_ave = sum(psnr_y_cuda_list)/len(psnr_y_cuda_list) 397 | ssim_y_ave = sum(ssim_y_cuda_list)/len(ssim_y_cuda_list) 398 | self.eval_psnr_y_cuda_list.append(psnr_y_ave) 399 | self.eval_ssim_y_cuda_list.append(ssim_y_ave) 400 | self.logger.info('Ref PSNR_Y (now): %.3f \t SSIM_Y (now): %.4f' %(psnr_y_ave, ssim_y_ave)) 401 | if (psnr_y_ave > self.max_y_psnr): 402 | self.max_y_psnr = psnr_y_ave 403 | self.max_y_psnr_epoch = current_epoch 404 | if (ssim_y_ave > self.max_y_ssim): 405 | self.max_y_ssim = ssim_y_ave 406 | self.max_y_ssim_epoch = current_epoch 407 | self.logger.info('Ref PSNR_Y (max): %.3f (%d) \t SSIM_Y (max): %.4f (%d)' 408 | %(self.max_y_psnr, self.max_y_psnr_epoch, self.max_y_ssim, self.max_y_ssim_epoch)) 409 | 410 | psnr_cuda_list.clear() 411 | ssim_cuda_list.clear() 412 | psnr_y_cuda_list.clear() 413 | ssim_y_cuda_list.clear() 414 | 415 | def test_basicvsr(self, save_img=True): 416 | if save_img: 417 | self.print_network(self.model) 418 | self.logger.info('Test process...') 419 | 420 | crop_h = self.args.FV_size 421 | crop_w = self.args.FV_size 422 | # kernel_size = self.args.FV_size 423 | kernel_size = 10 424 | stride_size = 5 425 | self.model.eval() 426 | with torch.no_grad(): 427 | for i_batch, sample_batched in enumerate(self.dataloader['test']): 428 | sample_batched = self.prepare(sample_batched) 429 | lr = sample_batched['LR'] 430 | lr_sr = sample_batched['LR_sr'] 431 | hr = sample_batched['HR'] 432 | ref = sample_batched['Ref'] 433 | ref_sp = sample_batched['Ref_sp'] 434 | fv_sp = sample_batched['FV_sp'] 435 | B, N, C, H, W = hr.size() 436 | 437 | sr = self.model(lrs=lr, fvs=ref, mks=ref_sp) 438 | if self.args.y_only: 439 | B, N, C, H, W = sr.size() 440 | sr = sr.view(B*N, C, H, W) 441 | B, N, C, H, W = lr_sr.size() 442 | lr_sr = lr_sr.view(B*N, C, H, W) 443 | lr_sr = rgb2yuv(lr_sr, y_only=False) 444 | sr = yuv2rgb(torch.cat((sr[:,0:1,:,:], lr_sr[:,1:3,:,:]), dim=1)) 445 | 446 | LR = (lr * 255.).squeeze().clone().detach() 447 | HR = (hr * 255.).squeeze().clone().detach() 448 | LR_fv = (sr * 255.).clip(0., 255.).squeeze().clone().detach() 449 | Ref_sp = fv_sp.squeeze().clone().detach() 450 | 451 | # direction = -1 452 | # scan_step = 8 453 | # accm_step = W - scan_step 454 | 455 | for n in range(N): 456 | psnr_score, ssim_score = self.foveated_metric(LR[n , :, :, :], LR_fv[n , :, :, :], HR[n , :, :, :], Ref_sp[n], (H, W), (crop_h, crop_w), kernel_size, stride_size) 457 | self.test_psnr.append((psnr_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy()) 458 | self.test_ssim.append((ssim_score.unsqueeze(2).repeat(1, 1, 3) * 255).round().cpu().detach().numpy()) 459 | # self.result_comp(LR_fv[n , :, :, :], accm_step) 460 | # accm_step += direction * scan_step 461 | # if accm_step < 0: 462 | # direction *= -1 463 | # accm_step += direction * scan_step 464 | # elif accm_step >= W: 465 | # direction *= -1 466 | # accm_step += direction * scan_step 467 | time.sleep(0.1) 468 | 469 | print('Process: {:.2f} % ...\r'.format(i_batch*100/len(self.dataloader['test'])), end='') 470 | 471 | for n in range(N): 472 | self.test_lr.append(LR[n, :, :, :].round().cpu().numpy()) 473 | self.test_hr.append(HR[n, :, :, :].round().cpu().numpy()) 474 | self.test_sr.append(LR_fv[n, :, :, :].round().cpu().numpy()) 475 | 476 | if save_img: 477 | if len(self.test_sr) == 100: 478 | N = 100 479 | save_path = os.path.join(self.args.save_dir, 'save_results', '{:05d}'.format(self.cur_clip)) 480 | if not os.path.isdir(save_path): 481 | os.mkdir(save_path) 482 | with get_writer(os.path.join(save_path, '{:05d}.gif'.format(i_batch)), mode="I", fps=5) as writer: 483 | for n in range(N): 484 | imsave(os.path.join(save_path, '{}_sr.png'.format(n)), np.transpose(self.test_sr[n], (1, 2, 0)).astype(np.uint8)) 485 | writer.append_data(np.transpose(self.test_sr[n], (1, 2, 0)).astype(np.uint8)) 486 | with get_writer(os.path.join(save_path, '{:05d}_gt.gif'.format(i_batch)), mode="I", fps=5) as writer: 487 | for n in range(N): 488 | imsave(os.path.join(save_path, '{}_hr.png'.format(n)), np.transpose(self.test_hr[n], (1, 2, 0)).astype(np.uint8)) 489 | writer.append_data(np.transpose(self.test_hr[n], (1, 2, 0)).astype(np.uint8)) 490 | with get_writer(os.path.join(save_path, '{:05d}_lr.gif'.format(i_batch)), mode="I", fps=5) as writer: 491 | for n in range(N): 492 | imsave(os.path.join(save_path, '{}_lr.png'.format(n)), np.transpose(self.test_lr[n], (1, 2, 0)).astype(np.uint8)) 493 | writer.append_data(np.transpose(self.test_lr[n], (1, 2, 0)).astype(np.uint8)) 494 | with get_writer(os.path.join(save_path, '{:05d}_psnr.gif'.format(i_batch)), mode="I", fps=5) as writer: 495 | for n in range(N): 496 | imsave(os.path.join(save_path, '{}_psnr.png'.format(n)), self.test_psnr[n].astype(np.uint8)) 497 | writer.append_data(self.test_psnr[n].astype(np.uint8)) 498 | with get_writer(os.path.join(save_path, '{:05d}_ssim.gif'.format(i_batch)), mode="I", fps=5) as writer: 499 | for n in range(N): 500 | imsave(os.path.join(save_path, '{}_ssim.png'.format(n)), self.test_ssim[n].astype(np.uint8)) 501 | writer.append_data(self.test_ssim[n].astype(np.uint8)) 502 | self.test_lr.clear() 503 | self.test_hr.clear() 504 | self.test_sr.clear() 505 | self.test_psnr.clear() 506 | self.test_ssim.clear() 507 | self.cur_clip += 1 508 | else: 509 | break 510 | 511 | if save_img: 512 | self.logger.info('Test over.') 513 | 514 | def test_basicvsr_scan(self, save_img=True): 515 | if save_img: 516 | self.print_network(self.model) 517 | self.logger.info('Test process...') 518 | self.logger.info('lr path: %s' %(self.args.lr_path)) 519 | self.logger.info('ref path: %s' %(self.args.ref_path)) 520 | 521 | self.model.eval() 522 | with torch.no_grad(): 523 | for i_batch, sample_batched in enumerate(self.dataloader['test']): 524 | sample_batched = self.prepare(sample_batched) 525 | lr = sample_batched['LR'] 526 | hr = sample_batched['HR'] 527 | ref = sample_batched['Ref'] 528 | ref_sp = sample_batched['Ref_sp'] 529 | fv_sp = sample_batched['FV_sp'] 530 | B, N, C, H, W = hr.size() 531 | 532 | sr = self.model(lrs=lr, fvs=ref, mks=ref_sp) 533 | #### Save images #### 534 | if save_img: 535 | sr_save = (sr * 255.).clip(0., 225.).squeeze() 536 | SR_t = sr 537 | sr_save = np.transpose(sr_save.round().cpu().detach().numpy(), (0, 2, 3, 1)).astype(np.uint8) 538 | for t in range(N): 539 | save_path = os.path.join(self.args.save_dir, 'save_results', '{:05d}'.format(i_batch)) 540 | if not os.path.isdir(save_path): 541 | os.mkdir(save_path) 542 | imsave(os.path.join(save_path, '{}.png'.format(t)), sr_save[t]) 543 | #### Save images #### 544 | LR = (lr * 255.).squeeze().clone().detach() 545 | HR = (hr * 255.).squeeze().clone().detach() 546 | LR_fv = (sr * 255.).clip(0., 255.).squeeze().clone().detach() 547 | Ref_sp = fv_sp.squeeze().clone().detach() 548 | for n in range(N): 549 | self.result_comp(LR_fv[n , :, :, :]) 550 | time.sleep(0.1) 551 | if save_img: 552 | with get_writer(os.path.join(save_path, '{:05d}.gif'.format(i_batch)), mode="I", fps=5) as writer: 553 | for n in range(N): 554 | writer.append_data(np.transpose(LR_fv[n].round().cpu().detach().numpy(), (1, 2, 0)).astype(np.uint8)) 555 | print('Process: {:.2f} % ...\r'.format(i_batch*100/len(self.dataloader['test'])), end='') 556 | else: 557 | break 558 | 559 | if save_img: 560 | self.logger.info('Test over.') 561 | 562 | def vis_plot_metric(self, phase): 563 | if phase == 'train': 564 | self.viz.plot_metric(loss=self.train_loss_list, psnr=self.train_psnr_list, ssim=self.train_ssim_list, \ 565 | psnr_cuda=self.train_psnr_cuda_list, ssim_cuda=self.train_ssim_cuda_list, \ 566 | psnr_y_cuda=self.train_psnr_y_cuda_list, ssim_y_cuda=self.train_ssim_y_cuda_list, \ 567 | phase='train') 568 | elif phase == 'eval': 569 | self.viz.plot_metric(loss=self.eval_loss_list, psnr=self.eval_psnr_list, ssim=self.eval_ssim_list, \ 570 | psnr_cuda=self.eval_psnr_cuda_list, ssim_cuda=self.eval_ssim_cuda_list, \ 571 | psnr_y_cuda=self.eval_psnr_y_cuda_list, ssim_y_cuda=self.eval_ssim_y_cuda_list, \ 572 | phase='eval') 573 | 574 | def _get_network_description(self, net): 575 | """Get the string and total parameters of the network""" 576 | if isinstance(net, nn.DataParallel): 577 | net = net.module 578 | return str(net), sum(map(lambda x: x.numel(), net.parameters())) 579 | 580 | def print_network(self, net): 581 | """Print the str and parameter number of a network. 582 | Args: 583 | net (nn.Module) 584 | """ 585 | net_str, net_params = self._get_network_description(net) 586 | if isinstance(net, nn.DataParallel): 587 | net_cls_str = (f'{net.__class__.__name__} - ' 588 | f'{net.module.__class__.__name__}') 589 | else: 590 | net_cls_str = f'{net.__class__.__name__}' 591 | 592 | self.logger.info( 593 | f'Network: {net_cls_str}, with parameters: {net_params:,d}') 594 | self.logger.info(net_str) 595 | 596 | def before_run(self): 597 | # NOTE: when resuming from a checkpoint, if 'initial_lr' is not saved, 598 | # it will be set according to the optimizer params 599 | for group in self.optimizer.param_groups: 600 | group.setdefault('initial_lr', group['lr']) 601 | self.base_lr = [ 602 | group['initial_lr'] for group in self.optimizer.param_groups 603 | ] 604 | 605 | def before_train_iter(self): 606 | self.regular_lr = self.get_regular_lr() 607 | self._set_lr(self.regular_lr) 608 | 609 | def get_regular_lr(self): 610 | return [self.get_lr(_base_lr) for _base_lr in self.base_lr] 611 | 612 | def get_lr(self, base_lr): 613 | progress = self.cur_iter 614 | target_lr = self.min_lr 615 | 616 | idx = get_position_from_periods(progress, self.cumulative_periods) 617 | current_weight = self.restart_weights[idx] 618 | nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1] 619 | current_periods = self.periods[idx] 620 | 621 | alpha = min((progress - nearest_restart) / current_periods, 1) 622 | return annealing_cos(base_lr, target_lr, alpha, current_weight) 623 | 624 | def _set_lr(self, lr_groups): 625 | for param_group, lr in zip(self.optimizer.param_groups, lr_groups): 626 | param_group['lr'] = lr 627 | 628 | def foveated_metric(self, LR, LR_fv, HR, mn, hw, crop, kernel_size, stride_size): 629 | m, n = mn 630 | h, w = hw 631 | crop_h, crop_w = crop 632 | HR_fold = F.unfold(HR.unsqueeze(0), kernel_size=(kernel_size, kernel_size), stride=stride_size) # [N, 3*11*11, Hr*Wr] 633 | LR_fv_fold = F.unfold(LR_fv.unsqueeze(0), kernel_size=(kernel_size, kernel_size), stride=stride_size) # [N, 3*11*11, Hr*Wr] 634 | B, C, N = HR_fold.size() 635 | HR_fold = HR_fold.permute(0, 2, 1).view(B*N , 3, kernel_size, kernel_size) # [N, 3*11*11, Hr*Wr] 636 | LR_fv_fold = LR_fv_fold.permute(0, 2, 1).view(B*N , 3, kernel_size, kernel_size) 637 | Hr = (h - kernel_size) // stride_size + 1 638 | Wr = (w - kernel_size) // stride_size + 1 639 | 640 | B, C, H, W = HR_fold.size() 641 | mask = torch.ones((B, 1, H, W)).float() 642 | psnr_score, ssim_score = calc_psnr_and_ssim_cuda(HR_fold, LR_fv_fold, mask, is_tensor=False, batch_avg=True) 643 | psnr_score = psnr_score.view(Hr, Wr) 644 | ssim_score = ssim_score.view(Hr, Wr) 645 | 646 | psnr_y_idx = (torch.argmax(psnr_score) // Wr) * stride_size 647 | psnr_x_idx = (torch.argmax(psnr_score) % Wr) * stride_size 648 | ssim_y_idx = (torch.argmax(ssim_score) // Wr) * stride_size 649 | ssim_x_idx = (torch.argmax(ssim_score) % Wr) * stride_size 650 | 651 | LR_fv[:, m:m+crop_h, n] = torch.tensor([255., 0., 0.]).unsqueeze(1).repeat((1,crop_h)) 652 | LR_fv[:, m:m+crop_h, n+crop_w-1] = torch.tensor([255., 0., 0.]).unsqueeze(1).repeat((1,crop_h)) 653 | LR_fv[:, m, n:n+crop_w] = torch.tensor([255., 0., 0.]).unsqueeze(1).repeat((1,crop_w)) 654 | LR_fv[:, m+crop_h-1, n:n+crop_w] = torch.tensor([255., 0., 0.]).unsqueeze(1).repeat((1,crop_w)) 655 | 656 | psnr_score_discrete = torch.zeros_like(psnr_score) 657 | ssim_score_discrete = torch.zeros_like(ssim_score) 658 | 659 | psnr_score = (psnr_score - psnr_score.min()) / (psnr_score.max() - psnr_score.min()) 660 | ssim_score = (ssim_score - ssim_score.min()) / (ssim_score.max() - ssim_score.min()) 661 | 662 | psnr_score_discrete[psnr_score <= 1.0] = 1.0 663 | psnr_score_discrete[psnr_score <= 0.9] = 0.9 664 | psnr_score_discrete[psnr_score <= 0.8] = 0.8 665 | psnr_score_discrete[psnr_score <= 0.7] = 0.7 666 | psnr_score_discrete[psnr_score <= 0.6] = 0.6 667 | psnr_score_discrete[psnr_score <= 0.5] = 0.5 668 | psnr_score_discrete[psnr_score <= 0.4] = 0.4 669 | psnr_score_discrete[psnr_score <= 0.3] = 0.3 670 | psnr_score_discrete[psnr_score <= 0.2] = 0.2 671 | psnr_score_discrete[psnr_score <= 0.1] = 0.1 672 | 673 | ssim_score_discrete[ssim_score <= 1.0] = 1.0 674 | ssim_score_discrete[ssim_score <= 0.9] = 0.9 675 | ssim_score_discrete[ssim_score <= 0.8] = 0.8 676 | ssim_score_discrete[ssim_score <= 0.7] = 0.7 677 | ssim_score_discrete[ssim_score <= 0.6] = 0.6 678 | ssim_score_discrete[ssim_score <= 0.5] = 0.5 679 | ssim_score_discrete[ssim_score <= 0.4] = 0.4 680 | ssim_score_discrete[ssim_score <= 0.3] = 0.3 681 | ssim_score_discrete[ssim_score <= 0.2] = 0.2 682 | ssim_score_discrete[ssim_score <= 0.1] = 0.1 683 | 684 | # self.viz.viz.image(HR.cpu().numpy(), win='{}'.format('HR'), opts=dict(title='{}, Image size : {}'.format('HR', HR.size()))) 685 | # self.viz.viz.image(LR.cpu().numpy(), win='{}'.format('LR'), opts=dict(title='{}, Image size : {}'.format('LR', LR.size()))) 686 | self.viz.viz.image(LR_fv.cpu().numpy(), win='{}'.format('FV'), opts=dict(title='{}, Image size : {}'.format('FV', LR_fv.size()))) 687 | # self.viz.viz.image(psnr_score.cpu().numpy(), win='{}'.format('PSNR_score'), opts=dict(title='{}, Image size : {}'.format('PSNR_score', psnr_score.size()))) 688 | # self.viz.viz.image(ssim_score.cpu().numpy(), win='{}'.format('SSIM_score'), opts=dict(title='{}, Image size : {}'.format('SSIM_score', ssim_score.size()))) 689 | # self.viz.viz.image(psnr_score_discrete.cpu().numpy(), win='{}'.format('PSNR_score_discrete'), opts=dict(title='{}, Image size : {}'.format('PSNR_score_discrete', psnr_score_discrete.size()))) 690 | # self.viz.viz.image(ssim_score_discrete.cpu().numpy(), win='{}'.format('SSIM_score_discrete'), opts=dict(title='{}, Image size : {}'.format('SSIM_score_discrete', ssim_score_discrete.size()))) 691 | 692 | return psnr_score, ssim_score 693 | 694 | def result_comp(self, LR_fv, SP_W): 695 | C, H, W = LR_fv.size() 696 | LR_fv[:, :, SP_W] = torch.tensor([255., 255., 255.]).unsqueeze(1).repeat((1,H)) 697 | self.viz.viz.image(LR_fv.cpu().numpy(), win='{}'.format('LR_fv'), opts=dict(title='{}, Image size : {}'.format('LR_fv', LR_fv.size()))) 698 | -------------------------------------------------------------------------------- /untar_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | tar xvf train_model.tar -C train/REDS/ 4 | rm train_model.tar 5 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import logging 4 | import cv2 5 | import os 6 | import shutil 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import pytorch_ssim as pytorch_ssim 13 | 14 | class Logger(object): 15 | def __init__(self, log_file_name, logger_name, log_level=logging.DEBUG): 16 | ### create a logger 17 | self.__logger = logging.getLogger(logger_name) 18 | 19 | ### set the log level 20 | self.__logger.setLevel(log_level) 21 | 22 | ### create a handler to write log file 23 | file_handler = logging.FileHandler(log_file_name) 24 | 25 | ### create a handler to print on console 26 | console_handler = logging.StreamHandler() 27 | 28 | ### define the output format of handlers 29 | formatter = logging.Formatter('[%(asctime)s] - [%(filename)s file line:%(lineno)d] - %(levelname)s: %(message)s') 30 | file_handler.setFormatter(formatter) 31 | console_handler.setFormatter(formatter) 32 | 33 | ### add handler to logger 34 | self.__logger.addHandler(file_handler) 35 | self.__logger.addHandler(console_handler) 36 | 37 | def get_log(self): 38 | return self.__logger 39 | 40 | 41 | def mkExpDir(args): 42 | if (os.path.exists(args.save_dir)): 43 | if (not args.reset): 44 | raise SystemExit('Error: save_dir "' + args.save_dir + '" already exists! Please set --reset True to delete the folder.') 45 | else: 46 | shutil.rmtree(args.save_dir) 47 | 48 | os.makedirs(args.save_dir) 49 | # os.makedirs(os.path.join(args.save_dir, 'img')) 50 | 51 | if ((not args.eval) and (not args.test)): 52 | os.makedirs(os.path.join(args.save_dir, 'model')) 53 | 54 | if ((args.eval and args.eval_save_results) or args.test): 55 | os.makedirs(os.path.join(args.save_dir, 'save_results')) 56 | 57 | args_file = open(os.path.join(args.save_dir, 'args.txt'), 'w') 58 | for k, v in vars(args).items(): 59 | args_file.write(k.rjust(30,' ') + '\t' + str(v) + '\n') 60 | 61 | _logger = Logger(log_file_name=os.path.join(args.save_dir, args.log_file_name), 62 | logger_name=args.logger_name).get_log() 63 | 64 | return _logger 65 | 66 | 67 | class MeanShift(nn.Conv2d): 68 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 69 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 70 | std = torch.Tensor(rgb_std) 71 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 72 | self.weight.data.div_(std.view(3, 1, 1, 1)) 73 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 74 | self.bias.data.div_(std) 75 | # self.requires_grad = False 76 | self.weight.requires_grad = False 77 | self.bias.requires_grad = False 78 | 79 | 80 | def calc_psnr(img1, img2): 81 | ### args: 82 | # img1: [h, w, c], range [0, 255] 83 | # img2: [h, w, c], range [0, 255] 84 | diff = (img1 - img2) / 255.0 85 | diff[:,:,0] = diff[:,:,0] * 65.738 / 256.0 86 | diff[:,:,1] = diff[:,:,1] * 129.057 / 256.0 87 | diff[:,:,2] = diff[:,:,2] * 25.064 / 256.0 88 | 89 | diff = np.sum(diff, axis=2) 90 | mse = np.mean(np.power(diff, 2)) 91 | return -10 * math.log10(mse) 92 | 93 | 94 | def calc_ssim(img1, img2): 95 | def ssim(img1, img2): 96 | C1 = (0.01 * 255)**2 97 | C2 = (0.03 * 255)**2 98 | 99 | img1 = img1.astype(np.float64) 100 | img2 = img2.astype(np.float64) 101 | kernel = cv2.getGaussianKernel(11, 1.5) 102 | window = np.outer(kernel, kernel.transpose()) 103 | 104 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 105 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 106 | mu1_sq = mu1**2 107 | mu2_sq = mu2**2 108 | mu1_mu2 = mu1 * mu2 109 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 110 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 111 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 112 | 113 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 114 | (sigma1_sq + sigma2_sq + C2)) 115 | return ssim_map.mean() 116 | 117 | ### args: 118 | # img1: [h, w, c], range [0, 255] 119 | # img2: [h, w, c], range [0, 255] 120 | # the same outputs as MATLAB's 121 | border = 0 122 | img1_y = np.dot(img1, [65.738,129.057,25.064])/256.0+16.0 123 | img2_y = np.dot(img2, [65.738,129.057,25.064])/256.0+16.0 124 | if not img1.shape == img2.shape: 125 | raise ValueError('Input images must have the same dimensions.') 126 | h, w = img1.shape[:2] 127 | img1_y = img1_y[border:h-border, border:w-border] 128 | img2_y = img2_y[border:h-border, border:w-border] 129 | 130 | if img1_y.ndim == 2: 131 | return ssim(img1_y, img2_y) 132 | elif img1.ndim == 3: 133 | if img1.shape[2] == 3: 134 | ssims = [] 135 | for i in range(3): 136 | ssims.append(ssim(img1, img2)) 137 | return np.array(ssims).mean() 138 | elif img1.shape[2] == 1: 139 | return ssim(np.squeeze(img1), np.squeeze(img2)) 140 | else: 141 | raise ValueError('Wrong input image dimensions.') 142 | 143 | 144 | def calc_psnr_and_ssim(sr, hr): 145 | ### args: 146 | # sr: pytorch tensor, range [-1, 1] 147 | # hr: pytorch tensor, range [-1, 1] 148 | 149 | ### prepare datac 150 | sr = (sr+1.) * 127.5 151 | hr = (hr+1.) * 127.5 152 | if (sr.size() != hr.size()): 153 | h_min = min(sr.size(2), hr.size(2)) 154 | w_min = min(sr.size(3), hr.size(3)) 155 | sr = sr[:, :, :h_min, :w_min] 156 | hr = hr[:, :, :h_min, :w_min] 157 | 158 | img1 = np.transpose(sr.squeeze().round().cpu().numpy(), (1,2,0)) 159 | img2 = np.transpose(hr.squeeze().round().cpu().numpy(), (1,2,0)) 160 | 161 | psnr = calc_psnr(img1, img2) 162 | ssim = calc_ssim(img1, img2) 163 | 164 | return psnr, ssim 165 | 166 | def psnr_cuda(img1, img2, mask, batch_avg=False): 167 | #### Image range [0, 1] #### 168 | if batch_avg: 169 | B, C, H, W = img1.size() 170 | mse = ((img1 - img2) ** 2).view(B, -1).mean(1) 171 | psnr = torch.where(mse == 0, -20 * torch.log10(torch.sqrt((1/255.)**2 / torch.tensor(C*H*W).to(mse.device))),\ 172 | -20 * torch.log10(torch.sqrt(mse))) 173 | else: 174 | # mse = torch.mean((img1 - img2) ** 2) 175 | B, C, H, W = img1.size() 176 | # C = 1 177 | mse = (((img1 - img2) ** 2)*mask).sum() / (mask.float().sum() * C) 178 | if mse == 0: 179 | psnr = -20 * torch.log10(torch.sqrt((1/255.)**2 / (torch.prod(torch.tensor(img1.size()))))) 180 | else: 181 | psnr = -20 * torch.log10(torch.sqrt(mse.cpu())) 182 | 183 | #### Image range [0, 255] #### 184 | # psnr = 20. * torch.log10(255. / torch.sqrt(mse)) 185 | return psnr 186 | 187 | def gaussian(window_size, sigma): 188 | gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 189 | return gauss/gauss.sum() 190 | 191 | def create_window(window_size, channel): 192 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 193 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 194 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 195 | return window 196 | 197 | def _ssim(img1, img2, mask, window, window_size, channel, size_average = True, batch_avg = False): 198 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 199 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 200 | 201 | mu1_sq = mu1.pow(2) 202 | mu2_sq = mu2.pow(2) 203 | mu1_mu2 = mu1*mu2 204 | 205 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 206 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 207 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 208 | 209 | C1 = 0.01**2 210 | C2 = 0.03**2 211 | #### Image range [0, 255] #### 212 | # C1 = (0.01 * 255)**2 213 | # C2 = (0.03 * 255)**2 214 | 215 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 216 | # print(ssim_map.size()) 217 | B, C, H, W = ssim_map.size() 218 | 219 | if batch_avg and ssim_map.dim() == 4: 220 | B, C, H, W = ssim_map.size() 221 | return ssim_map.view(B, -1).mean(1) 222 | elif size_average: 223 | # return ssim_map.mean() 224 | return (ssim_map*mask).sum() / (mask.float().sum() * C) 225 | else: 226 | return ssim_map.mean(1).mean(1).mean(1) 227 | 228 | def ssim(img1, img2, mask, window_size = 11, size_average = True, batch_avg = False): 229 | (_, channel, _, _) = img1.size() 230 | window = create_window(window_size, channel) 231 | 232 | if img1.is_cuda: 233 | window = window.cuda(img1.get_device()) 234 | window = window.type_as(img1) 235 | 236 | return _ssim(img1, img2, mask, window, window_size, channel, size_average, batch_avg) 237 | 238 | def ssim_cuda(img1, img2, mask, batch_avg=False): 239 | #### Image range [0, 1] #### 240 | return ssim(img1, img2, mask, batch_avg=batch_avg) 241 | 242 | def calc_psnr_and_ssim_cuda(sr, hr, mask, is_tensor=True, batch_avg=False): 243 | #### Convert Image range to [0, 1] #### 244 | min_val, max_val = hr.min(), hr.max() 245 | if((max_val - min_val) > 2): 246 | sr = sr / 255. 247 | hr = hr / 255. 248 | elif((max_val - min_val) > 1): 249 | sr = (sr+1.) / 2. 250 | hr = (hr+1.) / 2. 251 | #### Convert Image range to [0, 255] #### 252 | # sr = sr * 255. 253 | # hr = hr * 255. 254 | return psnr_cuda(sr, hr, mask, batch_avg=batch_avg), ssim_cuda(sr, hr, mask, batch_avg=batch_avg) 255 | 256 | def _convert_input_type_range(img): 257 | """Convert the type and range of the input image. 258 | It converts the input image to np.float32 type and range of [0, 1]. 259 | It is mainly used for pre-processing the input image in colorspace 260 | convertion functions such as rgb2ycbcr and ycbcr2rgb. 261 | Args: 262 | img (ndarray): The input image. It accepts: 263 | 1. np.uint8 type with range [0, 255]; 264 | 2. np.float32 type with range [0, 1]. 265 | Returns: 266 | (ndarray): The converted image with type of np.float32 and range of 267 | [0, 1]. 268 | """ 269 | img_type = img.dtype 270 | img = img.astype(np.float32) 271 | if img_type == np.float32: 272 | pass 273 | elif img_type == np.uint8: 274 | img /= 255. 275 | else: 276 | raise TypeError('The img type should be np.float32 or np.uint8, ' 277 | f'but got {img_type}') 278 | return img 279 | 280 | 281 | def _convert_output_type_range(img, dst_type): 282 | """Convert the type and range of the image according to dst_type. 283 | It converts the image to desired type and range. If `dst_type` is np.uint8, 284 | images will be converted to np.uint8 type with range [0, 255]. If 285 | `dst_type` is np.float32, it converts the image to np.float32 type with 286 | range [0, 1]. 287 | It is mainly used for post-processing images in colorspace convertion 288 | functions such as rgb2ycbcr and ycbcr2rgb. 289 | Args: 290 | img (ndarray): The image to be converted with np.float32 type and 291 | range [0, 255]. 292 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it 293 | converts the image to np.uint8 type with range [0, 255]. If 294 | dst_type is np.float32, it converts the image to np.float32 type 295 | with range [0, 1]. 296 | Returns: 297 | (ndarray): The converted image with desired type and range. 298 | """ 299 | if dst_type not in (np.uint8, np.float32): 300 | raise TypeError('The dst_type should be np.float32 or np.uint8, ' 301 | f'but got {dst_type}') 302 | if dst_type == np.uint8: 303 | img = img.round() 304 | else: 305 | img /= 255. 306 | return img.astype(dst_type) 307 | 308 | def bgr2ycbcr(img, y_only=False): 309 | """Convert a BGR image to YCbCr image. 310 | The bgr version of rgb2ycbcr. 311 | It implements the ITU-R BT.601 conversion for standard-definition 312 | television. See more details in 313 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 314 | It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. 315 | In OpenCV, it implements a JPEG conversion. See more details in 316 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 317 | Args: 318 | img (ndarray): The input image. It accepts: 319 | 1. np.uint8 type with range [0, 255]; 320 | 2. np.float32 type with range [0, 1]. 321 | y_only (bool): Whether to only return Y channel. Default: False. 322 | Returns: 323 | ndarray: The converted YCbCr image. The output image has the same type 324 | and range as input image. 325 | """ 326 | # img_type = img.dtype 327 | # img = _convert_input_type_range(img) 328 | if y_only: 329 | out_img = torch.matmul(img, torch.tensor([24.966, 128.553, 65.481]).to(img.device)) + 16.0 330 | out_img = out_img.unsqueeze(3).permute(0, 3, 1, 2) 331 | else: 332 | out_img = torch.matmul( 333 | img, torch.tensor([[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 334 | [65.481, -37.797, 112.0]]).to(img.device)) + torch.tensor([16, 128, 128]) 335 | out_img = out_img.permute(0, 3, 1, 2) 336 | # out_img = _convert_output_type_range(out_img, img_type) 337 | return out_img 338 | 339 | def make_colorwheel(): 340 | """ 341 | Generates a color wheel for optical flow visualization as presented in: 342 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 343 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 344 | 345 | Code follows the original C++ source code of Daniel Scharstein. 346 | Code follows the the Matlab source code of Deqing Sun. 347 | 348 | Returns: 349 | np.ndarray: Color wheel 350 | """ 351 | 352 | RY = 15 353 | YG = 6 354 | GC = 4 355 | CB = 11 356 | BM = 13 357 | MR = 6 358 | 359 | ncols = RY + YG + GC + CB + BM + MR 360 | colorwheel = np.zeros((ncols, 3)) 361 | col = 0 362 | 363 | # RY 364 | colorwheel[0:RY, 0] = 255 365 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 366 | col = col+RY 367 | # YG 368 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 369 | colorwheel[col:col+YG, 1] = 255 370 | col = col+YG 371 | # GC 372 | colorwheel[col:col+GC, 1] = 255 373 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 374 | col = col+GC 375 | # CB 376 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 377 | colorwheel[col:col+CB, 2] = 255 378 | col = col+CB 379 | # BM 380 | colorwheel[col:col+BM, 2] = 255 381 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 382 | col = col+BM 383 | # MR 384 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 385 | colorwheel[col:col+MR, 0] = 255 386 | return colorwheel 387 | 388 | 389 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 390 | """ 391 | Applies the flow color wheel to (possibly clipped) flow components u and v. 392 | 393 | According to the C++ source code of Daniel Scharstein 394 | According to the Matlab source code of Deqing Sun 395 | 396 | Args: 397 | u (np.ndarray): Input horizontal flow of shape [H,W] 398 | v (np.ndarray): Input vertical flow of shape [H,W] 399 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 400 | 401 | Returns: 402 | np.ndarray: Flow visualization image of shape [H,W,3] 403 | """ 404 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 405 | colorwheel = make_colorwheel() # shape [55x3] 406 | ncols = colorwheel.shape[0] 407 | rad = np.sqrt(np.square(u) + np.square(v)) 408 | a = np.arctan2(-v, -u)/np.pi 409 | fk = (a+1) / 2*(ncols-1) 410 | k0 = np.floor(fk).astype(np.int32) 411 | k1 = k0 + 1 412 | k1[k1 == ncols] = 0 413 | f = fk - k0 414 | for i in range(colorwheel.shape[1]): 415 | tmp = colorwheel[:,i] 416 | col0 = tmp[k0] / 255.0 417 | col1 = tmp[k1] / 255.0 418 | col = (1-f)*col0 + f*col1 419 | idx = (rad <= 1) 420 | col[idx] = 1 - rad[idx] * (1-col[idx]) 421 | col[~idx] = col[~idx] * 0.75 # out of range 422 | # Note the 2-i => BGR instead of RGB 423 | ch_idx = 2-i if convert_to_bgr else i 424 | flow_image[:,:,ch_idx] = np.floor(255 * col) 425 | return flow_image 426 | 427 | 428 | def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): 429 | """ 430 | Expects a two dimensional flow image of shape. 431 | 432 | Args: 433 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 434 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 435 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 436 | 437 | Returns: 438 | np.ndarray: Flow visualization image of shape [H,W,3] 439 | """ 440 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 441 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 442 | if clip_flow is not None: 443 | flow_uv = np.clip(flow_uv, 0, clip_flow) 444 | u = flow_uv[:,:,0] 445 | v = flow_uv[:,:,1] 446 | rad = np.sqrt(np.square(u) + np.square(v)) 447 | rad_max = np.max(rad) 448 | epsilon = 1e-5 449 | u = u / (rad_max + epsilon) 450 | v = v / (rad_max + epsilon) 451 | return flow_uv_to_colors(u, v, convert_to_bgr) --------------------------------------------------------------------------------