├── .idea └── workspace.xml ├── README.md ├── codes ├── calculate_metrics.py ├── data │ ├── LQGT_dataset.py │ ├── LQGT_enhance_dataset.py │ ├── LQ_dataset.py │ ├── REDS_dataset.py │ ├── Vimeo90K_dataset.py │ ├── __init__.py │ ├── data_sampler.py │ ├── util.py │ └── video_test_dataset.py ├── data_scripts │ ├── create_lmdb.py │ ├── extract_subimages.py │ ├── generate_LR_Vimeo90K.m │ ├── generate_mod_LR_bic.m │ ├── generate_mod_LR_bic.py │ ├── prepare_DIV2K_x4_dataset.sh │ ├── regroup_REDS.py │ ├── rename.py │ └── test_dataloader.py ├── metrics │ ├── calculate_PSNR_SSIM.m │ └── calculate_PSNR_SSIM.py ├── models │ ├── SRGAN_model.py │ ├── SR_model.py │ ├── Video_base_model.py │ ├── __init__.py │ ├── archs │ │ ├── CSRNet_arch.py │ │ ├── DUF_arch.py │ │ ├── EDVR_arch.py │ │ ├── RRDBNet_arch.py │ │ ├── SRResNet_arch.py │ │ ├── TOF_arch.py │ │ ├── __init__.py │ │ ├── arch_util.py │ │ ├── dcn │ │ │ ├── __init__.py │ │ │ ├── deform_conv.py │ │ │ ├── setup.py │ │ │ └── src │ │ │ │ ├── deform_conv_cuda.cpp │ │ │ │ └── deform_conv_cuda_kernel.cu │ │ └── discriminator_vgg_arch.py │ ├── base_model.py │ ├── loss.py │ ├── lr_scheduler.py │ └── networks.py ├── options │ ├── __init__.py │ ├── options.py │ ├── test │ │ ├── test_ESRGAN.yml │ │ ├── test_Enhance.yml │ │ ├── test_SRGAN.yml │ │ └── test_SRResNet.yml │ └── train │ │ ├── train_EDVR_M.yml │ │ ├── train_EDVR_woTSA_M.yml │ │ ├── train_ESRGAN.yml │ │ ├── train_Enhance.yml │ │ ├── train_SRGAN.yml │ │ └── train_SRResNet.yml ├── run_scripts.sh ├── scripts │ └── transfer_params_MSRResNet.py ├── test.py ├── test_CSRNet.py ├── train.py └── utils │ ├── __init__.py │ └── util.py ├── experiments └── pretrain_models │ └── csrnet.pth └── figures ├── csrnet_fig1.png └── csrnet_fig6.png /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Sequential Modulation for Efficient Global Image Retouching [Paper Link](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123580664.pdf) 2 | By Jingwen He*, Yihao Liu*, [Yu Qiao](http://mmlab.siat.ac.cn/yuqiao/), and [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ&hl=en) (* indicates equal contribution) 3 | 4 | 5 |

6 | 7 |

8 | Left: Compared with existing state-of-the-art methods, our method achieves 9 | superior performance with extremely few parameters (1/13 of HDRNet and 1/250 10 | of White-Box). The diameter of the circle represents the amount of trainable 11 | parameters. Right: Image retouching examples. 12 | 13 | 14 | 15 | 16 |

17 | 18 | 19 | 20 |

21 | The first row shows smooth transition effects between different styles (expert A 22 | to B) by image interpolation. In the second row, we use image interpolation to control 23 | the retouching strength from input image to the automatic retouched result. We denote 24 | the interpolation coefficient α for each image. 25 | 26 | ### BibTex 27 | @article{he2020conditional, 28 | title={Conditional Sequential Modulation for Efficient Global Image Retouching}, 29 | author={He, Jingwen and Liu, Yihao and Qiao, Yu and Dong, Chao}, 30 | journal={arXiv preprint arXiv:2009.10390}, 31 | year={2020} 32 | } 33 | 34 | 35 | ## Dependencies and Installation 36 | 37 | - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) 38 | - [PyTorch >= 1.0](https://pytorch.org/) 39 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 40 | - Python packages: `pip install numpy opencv-python lmdb pyyaml` 41 | - TensorBoard: 42 | - PyTorch >= 1.1: `pip install tb-nightly future` 43 | - PyTorch == 1.0: `pip install tensorboardX` 44 | 45 | 46 | ## Datasets 47 | 48 | Here, we provide the preprocessed datasets: [MIT-Adobe FiveK dataset](https://drive.google.com/drive/folders/1qrGLFzW7RBlBO1FqgrLPrq9p2_p11ZFs?usp=sharing), which contains both training pairs and testing pairs. 49 | - training pairs: {GT: expert_C_train; Input: raw_input_train} 50 | - testing pairs: {GT: expert_C_test; Input: raw_input_test} 51 | 52 | ## How to Test 53 | 1. Modify the configuration file [`options/test/test_Enhance.yml`](codes/options/test/test_Enhance.yml). e.g., `dataroot_GT`, `dataroot_LQ`, and `pretrain_model_G`. 54 | (We provide a pretrained model in [`experiments/pretrain_models/csrnet.pth`](experiments/pretrain_models/)) 55 | 1. Run command: 56 | ```c++ 57 | python test_CSRNet.py -opt options/test/test_Enhance.yml 58 | ``` 59 | 1. Modify the python file [`calculate_metrics.py`](codes/calculate_metrics.py): `input_path`, `GT_path` (Line 139, 140). Then run: 60 | ```c++ 61 | python calculate_metrics.py 62 | ``` 63 | 64 | ## How to Train 65 | 1. Modify the configuration file [`options/train/train_Enhance.yml`](codes/options/train/train_Enhance.yml). e.g., `dataroot_GT`, `dataroot_LQ`. 66 | 1. Run command: 67 | ```c++ 68 | python train.py -opt options/train/train_Enhance.yml 69 | ``` 70 | 71 | ## Acknowledgement 72 | 73 | - This code is based on [mmsr](https://github.com/open-mmlab/mmsr). 74 | - Thanks Yihao Liu for part of this work. 75 | -------------------------------------------------------------------------------- /codes/calculate_metrics.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import cv2 3 | from PIL import Image 4 | import numpy as np 5 | import math 6 | import os 7 | import tifffile as tiff 8 | from skimage import color 9 | 10 | def ProPhotoRGB2XYZ(pp_rgb,reverse=False): 11 | if not reverse: 12 | M = [[0.7976749, 0.1351917, 0.0313534], \ 13 | [0.2880402, 0.7118741, 0.0000857], \ 14 | [0.0000000, 0.0000000, 0.8252100]] 15 | else: 16 | M = [[ 1.34594337, -0.25560752, -0.05111183],\ 17 | [-0.54459882, 1.5081673, 0.02053511],\ 18 | [ 0, 0, 1.21181275]] 19 | M = np.array(M) 20 | sp = pp_rgb.shape 21 | xyz = np.transpose(np.dot(M, np.transpose(pp_rgb.reshape((sp[0] * sp[1], sp[2]))))) 22 | return xyz.reshape((sp[0], sp[1], 3)) 23 | 24 | def linearize_ProPhotoRGB(pp_rgb, reverse=False): 25 | if not reverse: 26 | gamma = 1.8 27 | else: 28 | gamma = 1.0/1.8 29 | pp_rgb = np.power(pp_rgb, gamma) 30 | return pp_rgb 31 | 32 | def XYZ_chromatic_adapt(xyz, src_white='D65', dest_white='D50'): 33 | if src_white == 'D65' and dest_white == 'D50': 34 | M = [[1.0478112, 0.0228866, -0.0501270], \ 35 | [0.0295424, 0.9904844, -0.0170491], \ 36 | [-0.0092345, 0.0150436, 0.7521316]] 37 | elif src_white == 'D50' and dest_white == 'D65': 38 | M = [[0.9555766, -0.0230393, 0.0631636], \ 39 | [-0.0282895, 1.0099416, 0.0210077], \ 40 | [0.0122982, -0.0204830, 1.3299098]] 41 | else: 42 | raise UtilCnnImageEnhanceError('invalid pair of source and destination white reference %s,%s')\ 43 | % (src_white, dest_white) 44 | M = np.array(M) 45 | sp = xyz.shape 46 | assert sp[2] == 3 47 | xyz = np.transpose(np.dot(M, np.transpose(xyz.reshape((sp[0] * sp[1], 3))))) 48 | return xyz.reshape((sp[0], sp[1], 3)) 49 | 50 | def read_tiff_16bit_img_into_XYZ(tiff_fn, exposure=0): 51 | pp_rgb = tiff.imread(tiff_fn) 52 | pp_rgb = np.float64(pp_rgb) / (2 ** 16 - 1.0) 53 | if not pp_rgb.shape[2] == 3: 54 | print('pp_rgb shape',pp_rgb.shape) 55 | raise UtilImageError('image channel number is not 3') 56 | pp_rgb = linearize_ProPhotoRGB(pp_rgb) 57 | pp_rgb *= np.power(2, exposure) 58 | xyz = ProPhotoRGB2XYZ(pp_rgb) 59 | xyz = XYZ_chromatic_adapt(xyz, src_white='D50', dest_white='D65') 60 | return xyz 61 | 62 | def read_tiff_16bit_img_into_LAB(tiff_fn, exposure=0, normalize_Lab=False): 63 | xyz = read_tiff_16bit_img_into_XYZ(tiff_fn, exposure) 64 | lab = color.xyz2lab(xyz) 65 | if normalize_Lab: 66 | normalize_Lab_image(lab) 67 | return lab 68 | 69 | 70 | 71 | def calculate_Lab_RMSE(img1, img2): 72 | # img1 and img2 have range [0, 255] 73 | #img1 = img1.astype(np.float64)#/255 74 | #img2 = img2.astype(np.float64)#/255 75 | num_pix = img1.shape[0]*img1.shape[1] 76 | 77 | Lab_RMSE = np.mean(np.sqrt(np.sum((img1 - img2)**2, axis=2))) # correct 1 78 | #Lab_RMSE = np.sum(np.sqrt(np.sum((img1 - img2) ** 2, axis=2))) / num_pix # correct 2 same with correct 1 79 | 80 | #Lab_RMSE = np.sqrt(np.sum(((img1 - img2) ** 2)) / num_pix) # a liiter different 81 | 82 | return Lab_RMSE 83 | 84 | def calculate_psnr(img1, img2): 85 | # img1 and img2 have range [0, 255] 86 | img1 = img1.astype(np.float64) 87 | img2 = img2.astype(np.float64) 88 | mse = np.mean((img1 - img2)**2) 89 | if mse == 0: 90 | return float('inf') 91 | return 20 * math.log10(255.0 / math.sqrt(mse)) 92 | 93 | 94 | def ssim_my(img1, img2): 95 | C1 = (0.01 * 255)**2 96 | C2 = (0.03 * 255)**2 97 | 98 | img1 = img1.astype(np.float64) 99 | img2 = img2.astype(np.float64) 100 | kernel = cv2.getGaussianKernel(11, 1.5) 101 | window = np.outer(kernel, kernel.transpose()) 102 | 103 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 104 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 105 | mu1_sq = mu1**2 106 | mu2_sq = mu2**2 107 | mu1_mu2 = mu1 * mu2 108 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 109 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 110 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 111 | 112 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 113 | (sigma1_sq + sigma2_sq + C2)) 114 | return ssim_map.mean() 115 | 116 | 117 | def calculate_ssim(img1, img2): 118 | '''calculate SSIM 119 | the same outputs as MATLAB's 120 | img1, img2: [0, 255] 121 | ''' 122 | if not img1.shape == img2.shape: 123 | raise ValueError('Input images must have the same dimensions.') 124 | if img1.ndim == 2: 125 | return ssim_my(img1, img2) 126 | elif img1.ndim == 3: 127 | if img1.shape[2] == 3: 128 | ssims = [] 129 | for i in range(3): 130 | ssims.append(ssim_my(img1, img2)) 131 | return np.array(ssims).mean() 132 | elif img1.shape[2] == 1: 133 | return ssim_my(np.squeeze(img1), np.squeeze(img2)) 134 | else: 135 | raise ValueError('Wrong input image dimensions.') 136 | 137 | # ########################################################## 138 | # Please specify the paths for input dir and ground truth dir. 139 | input_path="" 140 | GT_path="" 141 | 142 | input_fname_list = os.listdir(input_path) 143 | input_fname_list.sort() 144 | input_path_list = [os.path.join(input_path, fname) for fname in input_fname_list] 145 | 146 | GT_fname_list = os.listdir(GT_path) 147 | GT_fname_list.sort() 148 | GT_path_list = [os.path.join(GT_path, fname) for fname in GT_fname_list] 149 | 150 | assert len(input_path_list) == len(GT_path_list) 151 | print(len(input_path_list)) 152 | 153 | 154 | psnr_list = [] 155 | ssim_list = [] 156 | Lab_RMSE_list = [] 157 | for i in range(len(input_path_list)): 158 | assert input_fname_list[i].split('.')[0] == GT_fname_list[i].split('.')[0] 159 | img1 = cv2.imread(input_path_list[i], cv2.IMREAD_COLOR) 160 | img2 = cv2.imread(GT_path_list[i], cv2.IMREAD_COLOR) 161 | 162 | 163 | img1_rgb = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) 164 | img2_rgb = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) 165 | 166 | img1_lab = cv2.cvtColor(img1, cv2.COLOR_BGR2Lab) 167 | img2_lab = cv2.cvtColor(img2, cv2.COLOR_BGR2Lab) 168 | 169 | 170 | 171 | psnr = calculate_psnr(img1_rgb, img2_rgb) 172 | ssim = calculate_ssim(img1_rgb, img2_rgb) 173 | 174 | Lab_RMSE = calculate_Lab_RMSE(img1_lab, img2_lab) 175 | 176 | print('img: {} PSNR: {} SSIM: {} Lab_RMSE: {}'.format(input_fname_list[i].split('.')[0], psnr, ssim, Lab_RMSE)) 177 | 178 | psnr_list.append(psnr) 179 | ssim_list.append(ssim) 180 | Lab_RMSE_list.append(Lab_RMSE) 181 | 182 | print('Average PSNR: {} SSIM: {} Lab_RMSE: {} Total image: {}'.format(np.mean(psnr_list), np.mean(ssim_list), np.mean(Lab_RMSE_list), len(psnr_list))) -------------------------------------------------------------------------------- /codes/data/LQGT_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | 9 | 10 | class LQGTDataset(data.Dataset): 11 | """ 12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs. 13 | If only GT images are provided, generate LQ images on-the-fly. 14 | """ 15 | 16 | def __init__(self, opt): 17 | super(LQGTDataset, self).__init__() 18 | self.opt = opt 19 | self.data_type = self.opt['data_type'] 20 | self.paths_LQ, self.paths_GT = None, None 21 | self.sizes_LQ, self.sizes_GT = None, None 22 | self.LQ_env, self.GT_env = None, None # environments for lmdb 23 | 24 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) 25 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) 26 | assert self.paths_GT, 'Error: GT path is empty.' 27 | if self.paths_LQ and self.paths_GT: 28 | assert len(self.paths_LQ) == len( 29 | self.paths_GT 30 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format( 31 | len(self.paths_LQ), len(self.paths_GT)) 32 | self.random_scale_list = [1] 33 | 34 | def _init_lmdb(self): 35 | # https://github.com/chainer/chainermn/issues/129 36 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 37 | meminit=False) 38 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 39 | meminit=False) 40 | 41 | def __getitem__(self, index): 42 | if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): 43 | self._init_lmdb() 44 | GT_path, LQ_path = None, None 45 | scale = self.opt['scale'] 46 | GT_size = self.opt['GT_size'] 47 | 48 | # get GT image 49 | GT_path = self.paths_GT[index] 50 | resolution = [int(s) for s in self.sizes_GT[index].split('_') 51 | ] if self.data_type == 'lmdb' else None 52 | img_GT = util.read_img(self.GT_env, GT_path, resolution) 53 | if self.opt['phase'] != 'train': # modcrop in the validation / test phase 54 | img_GT = util.modcrop(img_GT, scale) 55 | #### downsample in base network 56 | img_GT = util.modcrop(img_GT, 2) 57 | if self.opt['color']: # change color space if necessary 58 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 59 | 60 | # get LQ image 61 | if self.paths_LQ: 62 | LQ_path = self.paths_LQ[index] 63 | resolution = [int(s) for s in self.sizes_LQ[index].split('_') 64 | ] if self.data_type == 'lmdb' else None 65 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) 66 | #### downsample in base network 67 | img_LQ = util.modcrop(img_LQ, 2) 68 | else: # down-sampling on-the-fly 69 | # randomly scale during training 70 | if self.opt['phase'] == 'train': 71 | random_scale = random.choice(self.random_scale_list) 72 | H_s, W_s, _ = img_GT.shape 73 | 74 | def _mod(n, random_scale, scale, thres): 75 | rlt = int(n * random_scale) 76 | rlt = (rlt // scale) * scale 77 | return thres if rlt < thres else rlt 78 | 79 | H_s = _mod(H_s, random_scale, scale, GT_size) 80 | W_s = _mod(W_s, random_scale, scale, GT_size) 81 | img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR) 82 | if img_GT.ndim == 2: 83 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) 84 | 85 | H, W, _ = img_GT.shape 86 | # using matlab imresize 87 | img_LQ = util.imresize_np(img_GT, 1 / scale, True) 88 | if img_LQ.ndim == 2: 89 | img_LQ = np.expand_dims(img_LQ, axis=2) 90 | 91 | if self.opt['phase'] == 'train': 92 | # if the image size is too small 93 | H, W, _ = img_GT.shape 94 | if H < GT_size or W < GT_size: 95 | img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR) 96 | # using matlab imresize 97 | img_LQ = util.imresize_np(img_GT, 1 / scale, True) 98 | if img_LQ.ndim == 2: 99 | img_LQ = np.expand_dims(img_LQ, axis=2) 100 | 101 | H, W, C = img_LQ.shape 102 | LQ_size = GT_size // scale 103 | 104 | # randomly crop 105 | rnd_h = random.randint(0, max(0, H - LQ_size)) 106 | rnd_w = random.randint(0, max(0, W - LQ_size)) 107 | img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] 108 | rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) 109 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] 110 | 111 | # augmentation - flip, rotate 112 | img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], 113 | self.opt['use_rot']) 114 | 115 | if self.opt['color']: # change color space if necessary 116 | img_LQ = util.channel_convert(C, self.opt['color'], 117 | [img_LQ])[0] # TODO during val no definition 118 | 119 | # BGR to RGB, HWC to CHW, numpy to tensor 120 | if img_GT.shape[2] == 3: 121 | img_GT = img_GT[:, :, [2, 1, 0]] 122 | img_LQ = img_LQ[:, :, [2, 1, 0]] 123 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 124 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 125 | 126 | if LQ_path is None: 127 | LQ_path = GT_path 128 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} 129 | 130 | def __len__(self): 131 | return len(self.paths_GT) 132 | -------------------------------------------------------------------------------- /codes/data/LQGT_enhance_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import data.util as util 5 | 6 | class LQGT_enhance_dataset(data.Dataset): 7 | def __init__(self, opt): 8 | super(LQGT_enhance_dataset, self).__init__() 9 | self.opt = opt 10 | self.data_type = self.opt['data_type'] 11 | self.paths_LQ, self.paths_GT = None, None 12 | self.sizes_LQ, self.sizes_GT = None, None 13 | self.LQ_env, self.GT_env = None, None # environments for lmdb 14 | 15 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) 16 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) 17 | assert self.paths_GT, 'Error: GT path is empty.' 18 | if self.paths_LQ and self.paths_GT: 19 | assert len(self.paths_LQ) == len( 20 | self.paths_GT 21 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format( 22 | len(self.paths_LQ), len(self.paths_GT)) 23 | 24 | def __getitem__(self, index): 25 | GT_path, LQ_path = None, None 26 | 27 | # get GT image 28 | GT_path = self.paths_GT[index] 29 | LQ_path = self.paths_LQ[index] 30 | img_GT = util.read_img(self.GT_env, GT_path) 31 | img_LQ = util.read_img(self.LQ_env, LQ_path) 32 | 33 | if self.opt['color']: 34 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 35 | img_LQ = util.channel_convert(img_LQ.shape[2], self.opt['color'], [img_LQ])[0] 36 | 37 | # BGR to RGB, HWC to CHW, numpy to tensor 38 | if img_GT.shape[2] == 3: 39 | img_GT = img_GT[:, :, [2, 1, 0]] 40 | img_LQ = img_LQ[:, :, [2, 1, 0]] 41 | 42 | H, W, _ = img_LQ.shape 43 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 44 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 45 | 46 | if LQ_path is None: 47 | LQ_path = GT_path 48 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} 49 | 50 | def __len__(self): 51 | return len(self.paths_GT) 52 | -------------------------------------------------------------------------------- /codes/data/LQ_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import lmdb 3 | import torch 4 | import torch.utils.data as data 5 | import data.util as util 6 | 7 | 8 | class LQDataset(data.Dataset): 9 | '''Read LQ images only in the test phase.''' 10 | 11 | def __init__(self, opt): 12 | super(LQDataset, self).__init__() 13 | self.opt = opt 14 | self.data_type = self.opt['data_type'] 15 | self.paths_LQ, self.paths_GT = None, None 16 | self.LQ_env = None # environment for lmdb 17 | 18 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) 19 | assert self.paths_LQ, 'Error: LQ paths are empty.' 20 | 21 | def _init_lmdb(self): 22 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 23 | meminit=False) 24 | 25 | def __getitem__(self, index): 26 | if self.data_type == 'lmdb' and self.LQ_env is None: 27 | self._init_lmdb() 28 | LQ_path = None 29 | 30 | # get LQ image 31 | LQ_path = self.paths_LQ[index] 32 | resolution = [int(s) for s in self.sizes_LQ[index].split('_') 33 | ] if self.data_type == 'lmdb' else None 34 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) 35 | H, W, C = img_LQ.shape 36 | 37 | if self.opt['color']: # change color space if necessary 38 | img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0] 39 | 40 | # BGR to RGB, HWC to CHW, numpy to tensor 41 | if img_LQ.shape[2] == 3: 42 | img_LQ = img_LQ[:, :, [2, 1, 0]] 43 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 44 | 45 | return {'LQ': img_LQ, 'LQ_path': LQ_path} 46 | 47 | def __len__(self): 48 | return len(self.paths_LQ) 49 | -------------------------------------------------------------------------------- /codes/data/REDS_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | REDS dataset 3 | support reading images from lmdb, image folder and memcached 4 | ''' 5 | import os.path as osp 6 | import random 7 | import pickle 8 | import logging 9 | import numpy as np 10 | import cv2 11 | import lmdb 12 | import torch 13 | import torch.utils.data as data 14 | import data.util as util 15 | try: 16 | import mc # import memcached 17 | except ImportError: 18 | pass 19 | 20 | logger = logging.getLogger('base') 21 | 22 | 23 | class REDSDataset(data.Dataset): 24 | ''' 25 | Reading the training REDS dataset 26 | key example: 000_00000000 27 | GT: Ground-Truth; 28 | LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames 29 | support reading N LQ frames, N = 1, 3, 5, 7 30 | ''' 31 | 32 | def __init__(self, opt): 33 | super(REDSDataset, self).__init__() 34 | self.opt = opt 35 | # temporal augmentation 36 | self.interval_list = opt['interval_list'] 37 | self.random_reverse = opt['random_reverse'] 38 | logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format( 39 | ','.join(str(x) for x in opt['interval_list']), self.random_reverse)) 40 | 41 | self.half_N_frames = opt['N_frames'] // 2 42 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] 43 | self.data_type = self.opt['data_type'] 44 | self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs 45 | #### directly load image keys 46 | if self.data_type == 'lmdb': 47 | self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT']) 48 | logger.info('Using lmdb meta info for cache keys.') 49 | elif opt['cache_keys']: 50 | logger.info('Using cache keys: {}'.format(opt['cache_keys'])) 51 | self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys'] 52 | else: 53 | raise ValueError( 54 | 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]') 55 | 56 | # remove the REDS4 for testing 57 | self.paths_GT = [ 58 | v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020'] 59 | ] 60 | assert self.paths_GT, 'Error: GT path is empty.' 61 | 62 | if self.data_type == 'lmdb': 63 | self.GT_env, self.LQ_env = None, None 64 | elif self.data_type == 'mc': # memcached 65 | self.mclient = None 66 | elif self.data_type == 'img': 67 | pass 68 | else: 69 | raise ValueError('Wrong data type: {}'.format(self.data_type)) 70 | 71 | def _init_lmdb(self): 72 | # https://github.com/chainer/chainermn/issues/129 73 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 74 | meminit=False) 75 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 76 | meminit=False) 77 | 78 | def _ensure_memcached(self): 79 | if self.mclient is None: 80 | # specify the config files 81 | server_list_config_file = None 82 | client_config_file = None 83 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, 84 | client_config_file) 85 | 86 | def _read_img_mc(self, path): 87 | ''' Return BGR, HWC, [0, 255], uint8''' 88 | value = mc.pyvector() 89 | self.mclient.Get(path, value) 90 | value_buf = mc.ConvertBuffer(value) 91 | img_array = np.frombuffer(value_buf, np.uint8) 92 | img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) 93 | return img 94 | 95 | def _read_img_mc_BGR(self, path, name_a, name_b): 96 | ''' Read BGR channels separately and then combine for 1M limits in cluster''' 97 | img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png')) 98 | img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png')) 99 | img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png')) 100 | img = cv2.merge((img_B, img_G, img_R)) 101 | return img 102 | 103 | def __getitem__(self, index): 104 | if self.data_type == 'mc': 105 | self._ensure_memcached() 106 | elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): 107 | self._init_lmdb() 108 | 109 | scale = self.opt['scale'] 110 | GT_size = self.opt['GT_size'] 111 | key = self.paths_GT[index] 112 | name_a, name_b = key.split('_') 113 | center_frame_idx = int(name_b) 114 | 115 | #### determine the neighbor frames 116 | interval = random.choice(self.interval_list) 117 | if self.opt['border_mode']: 118 | direction = 1 # 1: forward; 0: backward 119 | N_frames = self.opt['N_frames'] 120 | if self.random_reverse and random.random() < 0.5: 121 | direction = random.choice([0, 1]) 122 | if center_frame_idx + interval * (N_frames - 1) > 99: 123 | direction = 0 124 | elif center_frame_idx - interval * (N_frames - 1) < 0: 125 | direction = 1 126 | # get the neighbor list 127 | if direction == 1: 128 | neighbor_list = list( 129 | range(center_frame_idx, center_frame_idx + interval * N_frames, interval)) 130 | else: 131 | neighbor_list = list( 132 | range(center_frame_idx, center_frame_idx - interval * N_frames, -interval)) 133 | name_b = '{:08d}'.format(neighbor_list[0]) 134 | else: 135 | # ensure not exceeding the borders 136 | while (center_frame_idx + self.half_N_frames * interval > 137 | 99) or (center_frame_idx - self.half_N_frames * interval < 0): 138 | center_frame_idx = random.randint(0, 99) 139 | # get the neighbor list 140 | neighbor_list = list( 141 | range(center_frame_idx - self.half_N_frames * interval, 142 | center_frame_idx + self.half_N_frames * interval + 1, interval)) 143 | if self.random_reverse and random.random() < 0.5: 144 | neighbor_list.reverse() 145 | name_b = '{:08d}'.format(neighbor_list[self.half_N_frames]) 146 | 147 | assert len( 148 | neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format( 149 | len(neighbor_list)) 150 | 151 | #### get the GT image (as the center frame) 152 | if self.data_type == 'mc': 153 | img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b) 154 | img_GT = img_GT.astype(np.float32) / 255. 155 | elif self.data_type == 'lmdb': 156 | img_GT = util.read_img(self.GT_env, key, (3, 720, 1280)) 157 | else: 158 | img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png')) 159 | 160 | #### get LQ images 161 | LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280) 162 | img_LQ_l = [] 163 | for v in neighbor_list: 164 | img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v)) 165 | if self.data_type == 'mc': 166 | if self.LR_input: 167 | img_LQ = self._read_img_mc(img_LQ_path) 168 | else: 169 | img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v)) 170 | img_LQ = img_LQ.astype(np.float32) / 255. 171 | elif self.data_type == 'lmdb': 172 | img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple) 173 | else: 174 | img_LQ = util.read_img(None, img_LQ_path) 175 | img_LQ_l.append(img_LQ) 176 | 177 | if self.opt['phase'] == 'train': 178 | C, H, W = LQ_size_tuple # LQ size 179 | # randomly crop 180 | if self.LR_input: 181 | LQ_size = GT_size // scale 182 | rnd_h = random.randint(0, max(0, H - LQ_size)) 183 | rnd_w = random.randint(0, max(0, W - LQ_size)) 184 | img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l] 185 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) 186 | img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] 187 | else: 188 | rnd_h = random.randint(0, max(0, H - GT_size)) 189 | rnd_w = random.randint(0, max(0, W - GT_size)) 190 | img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l] 191 | img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] 192 | 193 | # augmentation - flip, rotate 194 | img_LQ_l.append(img_GT) 195 | rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot']) 196 | img_LQ_l = rlt[0:-1] 197 | img_GT = rlt[-1] 198 | 199 | # stack LQ images to NHWC, N is the frame number 200 | img_LQs = np.stack(img_LQ_l, axis=0) 201 | # BGR to RGB, HWC to CHW, numpy to tensor 202 | img_GT = img_GT[:, :, [2, 1, 0]] 203 | img_LQs = img_LQs[:, :, :, [2, 1, 0]] 204 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 205 | img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs, 206 | (0, 3, 1, 2)))).float() 207 | return {'LQs': img_LQs, 'GT': img_GT, 'key': key} 208 | 209 | def __len__(self): 210 | return len(self.paths_GT) 211 | -------------------------------------------------------------------------------- /codes/data/Vimeo90K_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Vimeo90K dataset 3 | support reading images from lmdb, image folder and memcached 4 | ''' 5 | import os.path as osp 6 | import random 7 | import pickle 8 | import logging 9 | import numpy as np 10 | import cv2 11 | import lmdb 12 | import torch 13 | import torch.utils.data as data 14 | import data.util as util 15 | try: 16 | import mc # import memcached 17 | except ImportError: 18 | pass 19 | logger = logging.getLogger('base') 20 | 21 | 22 | class Vimeo90KDataset(data.Dataset): 23 | ''' 24 | Reading the training Vimeo90K dataset 25 | key example: 00001_0001 (_1, ..., _7) 26 | GT (Ground-Truth): 4th frame; 27 | LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame 28 | ''' 29 | 30 | def __init__(self, opt): 31 | super(Vimeo90KDataset, self).__init__() 32 | self.opt = opt 33 | # temporal augmentation 34 | self.interval_list = opt['interval_list'] 35 | self.random_reverse = opt['random_reverse'] 36 | logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format( 37 | ','.join(str(x) for x in opt['interval_list']), self.random_reverse)) 38 | 39 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] 40 | self.data_type = self.opt['data_type'] 41 | self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs 42 | 43 | #### determine the LQ frame list 44 | ''' 45 | N | frames 46 | 1 | 4 47 | 3 | 3,4,5 48 | 5 | 2,3,4,5,6 49 | 7 | 1,2,3,4,5,6,7 50 | ''' 51 | self.LQ_frames_list = [] 52 | for i in range(opt['N_frames']): 53 | self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2) 54 | 55 | #### directly load image keys 56 | if self.data_type == 'lmdb': 57 | self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT']) 58 | logger.info('Using lmdb meta info for cache keys.') 59 | elif opt['cache_keys']: 60 | logger.info('Using cache keys: {}'.format(opt['cache_keys'])) 61 | self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys'] 62 | else: 63 | raise ValueError( 64 | 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]') 65 | assert self.paths_GT, 'Error: GT path is empty.' 66 | 67 | if self.data_type == 'lmdb': 68 | self.GT_env, self.LQ_env = None, None 69 | elif self.data_type == 'mc': # memcached 70 | self.mclient = None 71 | elif self.data_type == 'img': 72 | pass 73 | else: 74 | raise ValueError('Wrong data type: {}'.format(self.data_type)) 75 | 76 | def _init_lmdb(self): 77 | # https://github.com/chainer/chainermn/issues/129 78 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 79 | meminit=False) 80 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 81 | meminit=False) 82 | 83 | def _ensure_memcached(self): 84 | if self.mclient is None: 85 | # specify the config files 86 | server_list_config_file = None 87 | client_config_file = None 88 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, 89 | client_config_file) 90 | 91 | def _read_img_mc(self, path): 92 | ''' Return BGR, HWC, [0, 255], uint8''' 93 | value = mc.pyvector() 94 | self.mclient.Get(path, value) 95 | value_buf = mc.ConvertBuffer(value) 96 | img_array = np.frombuffer(value_buf, np.uint8) 97 | img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED) 98 | return img 99 | 100 | def __getitem__(self, index): 101 | if self.data_type == 'mc': 102 | self._ensure_memcached() 103 | elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None): 104 | self._init_lmdb() 105 | 106 | scale = self.opt['scale'] 107 | GT_size = self.opt['GT_size'] 108 | key = self.paths_GT[index] 109 | name_a, name_b = key.split('_') 110 | #### get the GT image (as the center frame) 111 | if self.data_type == 'mc': 112 | img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png')) 113 | img_GT = img_GT.astype(np.float32) / 255. 114 | elif self.data_type == 'lmdb': 115 | img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448)) 116 | else: 117 | img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png')) 118 | 119 | #### get LQ images 120 | LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448) 121 | img_LQ_l = [] 122 | for v in self.LQ_frames_list: 123 | if self.data_type == 'mc': 124 | img_LQ = self._read_img_mc( 125 | osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v))) 126 | img_LQ = img_LQ.astype(np.float32) / 255. 127 | elif self.data_type == 'lmdb': 128 | img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple) 129 | else: 130 | img_LQ = util.read_img(None, 131 | osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v))) 132 | img_LQ_l.append(img_LQ) 133 | 134 | if self.opt['phase'] == 'train': 135 | C, H, W = LQ_size_tuple # LQ size 136 | # randomly crop 137 | if self.LR_input: 138 | LQ_size = GT_size // scale 139 | rnd_h = random.randint(0, max(0, H - LQ_size)) 140 | rnd_w = random.randint(0, max(0, W - LQ_size)) 141 | img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l] 142 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) 143 | img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :] 144 | else: 145 | rnd_h = random.randint(0, max(0, H - GT_size)) 146 | rnd_w = random.randint(0, max(0, W - GT_size)) 147 | img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l] 148 | img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] 149 | 150 | # augmentation - flip, rotate 151 | img_LQ_l.append(img_GT) 152 | rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot']) 153 | img_LQ_l = rlt[0:-1] 154 | img_GT = rlt[-1] 155 | 156 | # stack LQ images to NHWC, N is the frame number 157 | img_LQs = np.stack(img_LQ_l, axis=0) 158 | # BGR to RGB, HWC to CHW, numpy to tensor 159 | img_GT = img_GT[:, :, [2, 1, 0]] 160 | img_LQs = img_LQs[:, :, :, [2, 1, 0]] 161 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 162 | img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs, 163 | (0, 3, 1, 2)))).float() 164 | return {'LQs': img_LQs, 'GT': img_GT, 'key': key} 165 | 166 | def __len__(self): 167 | return len(self.paths_GT) 168 | -------------------------------------------------------------------------------- /codes/data/__init__.py: -------------------------------------------------------------------------------- 1 | """create dataset and dataloader""" 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 8 | phase = dataset_opt['phase'] 9 | if phase == 'train': 10 | if opt['dist']: 11 | world_size = torch.distributed.get_world_size() 12 | num_workers = dataset_opt['n_workers'] 13 | assert dataset_opt['batch_size'] % world_size == 0 14 | batch_size = dataset_opt['batch_size'] // world_size 15 | shuffle = False 16 | else: 17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 18 | batch_size = dataset_opt['batch_size'] 19 | shuffle = True 20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 21 | num_workers=num_workers, sampler=sampler, drop_last=True, 22 | pin_memory=False) 23 | else: 24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, 25 | pin_memory=False) 26 | 27 | 28 | def create_dataset(dataset_opt): 29 | mode = dataset_opt['mode'] 30 | # datasets for image restoration and image enhancement 31 | if mode == 'LQ': 32 | from data.LQ_dataset import LQDataset as D 33 | elif mode == 'LQGT': 34 | from data.LQGT_dataset import LQGTDataset as D 35 | elif mode == 'LQGT_cond': 36 | from data.LQGT_cond_dataset import LQGT_cond_Dataset as D 37 | elif mode == 'LQGT_enhance': 38 | from data.LQGT_enhance_dataset import LQGT_enhance_dataset as D 39 | # datasets for video restoration 40 | elif mode == 'REDS': 41 | from data.REDS_dataset import REDSDataset as D 42 | elif mode == 'Vimeo90K': 43 | from data.Vimeo90K_dataset import Vimeo90KDataset as D 44 | elif mode == 'video_test': 45 | from data.video_test_dataset import VideoTestDataset as D 46 | else: 47 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 48 | dataset = D(dataset_opt) 49 | 50 | logger = logging.getLogger('base') 51 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 52 | dataset_opt['name'])) 53 | return dataset 54 | -------------------------------------------------------------------------------- /codes/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the 4 | dataloader after each epoch 5 | """ 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | import torch.distributed as dist 10 | 11 | 12 | class DistIterSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 31 | if num_replicas is None: 32 | if not dist.is_available(): 33 | raise RuntimeError("Requires distributed package to be available") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | rank = dist.get_rank() 39 | self.dataset = dataset 40 | self.num_replicas = num_replicas 41 | self.rank = rank 42 | self.epoch = 0 43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 44 | self.total_size = self.num_samples * self.num_replicas 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | g.manual_seed(self.epoch) 50 | indices = torch.randperm(self.total_size, generator=g).tolist() 51 | 52 | dsize = len(self.dataset) 53 | indices = [v % dsize for v in indices] 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /codes/data/video_test_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | import torch.utils.data as data 4 | import data.util as util 5 | 6 | 7 | class VideoTestDataset(data.Dataset): 8 | """ 9 | A video test dataset. Support: 10 | Vid4 11 | REDS4 12 | Vimeo90K-Test 13 | 14 | no need to prepare LMDB files 15 | """ 16 | 17 | def __init__(self, opt): 18 | super(VideoTestDataset, self).__init__() 19 | self.opt = opt 20 | self.cache_data = opt['cache_data'] 21 | self.half_N_frames = opt['N_frames'] // 2 22 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ'] 23 | self.data_type = self.opt['data_type'] 24 | self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []} 25 | if self.data_type == 'lmdb': 26 | raise ValueError('No need to use LMDB during validation/test.') 27 | #### Generate data info and cache data 28 | self.imgs_LQ, self.imgs_GT = {}, {} 29 | if opt['name'].lower() in ['vid4', 'reds4']: 30 | subfolders_LQ = util.glob_file_list(self.LQ_root) 31 | subfolders_GT = util.glob_file_list(self.GT_root) 32 | for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT): 33 | subfolder_name = osp.basename(subfolder_GT) 34 | img_paths_LQ = util.glob_file_list(subfolder_LQ) 35 | img_paths_GT = util.glob_file_list(subfolder_GT) 36 | max_idx = len(img_paths_LQ) 37 | assert max_idx == len( 38 | img_paths_GT), 'Different number of images in LQ and GT folders' 39 | self.data_info['path_LQ'].extend(img_paths_LQ) 40 | self.data_info['path_GT'].extend(img_paths_GT) 41 | self.data_info['folder'].extend([subfolder_name] * max_idx) 42 | for i in range(max_idx): 43 | self.data_info['idx'].append('{}/{}'.format(i, max_idx)) 44 | border_l = [0] * max_idx 45 | for i in range(self.half_N_frames): 46 | border_l[i] = 1 47 | border_l[max_idx - i - 1] = 1 48 | self.data_info['border'].extend(border_l) 49 | 50 | if self.cache_data: 51 | self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ) 52 | self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT) 53 | elif opt['name'].lower() in ['vimeo90k-test']: 54 | pass # TODO 55 | else: 56 | raise ValueError( 57 | 'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.') 58 | 59 | def __getitem__(self, index): 60 | # path_LQ = self.data_info['path_LQ'][index] 61 | # path_GT = self.data_info['path_GT'][index] 62 | folder = self.data_info['folder'][index] 63 | idx, max_idx = self.data_info['idx'][index].split('/') 64 | idx, max_idx = int(idx), int(max_idx) 65 | border = self.data_info['border'][index] 66 | 67 | if self.cache_data: 68 | select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'], 69 | padding=self.opt['padding']) 70 | imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx)) 71 | img_GT = self.imgs_GT[folder][idx] 72 | else: 73 | pass # TODO 74 | 75 | return { 76 | 'LQs': imgs_LQ, 77 | 'GT': img_GT, 78 | 'folder': folder, 79 | 'idx': self.data_info['idx'][index], 80 | 'border': border 81 | } 82 | 83 | def __len__(self): 84 | return len(self.data_info['path_GT']) 85 | -------------------------------------------------------------------------------- /codes/data_scripts/extract_subimages.py: -------------------------------------------------------------------------------- 1 | """A multi-thread tool to crop large images to sub-images for faster IO.""" 2 | import os 3 | import os.path as osp 4 | import sys 5 | from multiprocessing import Pool 6 | import numpy as np 7 | import cv2 8 | from PIL import Image 9 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 10 | from utils.util import ProgressBar # noqa: E402 11 | import data.util as data_util # noqa: E402 12 | 13 | 14 | def main(): 15 | mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs) 16 | opt = {} 17 | opt['n_thread'] = 20 18 | opt['compression_level'] = 3 # 3 is the default value in cv2 19 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer 20 | # compression time. If read raw images during training, use 0 for faster IO speed. 21 | if mode == 'single': 22 | opt['input_folder'] = '../../datasets/DIV2K_train_HR' 23 | opt['save_folder'] = '../../datasets/DIV2K_sub' 24 | opt['crop_sz'] = 480 # the size of each sub-image 25 | opt['step'] = 240 # step of the sliding crop window 26 | opt['thres_sz'] = 48 # size threshold 27 | extract_signle(opt) 28 | elif mode == 'pair': 29 | GT_folder = '../../datasets/DIV2K_train_HR' 30 | LR_folder = '../../datasets/DIV2K_train_LR_bicubic/X4' 31 | save_GT_folder = '../../datasets/DIV2K_sub' 32 | save_LR_folder = '../../datasets/DIV2K800_sub_bicLRx4' 33 | scale_ratio = 4 34 | crop_sz = 480 # the size of each sub-image (GT) 35 | step = 240 # step of the sliding crop window (GT) 36 | thres_sz = 48 # size threshold 37 | ######################################################################## 38 | # check that all the GT and LR images have correct scale ratio 39 | img_GT_list = data_util._get_paths_from_images(GT_folder) 40 | img_LR_list = data_util._get_paths_from_images(LR_folder) 41 | assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.' 42 | for path_GT, path_LR in zip(img_GT_list, img_LR_list): 43 | img_GT = Image.open(path_GT) 44 | img_LR = Image.open(path_LR) 45 | w_GT, h_GT = img_GT.size 46 | w_LR, h_LR = img_LR.size 47 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 48 | w_GT, scale_ratio, w_LR, path_GT) 49 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501 50 | w_GT, scale_ratio, w_LR, path_GT) 51 | # check crop size, step and threshold size 52 | assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format( 53 | scale_ratio) 54 | assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio) 55 | assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format( 56 | scale_ratio) 57 | print('process GT...') 58 | opt['input_folder'] = GT_folder 59 | opt['save_folder'] = save_GT_folder 60 | opt['crop_sz'] = crop_sz 61 | opt['step'] = step 62 | opt['thres_sz'] = thres_sz 63 | extract_signle(opt) 64 | print('process LR...') 65 | opt['input_folder'] = LR_folder 66 | opt['save_folder'] = save_LR_folder 67 | opt['crop_sz'] = crop_sz // scale_ratio 68 | opt['step'] = step // scale_ratio 69 | opt['thres_sz'] = thres_sz // scale_ratio 70 | extract_signle(opt) 71 | assert len(data_util._get_paths_from_images(save_GT_folder)) == len( 72 | data_util._get_paths_from_images( 73 | save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.' 74 | else: 75 | raise ValueError('Wrong mode.') 76 | 77 | 78 | def extract_signle(opt): 79 | input_folder = opt['input_folder'] 80 | save_folder = opt['save_folder'] 81 | if not osp.exists(save_folder): 82 | os.makedirs(save_folder) 83 | print('mkdir [{:s}] ...'.format(save_folder)) 84 | else: 85 | print('Folder [{:s}] already exists. Exit...'.format(save_folder)) 86 | sys.exit(1) 87 | img_list = data_util._get_paths_from_images(input_folder) 88 | 89 | def update(arg): 90 | pbar.update(arg) 91 | 92 | pbar = ProgressBar(len(img_list)) 93 | 94 | pool = Pool(opt['n_thread']) 95 | for path in img_list: 96 | pool.apply_async(worker, args=(path, opt), callback=update) 97 | pool.close() 98 | pool.join() 99 | print('All subprocesses done.') 100 | 101 | 102 | def worker(path, opt): 103 | crop_sz = opt['crop_sz'] 104 | step = opt['step'] 105 | thres_sz = opt['thres_sz'] 106 | img_name = osp.basename(path) 107 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 108 | 109 | n_channels = len(img.shape) 110 | if n_channels == 2: 111 | h, w = img.shape 112 | elif n_channels == 3: 113 | h, w, c = img.shape 114 | else: 115 | raise ValueError('Wrong image shape - {}'.format(n_channels)) 116 | 117 | h_space = np.arange(0, h - crop_sz + 1, step) 118 | if h - (h_space[-1] + crop_sz) > thres_sz: 119 | h_space = np.append(h_space, h - crop_sz) 120 | w_space = np.arange(0, w - crop_sz + 1, step) 121 | if w - (w_space[-1] + crop_sz) > thres_sz: 122 | w_space = np.append(w_space, w - crop_sz) 123 | 124 | index = 0 125 | for x in h_space: 126 | for y in w_space: 127 | index += 1 128 | if n_channels == 2: 129 | crop_img = img[x:x + crop_sz, y:y + crop_sz] 130 | else: 131 | crop_img = img[x:x + crop_sz, y:y + crop_sz, :] 132 | crop_img = np.ascontiguousarray(crop_img) 133 | cv2.imwrite( 134 | osp.join(opt['save_folder'], 135 | img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img, 136 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 137 | return 'Processing {:s} ...'.format(img_name) 138 | 139 | 140 | if __name__ == '__main__': 141 | main() 142 | -------------------------------------------------------------------------------- /codes/data_scripts/generate_LR_Vimeo90K.m: -------------------------------------------------------------------------------- 1 | function generate_LR_Vimeo90K() 2 | %% matlab code to genetate bicubic-downsampled for Vimeo90K dataset 3 | 4 | up_scale = 4; 5 | mod_scale = 4; 6 | idx = 0; 7 | filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png'); 8 | for i = 1 : length(filepaths) 9 | [~,imname,ext] = fileparts(filepaths(i).name); 10 | folder_path = filepaths(i).folder; 11 | save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4'); 12 | if ~exist(save_LR_folder, 'dir') 13 | mkdir(save_LR_folder); 14 | end 15 | if isempty(imname) 16 | disp('Ignore . folder.'); 17 | elseif strcmp(imname, '.') 18 | disp('Ignore .. folder.'); 19 | else 20 | idx = idx + 1; 21 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 22 | fprintf(str_rlt); 23 | % read image 24 | img = imread(fullfile(folder_path, [imname, ext])); 25 | img = im2double(img); 26 | % modcrop 27 | img = modcrop(img, mod_scale); 28 | % LR 29 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 30 | if exist('save_LR_folder', 'var') 31 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 32 | end 33 | end 34 | end 35 | end 36 | 37 | %% modcrop 38 | function img = modcrop(img, modulo) 39 | if size(img,3) == 1 40 | sz = size(img); 41 | sz = sz - mod(sz, modulo); 42 | img = img(1:sz(1), 1:sz(2)); 43 | else 44 | tmpsz = size(img); 45 | sz = tmpsz(1:2); 46 | sz = sz - mod(sz, modulo); 47 | img = img(1:sz(1), 1:sz(2),:); 48 | end 49 | end 50 | -------------------------------------------------------------------------------- /codes/data_scripts/generate_mod_LR_bic.m: -------------------------------------------------------------------------------- 1 | function generate_mod_LR_bic() 2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. 3 | 4 | %% set parameters 5 | % comment the unnecessary line 6 | input_folder = '../../datasets/DIV2K/DIV2K800'; 7 | % save_mod_folder = ''; 8 | save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4'; 9 | % save_bic_folder = ''; 10 | 11 | up_scale = 4; 12 | mod_scale = 4; 13 | 14 | if exist('save_mod_folder', 'var') 15 | if exist(save_mod_folder, 'dir') 16 | disp(['It will cover ', save_mod_folder]); 17 | else 18 | mkdir(save_mod_folder); 19 | end 20 | end 21 | if exist('save_LR_folder', 'var') 22 | if exist(save_LR_folder, 'dir') 23 | disp(['It will cover ', save_LR_folder]); 24 | else 25 | mkdir(save_LR_folder); 26 | end 27 | end 28 | if exist('save_bic_folder', 'var') 29 | if exist(save_bic_folder, 'dir') 30 | disp(['It will cover ', save_bic_folder]); 31 | else 32 | mkdir(save_bic_folder); 33 | end 34 | end 35 | 36 | idx = 0; 37 | filepaths = dir(fullfile(input_folder,'*.*')); 38 | for i = 1 : length(filepaths) 39 | [paths,imname,ext] = fileparts(filepaths(i).name); 40 | if isempty(imname) 41 | disp('Ignore . folder.'); 42 | elseif strcmp(imname, '.') 43 | disp('Ignore .. folder.'); 44 | else 45 | idx = idx + 1; 46 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 47 | fprintf(str_rlt); 48 | % read image 49 | img = imread(fullfile(input_folder, [imname, ext])); 50 | img = im2double(img); 51 | % modcrop 52 | img = modcrop(img, mod_scale); 53 | if exist('save_mod_folder', 'var') 54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); 55 | end 56 | % LR 57 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 58 | if exist('save_LR_folder', 'var') 59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 60 | end 61 | % Bicubic 62 | if exist('save_bic_folder', 'var') 63 | im_B = imresize(im_LR, up_scale, 'bicubic'); 64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '.png'])); 65 | end 66 | end 67 | end 68 | end 69 | 70 | %% modcrop 71 | function img = modcrop(img, modulo) 72 | if size(img,3) == 1 73 | sz = size(img); 74 | sz = sz - mod(sz, modulo); 75 | img = img(1:sz(1), 1:sz(2)); 76 | else 77 | tmpsz = size(img); 78 | sz = tmpsz(1:2); 79 | sz = sz - mod(sz, modulo); 80 | img = img(1:sz(1), 1:sz(2),:); 81 | end 82 | end 83 | -------------------------------------------------------------------------------- /codes/data_scripts/generate_mod_LR_bic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | 6 | try: 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | from data.util import imresize_np 9 | except ImportError: 10 | pass 11 | 12 | 13 | def generate_mod_LR_bic(): 14 | # set parameters 15 | up_scale = 4 16 | mod_scale = 4 17 | # set data dir 18 | sourcedir = '/data/datasets/img' 19 | savedir = '/data/datasets/mod' 20 | 21 | saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale)) 22 | saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale)) 23 | saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale)) 24 | 25 | if not os.path.isdir(sourcedir): 26 | print('Error: No source data found') 27 | exit(0) 28 | if not os.path.isdir(savedir): 29 | os.mkdir(savedir) 30 | 31 | if not os.path.isdir(os.path.join(savedir, 'HR')): 32 | os.mkdir(os.path.join(savedir, 'HR')) 33 | if not os.path.isdir(os.path.join(savedir, 'LR')): 34 | os.mkdir(os.path.join(savedir, 'LR')) 35 | if not os.path.isdir(os.path.join(savedir, 'Bic')): 36 | os.mkdir(os.path.join(savedir, 'Bic')) 37 | 38 | if not os.path.isdir(saveHRpath): 39 | os.mkdir(saveHRpath) 40 | else: 41 | print('It will cover ' + str(saveHRpath)) 42 | 43 | if not os.path.isdir(saveLRpath): 44 | os.mkdir(saveLRpath) 45 | else: 46 | print('It will cover ' + str(saveLRpath)) 47 | 48 | if not os.path.isdir(saveBicpath): 49 | os.mkdir(saveBicpath) 50 | else: 51 | print('It will cover ' + str(saveBicpath)) 52 | 53 | filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')] 54 | num_files = len(filepaths) 55 | 56 | # prepare data with augementation 57 | for i in range(num_files): 58 | filename = filepaths[i] 59 | print('No.{} -- Processing {}'.format(i, filename)) 60 | # read image 61 | image = cv2.imread(os.path.join(sourcedir, filename)) 62 | 63 | width = int(np.floor(image.shape[1] / mod_scale)) 64 | height = int(np.floor(image.shape[0] / mod_scale)) 65 | # modcrop 66 | if len(image.shape) == 3: 67 | image_HR = image[0:mod_scale * height, 0:mod_scale * width, :] 68 | else: 69 | image_HR = image[0:mod_scale * height, 0:mod_scale * width] 70 | # LR 71 | image_LR = imresize_np(image_HR, 1 / up_scale, True) 72 | # bic 73 | image_Bic = imresize_np(image_LR, up_scale, True) 74 | 75 | cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) 76 | cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) 77 | cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) 78 | 79 | 80 | if __name__ == "__main__": 81 | generate_mod_LR_bic() 82 | -------------------------------------------------------------------------------- /codes/data_scripts/prepare_DIV2K_x4_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | echo "Prepare DIV2K X4 datasets..." 4 | cd ../../datasets 5 | mkdir DIV2K 6 | cd DIV2K 7 | 8 | #### Step 1 9 | echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..." 10 | # GT 11 | FOLDER=DIV2K_train_HR 12 | FILE=DIV2K_train_HR.zip 13 | if [ ! -d "$FOLDER" ]; then 14 | if [ ! -f "$FILE" ]; then 15 | echo "Downloading $FILE..." 16 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE 17 | fi 18 | unzip $FILE 19 | fi 20 | # LR 21 | FOLDER=DIV2K_train_LR_bicubic 22 | FILE=DIV2K_train_LR_bicubic_X4.zip 23 | if [ ! -d "$FOLDER" ]; then 24 | if [ ! -f "$FILE" ]; then 25 | echo "Downloading $FILE..." 26 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE 27 | fi 28 | unzip $FILE 29 | fi 30 | 31 | #### Step 2 32 | echo "Step 2: Rename the LR images..." 33 | cd ../../codes/data_scripts 34 | python rename.py 35 | 36 | #### Step 4 37 | echo "Step 4: Crop to sub-images..." 38 | python extract_subimages.py 39 | 40 | #### Step 5 41 | echo "Step5: Create LMDB files..." 42 | python create_lmdb.py 43 | -------------------------------------------------------------------------------- /codes/data_scripts/regroup_REDS.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | train_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic/X4' 5 | val_path = '/home/xtwang/datasets/REDS/val_sharp_bicubic/X4' 6 | 7 | # mv the val set 8 | val_folders = glob.glob(os.path.join(val_path, '*')) 9 | for folder in val_folders: 10 | new_folder_idx = '{:03d}'.format(int(folder.split('/')[-1]) + 240) 11 | os.system('cp -r {} {}'.format(folder, os.path.join(train_path, new_folder_idx))) 12 | -------------------------------------------------------------------------------- /codes/data_scripts/rename.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | 5 | def main(): 6 | folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4' 7 | DIV2K(folder) 8 | print('Finished.') 9 | 10 | 11 | def DIV2K(path): 12 | img_path_l = glob.glob(os.path.join(path, '*')) 13 | for img_path in img_path_l: 14 | new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') 15 | os.rename(img_path, new_path) 16 | 17 | 18 | if __name__ == "__main__": 19 | main() -------------------------------------------------------------------------------- /codes/data_scripts/test_dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os.path as osp 3 | import math 4 | import torchvision.utils 5 | 6 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 7 | from data import create_dataloader, create_dataset # noqa: E402 8 | from utils import util # noqa: E402 9 | 10 | 11 | def main(): 12 | dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub 13 | opt = {} 14 | opt['dist'] = False 15 | opt['gpu_ids'] = [0] 16 | if dataset == 'REDS': 17 | opt['name'] = 'test_REDS' 18 | opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb' 19 | opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb' 20 | opt['mode'] = 'REDS' 21 | opt['N_frames'] = 5 22 | opt['phase'] = 'train' 23 | opt['use_shuffle'] = True 24 | opt['n_workers'] = 8 25 | opt['batch_size'] = 16 26 | opt['GT_size'] = 256 27 | opt['LQ_size'] = 64 28 | opt['scale'] = 4 29 | opt['use_flip'] = True 30 | opt['use_rot'] = True 31 | opt['interval_list'] = [1] 32 | opt['random_reverse'] = False 33 | opt['border_mode'] = False 34 | opt['cache_keys'] = None 35 | opt['data_type'] = 'lmdb' # img | lmdb | mc 36 | elif dataset == 'Vimeo90K': 37 | opt['name'] = 'test_Vimeo90K' 38 | opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb' 39 | opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' 40 | opt['mode'] = 'Vimeo90K' 41 | opt['N_frames'] = 7 42 | opt['phase'] = 'train' 43 | opt['use_shuffle'] = True 44 | opt['n_workers'] = 8 45 | opt['batch_size'] = 16 46 | opt['GT_size'] = 256 47 | opt['LQ_size'] = 64 48 | opt['scale'] = 4 49 | opt['use_flip'] = True 50 | opt['use_rot'] = True 51 | opt['interval_list'] = [1] 52 | opt['random_reverse'] = False 53 | opt['border_mode'] = False 54 | opt['cache_keys'] = None 55 | opt['data_type'] = 'lmdb' # img | lmdb | mc 56 | elif dataset == 'DIV2K800_sub': 57 | opt['name'] = 'DIV2K800' 58 | opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb' 59 | opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb' 60 | opt['mode'] = 'LQGT' 61 | opt['phase'] = 'train' 62 | opt['use_shuffle'] = True 63 | opt['n_workers'] = 8 64 | opt['batch_size'] = 16 65 | opt['GT_size'] = 128 66 | opt['scale'] = 4 67 | opt['use_flip'] = True 68 | opt['use_rot'] = True 69 | opt['color'] = 'RGB' 70 | opt['data_type'] = 'lmdb' # img | lmdb 71 | else: 72 | raise ValueError('Please implement by yourself.') 73 | 74 | util.mkdir('tmp') 75 | train_set = create_dataset(opt) 76 | train_loader = create_dataloader(train_set, opt, opt, None) 77 | nrow = int(math.sqrt(opt['batch_size'])) 78 | padding = 2 if opt['phase'] == 'train' else 0 79 | 80 | print('start...') 81 | for i, data in enumerate(train_loader): 82 | if i > 5: 83 | break 84 | print(i) 85 | if dataset == 'REDS' or dataset == 'Vimeo90K': 86 | LQs = data['LQs'] 87 | else: 88 | LQ = data['LQ'] 89 | GT = data['GT'] 90 | 91 | if dataset == 'REDS' or dataset == 'Vimeo90K': 92 | for j in range(LQs.size(1)): 93 | torchvision.utils.save_image(LQs[:, j, :, :, :], 94 | 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow, 95 | padding=padding, normalize=False) 96 | else: 97 | torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow, 98 | padding=padding, normalize=False) 99 | torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding, 100 | normalize=False) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /codes/metrics/calculate_PSNR_SSIM.m: -------------------------------------------------------------------------------- 1 | function calculate_PSNR_SSIM() 2 | 3 | % GT and SR folder 4 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'; 5 | folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'; 6 | scale = 4; 7 | suffix = ''; % suffix for SR images 8 | test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels 9 | if test_Y 10 | fprintf('Tesing Y channel.\n'); 11 | else 12 | fprintf('Tesing RGB channels.\n'); 13 | end 14 | filepaths = dir(fullfile(folder_GT, '*.png')); 15 | PSNR_all = zeros(1, length(filepaths)); 16 | SSIM_all = zeros(1, length(filepaths)); 17 | 18 | for idx_im = 1:length(filepaths) 19 | im_name = filepaths(idx_im).name; 20 | im_GT = imread(fullfile(folder_GT, im_name)); 21 | im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png'])); 22 | 23 | if test_Y % evaluate on Y channel in YCbCr color space 24 | if size(im_GT, 3) == 3 25 | im_GT_YCbCr = rgb2ycbcr(im2double(im_GT)); 26 | im_GT_in = im_GT_YCbCr(:,:,1); 27 | im_SR_YCbCr = rgb2ycbcr(im2double(im_SR)); 28 | im_SR_in = im_SR_YCbCr(:,:,1); 29 | else 30 | im_GT_in = im2double(im_GT); 31 | im_SR_in = im2double(im_SR); 32 | end 33 | else % evaluate on RGB channels 34 | im_GT_in = im2double(im_GT); 35 | im_SR_in = im2double(im_SR); 36 | end 37 | 38 | % calculate PSNR and SSIM 39 | PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale); 40 | SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale); 41 | fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im)); 42 | end 43 | 44 | fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all)); 45 | end 46 | 47 | function res = calculate_PSNR(GT, SR, border) 48 | % remove border 49 | GT = GT(border+1:end-border, border+1:end-border, :); 50 | SR = SR(border+1:end-border, border+1:end-border, :); 51 | % calculate PNSR (assume in [0,255]) 52 | error = GT(:) - SR(:); 53 | mse = mean(error.^2); 54 | res = 10 * log10(255^2/mse); 55 | end 56 | 57 | function res = calculate_SSIM(GT, SR, border) 58 | GT = GT(border+1:end-border, border+1:end-border, :); 59 | SR = SR(border+1:end-border, border+1:end-border, :); 60 | % calculate SSIM 61 | mssim = zeros(1, size(SR, 3)); 62 | for i = 1:size(SR,3) 63 | [mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i)); 64 | end 65 | res = mean(mssim); 66 | end 67 | 68 | function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L) 69 | 70 | %======================================================================== 71 | %SSIM Index, Version 1.0 72 | %Copyright(c) 2003 Zhou Wang 73 | %All Rights Reserved. 74 | % 75 | %The author is with Howard Hughes Medical Institute, and Laboratory 76 | %for Computational Vision at Center for Neural Science and Courant 77 | %Institute of Mathematical Sciences, New York University. 78 | % 79 | %---------------------------------------------------------------------- 80 | %Permission to use, copy, or modify this software and its documentation 81 | %for educational and research purposes only and without fee is hereby 82 | %granted, provided that this copyright notice and the original authors' 83 | %names appear on all copies and supporting documentation. This program 84 | %shall not be used, rewritten, or adapted as the basis of a commercial 85 | %software or hardware product without first obtaining permission of the 86 | %authors. The authors make no representations about the suitability of 87 | %this software for any purpose. It is provided "as is" without express 88 | %or implied warranty. 89 | %---------------------------------------------------------------------- 90 | % 91 | %This is an implementation of the algorithm for calculating the 92 | %Structural SIMilarity (SSIM) index between two images. Please refer 93 | %to the following paper: 94 | % 95 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 96 | %quality assessment: From error measurement to structural similarity" 97 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. 98 | % 99 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 100 | % 101 | %---------------------------------------------------------------------- 102 | % 103 | %Input : (1) img1: the first image being compared 104 | % (2) img2: the second image being compared 105 | % (3) K: constants in the SSIM index formula (see the above 106 | % reference). defualt value: K = [0.01 0.03] 107 | % (4) window: local window for statistics (see the above 108 | % reference). default widnow is Gaussian given by 109 | % window = fspecial('gaussian', 11, 1.5); 110 | % (5) L: dynamic range of the images. default: L = 255 111 | % 112 | %Output: (1) mssim: the mean SSIM index value between 2 images. 113 | % If one of the images being compared is regarded as 114 | % perfect quality, then mssim can be considered as the 115 | % quality measure of the other image. 116 | % If img1 = img2, then mssim = 1. 117 | % (2) ssim_map: the SSIM index map of the test image. The map 118 | % has a smaller size than the input images. The actual size: 119 | % size(img1) - size(window) + 1. 120 | % 121 | %Default Usage: 122 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 123 | % 124 | % [mssim ssim_map] = ssim_index(img1, img2); 125 | % 126 | %Advanced Usage: 127 | % User defined parameters. For example 128 | % 129 | % K = [0.05 0.05]; 130 | % window = ones(8); 131 | % L = 100; 132 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 133 | % 134 | %See the results: 135 | % 136 | % mssim %Gives the mssim value 137 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 138 | % 139 | %======================================================================== 140 | 141 | 142 | if (nargin < 2 || nargin > 5) 143 | ssim_index = -Inf; 144 | ssim_map = -Inf; 145 | return; 146 | end 147 | 148 | if (size(img1) ~= size(img2)) 149 | ssim_index = -Inf; 150 | ssim_map = -Inf; 151 | return; 152 | end 153 | 154 | [M, N] = size(img1); 155 | 156 | if (nargin == 2) 157 | if ((M < 11) || (N < 11)) 158 | ssim_index = -Inf; 159 | ssim_map = -Inf; 160 | return 161 | end 162 | window = fspecial('gaussian', 11, 1.5); % 163 | K(1) = 0.01; % default settings 164 | K(2) = 0.03; % 165 | L = 255; % 166 | end 167 | 168 | if (nargin == 3) 169 | if ((M < 11) || (N < 11)) 170 | ssim_index = -Inf; 171 | ssim_map = -Inf; 172 | return 173 | end 174 | window = fspecial('gaussian', 11, 1.5); 175 | L = 255; 176 | if (length(K) == 2) 177 | if (K(1) < 0 || K(2) < 0) 178 | ssim_index = -Inf; 179 | ssim_map = -Inf; 180 | return; 181 | end 182 | else 183 | ssim_index = -Inf; 184 | ssim_map = -Inf; 185 | return; 186 | end 187 | end 188 | 189 | if (nargin == 4) 190 | [H, W] = size(window); 191 | if ((H*W) < 4 || (H > M) || (W > N)) 192 | ssim_index = -Inf; 193 | ssim_map = -Inf; 194 | return 195 | end 196 | L = 255; 197 | if (length(K) == 2) 198 | if (K(1) < 0 || K(2) < 0) 199 | ssim_index = -Inf; 200 | ssim_map = -Inf; 201 | return; 202 | end 203 | else 204 | ssim_index = -Inf; 205 | ssim_map = -Inf; 206 | return; 207 | end 208 | end 209 | 210 | if (nargin == 5) 211 | [H, W] = size(window); 212 | if ((H*W) < 4 || (H > M) || (W > N)) 213 | ssim_index = -Inf; 214 | ssim_map = -Inf; 215 | return 216 | end 217 | if (length(K) == 2) 218 | if (K(1) < 0 || K(2) < 0) 219 | ssim_index = -Inf; 220 | ssim_map = -Inf; 221 | return; 222 | end 223 | else 224 | ssim_index = -Inf; 225 | ssim_map = -Inf; 226 | return; 227 | end 228 | end 229 | 230 | C1 = (K(1)*L)^2; 231 | C2 = (K(2)*L)^2; 232 | window = window/sum(sum(window)); 233 | img1 = double(img1); 234 | img2 = double(img2); 235 | 236 | mu1 = filter2(window, img1, 'valid'); 237 | mu2 = filter2(window, img2, 'valid'); 238 | mu1_sq = mu1.*mu1; 239 | mu2_sq = mu2.*mu2; 240 | mu1_mu2 = mu1.*mu2; 241 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 242 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 243 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 244 | 245 | if (C1 > 0 && C2 > 0) 246 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 247 | else 248 | numerator1 = 2*mu1_mu2 + C1; 249 | numerator2 = 2*sigma12 + C2; 250 | denominator1 = mu1_sq + mu2_sq + C1; 251 | denominator2 = sigma1_sq + sigma2_sq + C2; 252 | ssim_map = ones(size(mu1)); 253 | index = (denominator1.*denominator2 > 0); 254 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 255 | index = (denominator1 ~= 0) & (denominator2 == 0); 256 | ssim_map(index) = numerator1(index)./denominator1(index); 257 | end 258 | 259 | mssim = mean2(ssim_map); 260 | 261 | end 262 | -------------------------------------------------------------------------------- /codes/metrics/calculate_PSNR_SSIM.py: -------------------------------------------------------------------------------- 1 | ''' 2 | calculate the PSNR and SSIM. 3 | same as MATLAB's results 4 | ''' 5 | import os 6 | import math 7 | import numpy as np 8 | import cv2 9 | import glob 10 | 11 | 12 | def main(): 13 | # Configurations 14 | 15 | # GT - Ground-truth; 16 | # Gen: Generated / Restored / Recovered images 17 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5' 18 | folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5' 19 | 20 | crop_border = 4 21 | suffix = '' # suffix for Gen images 22 | test_Y = False # True: test Y channel only; False: test RGB channels 23 | 24 | PSNR_all = [] 25 | SSIM_all = [] 26 | img_list = sorted(glob.glob(folder_GT + '/*')) 27 | 28 | if test_Y: 29 | print('Testing Y channel.') 30 | else: 31 | print('Testing RGB channels.') 32 | 33 | for i, img_path in enumerate(img_list): 34 | base_name = os.path.splitext(os.path.basename(img_path))[0] 35 | im_GT = cv2.imread(img_path) / 255. 36 | im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255. 37 | 38 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space 39 | im_GT_in = bgr2ycbcr(im_GT) 40 | im_Gen_in = bgr2ycbcr(im_Gen) 41 | else: 42 | im_GT_in = im_GT 43 | im_Gen_in = im_Gen 44 | 45 | # crop borders 46 | if im_GT_in.ndim == 3: 47 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 48 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 49 | elif im_GT_in.ndim == 2: 50 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 51 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 52 | else: 53 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim)) 54 | 55 | # calculate PSNR and SSIM 56 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) 57 | 58 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) 59 | print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format( 60 | i + 1, base_name, PSNR, SSIM)) 61 | PSNR_all.append(PSNR) 62 | SSIM_all.append(SSIM) 63 | print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format( 64 | sum(PSNR_all) / len(PSNR_all), 65 | sum(SSIM_all) / len(SSIM_all))) 66 | 67 | 68 | def calculate_psnr(img1, img2): 69 | # img1 and img2 have range [0, 255] 70 | img1 = img1.astype(np.float64) 71 | img2 = img2.astype(np.float64) 72 | mse = np.mean((img1 - img2)**2) 73 | if mse == 0: 74 | return float('inf') 75 | return 20 * math.log10(255.0 / math.sqrt(mse)) 76 | 77 | 78 | def ssim(img1, img2): 79 | C1 = (0.01 * 255)**2 80 | C2 = (0.03 * 255)**2 81 | 82 | img1 = img1.astype(np.float64) 83 | img2 = img2.astype(np.float64) 84 | kernel = cv2.getGaussianKernel(11, 1.5) 85 | window = np.outer(kernel, kernel.transpose()) 86 | 87 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 88 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 89 | mu1_sq = mu1**2 90 | mu2_sq = mu2**2 91 | mu1_mu2 = mu1 * mu2 92 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 93 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 94 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 95 | 96 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 97 | (sigma1_sq + sigma2_sq + C2)) 98 | return ssim_map.mean() 99 | 100 | 101 | def calculate_ssim(img1, img2): 102 | '''calculate SSIM 103 | the same outputs as MATLAB's 104 | img1, img2: [0, 255] 105 | ''' 106 | if not img1.shape == img2.shape: 107 | raise ValueError('Input images must have the same dimensions.') 108 | if img1.ndim == 2: 109 | return ssim(img1, img2) 110 | elif img1.ndim == 3: 111 | if img1.shape[2] == 3: 112 | ssims = [] 113 | for i in range(3): 114 | ssims.append(ssim(img1, img2)) 115 | return np.array(ssims).mean() 116 | elif img1.shape[2] == 1: 117 | return ssim(np.squeeze(img1), np.squeeze(img2)) 118 | else: 119 | raise ValueError('Wrong input image dimensions.') 120 | 121 | 122 | def bgr2ycbcr(img, only_y=True): 123 | '''same as matlab rgb2ycbcr 124 | only_y: only return Y channel 125 | Input: 126 | uint8, [0, 255] 127 | float, [0, 1] 128 | ''' 129 | in_img_type = img.dtype 130 | img.astype(np.float32) 131 | if in_img_type != np.uint8: 132 | img *= 255. 133 | # convert 134 | if only_y: 135 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 136 | else: 137 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 138 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 139 | if in_img_type == np.uint8: 140 | rlt = rlt.round() 141 | else: 142 | rlt /= 255. 143 | return rlt.astype(in_img_type) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /codes/models/SRGAN_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DataParallel, DistributedDataParallel 6 | import models.networks as networks 7 | import models.lr_scheduler as lr_scheduler 8 | from .base_model import BaseModel 9 | from models.loss import GANLoss 10 | 11 | logger = logging.getLogger('base') 12 | 13 | 14 | class SRGANModel(BaseModel): 15 | def __init__(self, opt): 16 | super(SRGANModel, self).__init__(opt) 17 | if opt['dist']: 18 | self.rank = torch.distributed.get_rank() 19 | else: 20 | self.rank = -1 # non dist training 21 | train_opt = opt['train'] 22 | 23 | # define networks and load pretrained models 24 | self.netG = networks.define_G(opt).to(self.device) 25 | if opt['dist']: 26 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 27 | else: 28 | self.netG = DataParallel(self.netG) 29 | if self.is_train: 30 | self.netD = networks.define_D(opt).to(self.device) 31 | if opt['dist']: 32 | self.netD = DistributedDataParallel(self.netD, 33 | device_ids=[torch.cuda.current_device()]) 34 | else: 35 | self.netD = DataParallel(self.netD) 36 | 37 | self.netG.train() 38 | self.netD.train() 39 | 40 | # define losses, optimizer and scheduler 41 | if self.is_train: 42 | # G pixel loss 43 | if train_opt['pixel_weight'] > 0: 44 | l_pix_type = train_opt['pixel_criterion'] 45 | if l_pix_type == 'l1': 46 | self.cri_pix = nn.L1Loss().to(self.device) 47 | elif l_pix_type == 'l2': 48 | self.cri_pix = nn.MSELoss().to(self.device) 49 | else: 50 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) 51 | self.l_pix_w = train_opt['pixel_weight'] 52 | else: 53 | logger.info('Remove pixel loss.') 54 | self.cri_pix = None 55 | 56 | # G feature loss 57 | if train_opt['feature_weight'] > 0: 58 | l_fea_type = train_opt['feature_criterion'] 59 | if l_fea_type == 'l1': 60 | self.cri_fea = nn.L1Loss().to(self.device) 61 | elif l_fea_type == 'l2': 62 | self.cri_fea = nn.MSELoss().to(self.device) 63 | else: 64 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) 65 | self.l_fea_w = train_opt['feature_weight'] 66 | else: 67 | logger.info('Remove feature loss.') 68 | self.cri_fea = None 69 | if self.cri_fea: # load VGG perceptual loss 70 | self.netF = networks.define_F(opt, use_bn=False).to(self.device) 71 | if opt['dist']: 72 | self.netF = DistributedDataParallel(self.netF, 73 | device_ids=[torch.cuda.current_device()]) 74 | else: 75 | self.netF = DataParallel(self.netF) 76 | 77 | # GD gan loss 78 | self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) 79 | self.l_gan_w = train_opt['gan_weight'] 80 | # D_update_ratio and D_init_iters 81 | self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 82 | self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 83 | 84 | # optimizers 85 | # G 86 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 87 | optim_params = [] 88 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 89 | if v.requires_grad: 90 | optim_params.append(v) 91 | else: 92 | if self.rank <= 0: 93 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 94 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 95 | weight_decay=wd_G, 96 | betas=(train_opt['beta1_G'], train_opt['beta2_G'])) 97 | self.optimizers.append(self.optimizer_G) 98 | # D 99 | wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 100 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], 101 | weight_decay=wd_D, 102 | betas=(train_opt['beta1_D'], train_opt['beta2_D'])) 103 | self.optimizers.append(self.optimizer_D) 104 | 105 | # schedulers 106 | if train_opt['lr_scheme'] == 'MultiStepLR': 107 | for optimizer in self.optimizers: 108 | self.schedulers.append( 109 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 110 | restarts=train_opt['restarts'], 111 | weights=train_opt['restart_weights'], 112 | gamma=train_opt['lr_gamma'], 113 | clear_state=train_opt['clear_state'])) 114 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 115 | for optimizer in self.optimizers: 116 | self.schedulers.append( 117 | lr_scheduler.CosineAnnealingLR_Restart( 118 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 119 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 120 | else: 121 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 122 | 123 | self.log_dict = OrderedDict() 124 | 125 | self.print_network() # print network 126 | self.load() # load G and D if needed 127 | 128 | def feed_data(self, data, need_GT=True): 129 | self.var_L = data['LQ'].to(self.device) # LQ 130 | if need_GT: 131 | self.var_H = data['GT'].to(self.device) # GT 132 | input_ref = data['ref'] if 'ref' in data else data['GT'] 133 | self.var_ref = input_ref.to(self.device) 134 | 135 | def optimize_parameters(self, step): 136 | # G 137 | for p in self.netD.parameters(): 138 | p.requires_grad = False 139 | 140 | self.optimizer_G.zero_grad() 141 | self.fake_H = self.netG(self.var_L) 142 | 143 | l_g_total = 0 144 | if step % self.D_update_ratio == 0 and step > self.D_init_iters: 145 | if self.cri_pix: # pixel loss 146 | l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) 147 | l_g_total += l_g_pix 148 | if self.cri_fea: # feature loss 149 | real_fea = self.netF(self.var_H).detach() 150 | fake_fea = self.netF(self.fake_H) 151 | l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) 152 | l_g_total += l_g_fea 153 | 154 | pred_g_fake = self.netD(self.fake_H) 155 | if self.opt['train']['gan_type'] == 'gan': 156 | l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) 157 | elif self.opt['train']['gan_type'] == 'ragan': 158 | pred_d_real = self.netD(self.var_ref).detach() 159 | l_g_gan = self.l_gan_w * ( 160 | self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) + 161 | self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2 162 | l_g_total += l_g_gan 163 | 164 | l_g_total.backward() 165 | self.optimizer_G.step() 166 | 167 | # D 168 | for p in self.netD.parameters(): 169 | p.requires_grad = True 170 | 171 | self.optimizer_D.zero_grad() 172 | l_d_total = 0 173 | pred_d_real = self.netD(self.var_ref) 174 | pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G 175 | if self.opt['train']['gan_type'] == 'gan': 176 | l_d_real = self.cri_gan(pred_d_real, True) 177 | l_d_fake = self.cri_gan(pred_d_fake, False) 178 | l_d_total = l_d_real + l_d_fake 179 | elif self.opt['train']['gan_type'] == 'ragan': 180 | l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True) 181 | l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False) 182 | l_d_total = (l_d_real + l_d_fake) / 2 183 | 184 | l_d_total.backward() 185 | self.optimizer_D.step() 186 | 187 | # set log 188 | if step % self.D_update_ratio == 0 and step > self.D_init_iters: 189 | if self.cri_pix: 190 | self.log_dict['l_g_pix'] = l_g_pix.item() 191 | if self.cri_fea: 192 | self.log_dict['l_g_fea'] = l_g_fea.item() 193 | self.log_dict['l_g_gan'] = l_g_gan.item() 194 | 195 | self.log_dict['l_d_real'] = l_d_real.item() 196 | self.log_dict['l_d_fake'] = l_d_fake.item() 197 | self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) 198 | self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) 199 | 200 | def test(self): 201 | self.netG.eval() 202 | with torch.no_grad(): 203 | self.fake_H = self.netG(self.var_L) 204 | self.netG.train() 205 | 206 | def get_current_log(self): 207 | return self.log_dict 208 | 209 | def get_current_visuals(self, need_GT=True): 210 | out_dict = OrderedDict() 211 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 212 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() 213 | if need_GT: 214 | out_dict['GT'] = self.var_H.detach()[0].float().cpu() 215 | return out_dict 216 | 217 | def print_network(self): 218 | # Generator 219 | s, n = self.get_network_description(self.netG) 220 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 221 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 222 | self.netG.module.__class__.__name__) 223 | else: 224 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 225 | if self.rank <= 0: 226 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 227 | logger.info(s) 228 | if self.is_train: 229 | # Discriminator 230 | s, n = self.get_network_description(self.netD) 231 | if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD, 232 | DistributedDataParallel): 233 | net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, 234 | self.netD.module.__class__.__name__) 235 | else: 236 | net_struc_str = '{}'.format(self.netD.__class__.__name__) 237 | if self.rank <= 0: 238 | logger.info('Network D structure: {}, with parameters: {:,d}'.format( 239 | net_struc_str, n)) 240 | logger.info(s) 241 | 242 | if self.cri_fea: # F, Perceptual Network 243 | s, n = self.get_network_description(self.netF) 244 | if isinstance(self.netF, nn.DataParallel) or isinstance( 245 | self.netF, DistributedDataParallel): 246 | net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, 247 | self.netF.module.__class__.__name__) 248 | else: 249 | net_struc_str = '{}'.format(self.netF.__class__.__name__) 250 | if self.rank <= 0: 251 | logger.info('Network F structure: {}, with parameters: {:,d}'.format( 252 | net_struc_str, n)) 253 | logger.info(s) 254 | 255 | def load(self): 256 | load_path_G = self.opt['path']['pretrain_model_G'] 257 | if load_path_G is not None: 258 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 259 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 260 | load_path_D = self.opt['path']['pretrain_model_D'] 261 | if self.opt['is_train'] and load_path_D is not None: 262 | logger.info('Loading model for D [{:s}] ...'.format(load_path_D)) 263 | self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) 264 | 265 | def save(self, iter_step): 266 | self.save_network(self.netG, 'G', iter_step) 267 | self.save_network(self.netD, 'D', iter_step) 268 | -------------------------------------------------------------------------------- /codes/models/SR_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | import models.networks as networks 8 | import models.lr_scheduler as lr_scheduler 9 | from .base_model import BaseModel 10 | from models.loss import CharbonnierLoss 11 | 12 | logger = logging.getLogger('base') 13 | 14 | 15 | class SRModel(BaseModel): 16 | def __init__(self, opt): 17 | super(SRModel, self).__init__(opt) 18 | 19 | if opt['dist']: 20 | self.rank = torch.distributed.get_rank() 21 | else: 22 | self.rank = -1 # non dist training 23 | train_opt = opt['train'] 24 | 25 | # define network and load pretrained models 26 | self.netG = networks.define_G(opt).to(self.device) 27 | if opt['dist']: 28 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 29 | else: 30 | self.netG = DataParallel(self.netG) 31 | # print network 32 | self.print_network() 33 | self.load() 34 | 35 | if self.is_train: 36 | self.netG.train() 37 | 38 | # loss 39 | loss_type = train_opt['pixel_criterion'] 40 | if loss_type == 'l1': 41 | self.cri_pix = nn.L1Loss().to(self.device) 42 | elif loss_type == 'l2': 43 | self.cri_pix = nn.MSELoss().to(self.device) 44 | elif loss_type == 'cb': 45 | self.cri_pix = CharbonnierLoss().to(self.device) 46 | else: 47 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 48 | self.l_pix_w = train_opt['pixel_weight'] 49 | 50 | # optimizers 51 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 52 | optim_params = [] 53 | if train_opt['finetune_adafm']: 54 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 55 | v.requires_grad = False 56 | if k.find('adafm') >= 0: 57 | v.requires_grad = True 58 | optim_params.append(v) 59 | logger.info('Params [{:s}] will optimize.'.format(k)) 60 | else: 61 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 62 | if v.requires_grad: 63 | optim_params.append(v) 64 | else: 65 | if self.rank <= 0: 66 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 67 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 68 | weight_decay=wd_G, 69 | betas=(train_opt['beta1'], train_opt['beta2'])) 70 | self.optimizers.append(self.optimizer_G) 71 | 72 | # schedulers 73 | if train_opt['lr_scheme'] == 'MultiStepLR': 74 | for optimizer in self.optimizers: 75 | self.schedulers.append( 76 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 77 | restarts=train_opt['restarts'], 78 | weights=train_opt['restart_weights'], 79 | gamma=train_opt['lr_gamma'], 80 | clear_state=train_opt['clear_state'])) 81 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 82 | for optimizer in self.optimizers: 83 | self.schedulers.append( 84 | lr_scheduler.CosineAnnealingLR_Restart( 85 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 86 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 87 | else: 88 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 89 | 90 | self.log_dict = OrderedDict() 91 | 92 | def feed_data(self, data, need_GT=True, need_cond=False): 93 | self.var_L = data['LQ'].to(self.device) # LQ 94 | if need_GT: 95 | self.real_H = data['GT'].to(self.device) # GT 96 | if need_cond: 97 | self.cond = data['cond'].to(self.device) # cond 98 | self.input = [self.var_L, self.cond] 99 | else: 100 | self.input = self.var_L 101 | 102 | def optimize_parameters(self, step): 103 | self.optimizer_G.zero_grad() 104 | self.fake_H = self.netG(self.input) 105 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) 106 | l_pix.backward() 107 | self.optimizer_G.step() 108 | 109 | # set log 110 | self.log_dict['l_pix'] = l_pix.item() 111 | 112 | def test(self): 113 | self.netG.eval() 114 | with torch.no_grad(): 115 | self.fake_H = self.netG(self.input) 116 | self.netG.train() 117 | 118 | def test_x8(self): 119 | # from https://github.com/thstkdgus35/EDSR-PyTorch 120 | self.netG.eval() 121 | 122 | def _transform(v, op): 123 | # if self.precision != 'single': v = v.float() 124 | v2np = v.data.cpu().numpy() 125 | if op == 'v': 126 | tfnp = v2np[:, :, :, ::-1].copy() 127 | elif op == 'h': 128 | tfnp = v2np[:, :, ::-1, :].copy() 129 | elif op == 't': 130 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 131 | 132 | ret = torch.Tensor(tfnp).to(self.device) 133 | # if self.precision == 'half': ret = ret.half() 134 | 135 | return ret 136 | 137 | lr_list = [self.var_L] 138 | for tf in 'v', 'h', 't': 139 | lr_list.extend([_transform(t, tf) for t in lr_list]) 140 | with torch.no_grad(): 141 | sr_list = [self.netG(aug) for aug in lr_list] 142 | for i in range(len(sr_list)): 143 | if i > 3: 144 | sr_list[i] = _transform(sr_list[i], 't') 145 | if i % 4 > 1: 146 | sr_list[i] = _transform(sr_list[i], 'h') 147 | if (i % 4) % 2 == 1: 148 | sr_list[i] = _transform(sr_list[i], 'v') 149 | 150 | output_cat = torch.cat(sr_list, dim=0) 151 | self.fake_H = output_cat.mean(dim=0, keepdim=True) 152 | self.netG.train() 153 | 154 | def get_current_log(self): 155 | return self.log_dict 156 | 157 | def get_current_visuals(self, need_GT=True): 158 | out_dict = OrderedDict() 159 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 160 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() 161 | if need_GT: 162 | out_dict['GT'] = self.real_H.detach()[0].float().cpu() 163 | return out_dict 164 | 165 | def print_network(self): 166 | s, n = self.get_network_description(self.netG) 167 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 168 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 169 | self.netG.module.__class__.__name__) 170 | else: 171 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 172 | if self.rank <= 0: 173 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 174 | logger.info(s) 175 | 176 | def load(self): 177 | load_path_G = self.opt['path']['pretrain_model_G'] 178 | if load_path_G is not None: 179 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 180 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 181 | 182 | def update(self, new_model_dict): 183 | if isinstance(self.netG, nn.DataParallel): 184 | network = self.netG.module 185 | network.load_state_dict(new_model_dict) 186 | 187 | def save(self, iter_label): 188 | self.save_network(self.netG, 'G', iter_label) 189 | -------------------------------------------------------------------------------- /codes/models/Video_base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | import models.networks as networks 8 | import models.lr_scheduler as lr_scheduler 9 | from .base_model import BaseModel 10 | from models.loss import CharbonnierLoss 11 | 12 | logger = logging.getLogger('base') 13 | 14 | 15 | class VideoBaseModel(BaseModel): 16 | def __init__(self, opt): 17 | super(VideoBaseModel, self).__init__(opt) 18 | 19 | if opt['dist']: 20 | self.rank = torch.distributed.get_rank() 21 | else: 22 | self.rank = -1 # non dist training 23 | train_opt = opt['train'] 24 | 25 | # define network and load pretrained models 26 | self.netG = networks.define_G(opt).to(self.device) 27 | if opt['dist']: 28 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 29 | else: 30 | self.netG = DataParallel(self.netG) 31 | # print network 32 | self.print_network() 33 | self.load() 34 | 35 | if self.is_train: 36 | self.netG.train() 37 | 38 | #### loss 39 | loss_type = train_opt['pixel_criterion'] 40 | if loss_type == 'l1': 41 | self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) 42 | elif loss_type == 'l2': 43 | self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) 44 | elif loss_type == 'cb': 45 | self.cri_pix = CharbonnierLoss().to(self.device) 46 | else: 47 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 48 | self.l_pix_w = train_opt['pixel_weight'] 49 | 50 | #### optimizers 51 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 52 | if train_opt['ft_tsa_only']: 53 | normal_params = [] 54 | tsa_fusion_params = [] 55 | for k, v in self.netG.named_parameters(): 56 | if v.requires_grad: 57 | if 'tsa_fusion' in k: 58 | tsa_fusion_params.append(v) 59 | else: 60 | normal_params.append(v) 61 | else: 62 | if self.rank <= 0: 63 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 64 | optim_params = [ 65 | { # add normal params first 66 | 'params': normal_params, 67 | 'lr': train_opt['lr_G'] 68 | }, 69 | { 70 | 'params': tsa_fusion_params, 71 | 'lr': train_opt['lr_G'] 72 | }, 73 | ] 74 | else: 75 | optim_params = [] 76 | for k, v in self.netG.named_parameters(): 77 | if v.requires_grad: 78 | optim_params.append(v) 79 | else: 80 | if self.rank <= 0: 81 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 82 | 83 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 84 | weight_decay=wd_G, 85 | betas=(train_opt['beta1'], train_opt['beta2'])) 86 | self.optimizers.append(self.optimizer_G) 87 | 88 | #### schedulers 89 | if train_opt['lr_scheme'] == 'MultiStepLR': 90 | for optimizer in self.optimizers: 91 | self.schedulers.append( 92 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 93 | restarts=train_opt['restarts'], 94 | weights=train_opt['restart_weights'], 95 | gamma=train_opt['lr_gamma'], 96 | clear_state=train_opt['clear_state'])) 97 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 98 | for optimizer in self.optimizers: 99 | self.schedulers.append( 100 | lr_scheduler.CosineAnnealingLR_Restart( 101 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 102 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 103 | else: 104 | raise NotImplementedError() 105 | 106 | self.log_dict = OrderedDict() 107 | 108 | def feed_data(self, data, need_GT=True): 109 | self.var_L = data['LQs'].to(self.device) 110 | if need_GT: 111 | self.real_H = data['GT'].to(self.device) 112 | 113 | def set_params_lr_zero(self): 114 | # fix normal module 115 | self.optimizers[0].param_groups[0]['lr'] = 0 116 | 117 | def optimize_parameters(self, step): 118 | if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']: 119 | self.set_params_lr_zero() 120 | 121 | self.optimizer_G.zero_grad() 122 | self.fake_H = self.netG(self.var_L) 123 | 124 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) 125 | l_pix.backward() 126 | self.optimizer_G.step() 127 | 128 | # set log 129 | self.log_dict['l_pix'] = l_pix.item() 130 | 131 | def test(self): 132 | self.netG.eval() 133 | with torch.no_grad(): 134 | self.fake_H = self.netG(self.var_L) 135 | self.netG.train() 136 | 137 | def get_current_log(self): 138 | return self.log_dict 139 | 140 | def get_current_visuals(self, need_GT=True): 141 | out_dict = OrderedDict() 142 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 143 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu() 144 | if need_GT: 145 | out_dict['GT'] = self.real_H.detach()[0].float().cpu() 146 | return out_dict 147 | 148 | def print_network(self): 149 | s, n = self.get_network_description(self.netG) 150 | if isinstance(self.netG, nn.DataParallel): 151 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 152 | self.netG.module.__class__.__name__) 153 | else: 154 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 155 | if self.rank <= 0: 156 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 157 | logger.info(s) 158 | 159 | def load(self): 160 | load_path_G = self.opt['path']['pretrain_model_G'] 161 | if load_path_G is not None: 162 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 163 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 164 | 165 | def save(self, iter_label): 166 | self.save_network(self.netG, 'G', iter_label) 167 | -------------------------------------------------------------------------------- /codes/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | model = opt['model'] 7 | # image restoration 8 | if model == 'sr': # PSNR-oriented super resolution 9 | from .SR_model import SRModel as M 10 | elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN 11 | from .SRGAN_model import SRGANModel as M 12 | # video restoration 13 | elif model == 'video_base': 14 | from .Video_base_model import VideoBaseModel as M 15 | else: 16 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) 17 | m = M(opt) 18 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 19 | return m 20 | -------------------------------------------------------------------------------- /codes/models/archs/CSRNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Condition(nn.Module): 9 | def __init__(self, in_nc=3, nf=32): 10 | super(Condition, self).__init__() 11 | stride = 2 12 | pad = 0 13 | self.pad = nn.ZeroPad2d(1) 14 | self.conv1 = nn.Conv2d(in_nc, nf, 7, stride, pad, bias=True) 15 | self.conv2 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) 16 | self.conv3 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True) 17 | self.act = nn.ReLU(inplace=True) 18 | 19 | def forward(self, x): 20 | conv1_out = self.act(self.conv1(self.pad(x))) 21 | conv2_out = self.act(self.conv2(self.pad(conv1_out))) 22 | conv3_out = self.act(self.conv3(self.pad(conv2_out))) 23 | out = torch.mean(conv3_out, dim=[2, 3], keepdim=False) 24 | 25 | return out 26 | 27 | 28 | # 3layers with control 29 | class CSRNet(nn.Module): 30 | def __init__(self, in_nc=3, out_nc=3, base_nf=64, cond_nf=32): 31 | super(CSRNet, self).__init__() 32 | 33 | self.base_nf = base_nf 34 | self.out_nc = out_nc 35 | 36 | self.cond_net = Condition(in_nc=in_nc, nf=cond_nf) 37 | 38 | self.cond_scale1 = nn.Linear(cond_nf, base_nf, bias=True) 39 | self.cond_scale2 = nn.Linear(cond_nf, base_nf, bias=True) 40 | self.cond_scale3 = nn.Linear(cond_nf, 3, bias=True) 41 | 42 | self.cond_shift1 = nn.Linear(cond_nf, base_nf, bias=True) 43 | self.cond_shift2 = nn.Linear(cond_nf, base_nf, bias=True) 44 | self.cond_shift3 = nn.Linear(cond_nf, 3, bias=True) 45 | 46 | self.conv1 = nn.Conv2d(in_nc, base_nf, 1, 1, bias=True) 47 | self.conv2 = nn.Conv2d(base_nf, base_nf, 1, 1, bias=True) 48 | self.conv3 = nn.Conv2d(base_nf, out_nc, 1, 1, bias=True) 49 | 50 | self.act = nn.ReLU(inplace=True) 51 | 52 | 53 | def forward(self, x): 54 | cond = self.cond_net(x) 55 | 56 | scale1 = self.cond_scale1(cond) 57 | shift1 = self.cond_shift1(cond) 58 | 59 | scale2 = self.cond_scale2(cond) 60 | shift2 = self.cond_shift2(cond) 61 | 62 | scale3 = self.cond_scale3(cond) 63 | shift3 = self.cond_shift3(cond) 64 | 65 | out = self.conv1(x) 66 | out = out * scale1.view(-1, self.base_nf, 1, 1) + shift1.view(-1, self.base_nf, 1, 1) + out 67 | out = self.act(out) 68 | 69 | 70 | out = self.conv2(out) 71 | out = out * scale2.view(-1, self.base_nf, 1, 1) + shift2.view(-1, self.base_nf, 1, 1) + out 72 | out = self.act(out) 73 | 74 | out = self.conv3(out) 75 | out = out * scale3.view(-1, self.out_nc, 1, 1) + shift3.view(-1, self.out_nc, 1, 1) + out 76 | return out -------------------------------------------------------------------------------- /codes/models/archs/EDVR_arch.py: -------------------------------------------------------------------------------- 1 | ''' network architecture for EDVR ''' 2 | import functools 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import models.archs.arch_util as arch_util 7 | try: 8 | from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN 9 | except ImportError: 10 | raise ImportError('Failed to import DCNv2 module.') 11 | 12 | 13 | class Predeblur_ResNet_Pyramid(nn.Module): 14 | def __init__(self, nf=128, HR_in=False): 15 | ''' 16 | HR_in: True if the inputs are high spatial size 17 | ''' 18 | 19 | super(Predeblur_ResNet_Pyramid, self).__init__() 20 | self.HR_in = True if HR_in else False 21 | if self.HR_in: 22 | self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) 23 | self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) 24 | self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) 25 | else: 26 | self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True) 27 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) 28 | self.RB_L1_1 = basic_block() 29 | self.RB_L1_2 = basic_block() 30 | self.RB_L1_3 = basic_block() 31 | self.RB_L1_4 = basic_block() 32 | self.RB_L1_5 = basic_block() 33 | self.RB_L2_1 = basic_block() 34 | self.RB_L2_2 = basic_block() 35 | self.RB_L3_1 = basic_block() 36 | self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) 37 | self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) 38 | 39 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 40 | 41 | def forward(self, x): 42 | if self.HR_in: 43 | L1_fea = self.lrelu(self.conv_first_1(x)) 44 | L1_fea = self.lrelu(self.conv_first_2(L1_fea)) 45 | L1_fea = self.lrelu(self.conv_first_3(L1_fea)) 46 | else: 47 | L1_fea = self.lrelu(self.conv_first(x)) 48 | L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea)) 49 | L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea)) 50 | L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear', 51 | align_corners=False) 52 | L2_fea = self.RB_L2_1(L2_fea) + L3_fea 53 | L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear', 54 | align_corners=False) 55 | L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea 56 | out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea))) 57 | return out 58 | 59 | 60 | class PCD_Align(nn.Module): 61 | ''' Alignment module using Pyramid, Cascading and Deformable convolution 62 | with 3 pyramid levels. 63 | ''' 64 | 65 | def __init__(self, nf=64, groups=8): 66 | super(PCD_Align, self).__init__() 67 | # L3: level 3, 1/4 spatial size 68 | self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff 69 | self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 70 | self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, 71 | extra_offset_mask=True) 72 | # L2: level 2, 1/2 spatial size 73 | self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff 74 | self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset 75 | self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 76 | self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, 77 | extra_offset_mask=True) 78 | self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea 79 | # L1: level 1, original spatial size 80 | self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff 81 | self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset 82 | self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 83 | self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, 84 | extra_offset_mask=True) 85 | self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea 86 | # Cascading DCN 87 | self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff 88 | self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 89 | 90 | self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups, 91 | extra_offset_mask=True) 92 | 93 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 94 | 95 | def forward(self, nbr_fea_l, ref_fea_l): 96 | '''align other neighboring frames to the reference frame in the feature level 97 | nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features 98 | ''' 99 | # L3 100 | L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1) 101 | L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset)) 102 | L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset)) 103 | L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset])) 104 | # L2 105 | L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1) 106 | L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset)) 107 | L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False) 108 | L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1))) 109 | L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset)) 110 | L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset]) 111 | L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False) 112 | L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1))) 113 | # L1 114 | L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1) 115 | L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset)) 116 | L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False) 117 | L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1))) 118 | L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset)) 119 | L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset]) 120 | L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False) 121 | L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1)) 122 | # Cascading 123 | offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1) 124 | offset = self.lrelu(self.cas_offset_conv1(offset)) 125 | offset = self.lrelu(self.cas_offset_conv2(offset)) 126 | L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset])) 127 | 128 | return L1_fea 129 | 130 | 131 | class TSA_Fusion(nn.Module): 132 | ''' Temporal Spatial Attention fusion module 133 | Temporal: correlation; 134 | Spatial: 3 pyramid levels. 135 | ''' 136 | 137 | def __init__(self, nf=64, nframes=5, center=2): 138 | super(TSA_Fusion, self).__init__() 139 | self.center = center 140 | # temporal attention (before fusion conv) 141 | self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 142 | self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 143 | 144 | # fusion conv: using 1x1 to save parameters and computation 145 | self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) 146 | 147 | # spatial attention (after fusion conv) 148 | self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) 149 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 150 | self.avgpool = nn.AvgPool2d(3, stride=2, padding=1) 151 | self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True) 152 | self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 153 | self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True) 154 | self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 155 | self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True) 156 | self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) 157 | self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 158 | self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True) 159 | self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True) 160 | 161 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 162 | 163 | def forward(self, aligned_fea): 164 | B, N, C, H, W = aligned_fea.size() # N video frames 165 | #### temporal attention 166 | emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone()) 167 | emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W] 168 | 169 | cor_l = [] 170 | for i in range(N): 171 | emb_nbr = emb[:, i, :, :, :] 172 | cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W 173 | cor_l.append(cor_tmp) 174 | cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W 175 | cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W) 176 | aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob 177 | 178 | #### fusion 179 | fea = self.lrelu(self.fea_fusion(aligned_fea)) 180 | 181 | #### spatial attention 182 | att = self.lrelu(self.sAtt_1(aligned_fea)) 183 | att_max = self.maxpool(att) 184 | att_avg = self.avgpool(att) 185 | att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1))) 186 | # pyramid levels 187 | att_L = self.lrelu(self.sAtt_L1(att)) 188 | att_max = self.maxpool(att_L) 189 | att_avg = self.avgpool(att_L) 190 | att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1))) 191 | att_L = self.lrelu(self.sAtt_L3(att_L)) 192 | att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False) 193 | 194 | att = self.lrelu(self.sAtt_3(att)) 195 | att = att + att_L 196 | att = self.lrelu(self.sAtt_4(att)) 197 | att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False) 198 | att = self.sAtt_5(att) 199 | att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att))) 200 | att = torch.sigmoid(att) 201 | 202 | fea = fea * att * 2 + att_add 203 | return fea 204 | 205 | 206 | class EDVR(nn.Module): 207 | def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None, 208 | predeblur=False, HR_in=False, w_TSA=True): 209 | super(EDVR, self).__init__() 210 | self.nf = nf 211 | self.center = nframes // 2 if center is None else center 212 | self.is_predeblur = True if predeblur else False 213 | self.HR_in = True if HR_in else False 214 | self.w_TSA = w_TSA 215 | ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) 216 | 217 | #### extract features (for each frame) 218 | if self.is_predeblur: 219 | self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in) 220 | self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True) 221 | else: 222 | if self.HR_in: 223 | self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) 224 | self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) 225 | self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) 226 | else: 227 | self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True) 228 | self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs) 229 | self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) 230 | self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 231 | self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) 232 | self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 233 | 234 | self.pcd_align = PCD_Align(nf=nf, groups=groups) 235 | if self.w_TSA: 236 | self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center) 237 | else: 238 | self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) 239 | 240 | #### reconstruction 241 | self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs) 242 | #### upsampling 243 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 244 | self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True) 245 | self.pixel_shuffle = nn.PixelShuffle(2) 246 | self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True) 247 | self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True) 248 | 249 | #### activation function 250 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 251 | 252 | def forward(self, x): 253 | B, N, C, H, W = x.size() # N video frames 254 | x_center = x[:, self.center, :, :, :].contiguous() 255 | 256 | #### extract LR features 257 | # L1 258 | if self.is_predeblur: 259 | L1_fea = self.pre_deblur(x.view(-1, C, H, W)) 260 | L1_fea = self.conv_1x1(L1_fea) 261 | if self.HR_in: 262 | H, W = H // 4, W // 4 263 | else: 264 | if self.HR_in: 265 | L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W))) 266 | L1_fea = self.lrelu(self.conv_first_2(L1_fea)) 267 | L1_fea = self.lrelu(self.conv_first_3(L1_fea)) 268 | H, W = H // 4, W // 4 269 | else: 270 | L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W))) 271 | L1_fea = self.feature_extraction(L1_fea) 272 | # L2 273 | L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea)) 274 | L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea)) 275 | # L3 276 | L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea)) 277 | L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea)) 278 | 279 | L1_fea = L1_fea.view(B, N, -1, H, W) 280 | L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2) 281 | L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4) 282 | 283 | #### pcd align 284 | # ref feature list 285 | ref_fea_l = [ 286 | L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(), 287 | L3_fea[:, self.center, :, :, :].clone() 288 | ] 289 | aligned_fea = [] 290 | for i in range(N): 291 | nbr_fea_l = [ 292 | L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(), 293 | L3_fea[:, i, :, :, :].clone() 294 | ] 295 | aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l)) 296 | aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W] 297 | 298 | if not self.w_TSA: 299 | aligned_fea = aligned_fea.view(B, -1, H, W) 300 | fea = self.tsa_fusion(aligned_fea) 301 | 302 | out = self.recon_trunk(fea) 303 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 304 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 305 | out = self.lrelu(self.HRconv(out)) 306 | out = self.conv_last(out) 307 | if self.HR_in: 308 | base = x_center 309 | else: 310 | base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False) 311 | out += base 312 | return out 313 | -------------------------------------------------------------------------------- /codes/models/archs/RRDBNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import models.archs.arch_util as arch_util 6 | 7 | 8 | class ResidualDenseBlock_5C(nn.Module): 9 | def __init__(self, nf=64, gc=32, bias=True): 10 | super(ResidualDenseBlock_5C, self).__init__() 11 | # gc: growth channel, i.e. intermediate channels 12 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 13 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 14 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 15 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 16 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 17 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 18 | 19 | # initialization 20 | arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 21 | 0.1) 22 | 23 | def forward(self, x): 24 | x1 = self.lrelu(self.conv1(x)) 25 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 26 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 27 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 28 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 29 | return x5 * 0.2 + x 30 | 31 | 32 | class RRDB(nn.Module): 33 | '''Residual in Residual Dense Block''' 34 | 35 | def __init__(self, nf, gc=32): 36 | super(RRDB, self).__init__() 37 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 38 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 39 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 40 | 41 | def forward(self, x): 42 | out = self.RDB1(x) 43 | out = self.RDB2(out) 44 | out = self.RDB3(out) 45 | return out * 0.2 + x 46 | 47 | 48 | class RRDBNet(nn.Module): 49 | def __init__(self, in_nc, out_nc, nf, nb, gc=32): 50 | super(RRDBNet, self).__init__() 51 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 52 | 53 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 54 | self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb) 55 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 56 | #### upsampling 57 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 58 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 59 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 60 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 61 | 62 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 63 | 64 | def forward(self, x): 65 | fea = self.conv_first(x) 66 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 67 | fea = fea + trunk 68 | 69 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 70 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 71 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 72 | 73 | return out 74 | -------------------------------------------------------------------------------- /codes/models/archs/SRResNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import models.archs.arch_util as arch_util 5 | 6 | 7 | class MSRResNet(nn.Module): 8 | ''' modified SRResNet''' 9 | 10 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): 11 | super(MSRResNet, self).__init__() 12 | self.upscale = upscale 13 | 14 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 15 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf) 16 | self.recon_trunk = arch_util.make_layer(basic_block, nb) 17 | 18 | # upsampling 19 | if self.upscale == 2: 20 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 21 | self.pixel_shuffle = nn.PixelShuffle(2) 22 | elif self.upscale == 3: 23 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) 24 | self.pixel_shuffle = nn.PixelShuffle(3) 25 | elif self.upscale == 4: 26 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 27 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 28 | self.pixel_shuffle = nn.PixelShuffle(2) 29 | 30 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 31 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 32 | 33 | # activation function 34 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 35 | 36 | # initialization 37 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 38 | 0.1) 39 | if self.upscale == 4: 40 | arch_util.initialize_weights(self.upconv2, 0.1) 41 | 42 | def forward(self, x): 43 | fea = self.lrelu(self.conv_first(x)) 44 | out = self.recon_trunk(fea) 45 | 46 | if self.upscale == 4: 47 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 48 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 49 | elif self.upscale == 3 or self.upscale == 2: 50 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 51 | 52 | out = self.conv_last(self.lrelu(self.HRconv(out))) 53 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 54 | out += base 55 | return out 56 | -------------------------------------------------------------------------------- /codes/models/archs/TOF_arch.py: -------------------------------------------------------------------------------- 1 | '''PyTorch implementation of TOFlow 2 | Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018 3 | Code reference: 4 | 1. https://github.com/anchen1011/toflow 5 | 2. https://github.com/Coldog2333/pytoflow 6 | ''' 7 | 8 | import torch 9 | import torch.nn as nn 10 | from .arch_util import flow_warp 11 | 12 | 13 | def normalize(x): 14 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x) 15 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x) 16 | return (x - mean) / std 17 | 18 | 19 | def denormalize(x): 20 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x) 21 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x) 22 | return x * std + mean 23 | 24 | 25 | class SpyNet_Block(nn.Module): 26 | '''A submodule of SpyNet.''' 27 | 28 | def __init__(self): 29 | super(SpyNet_Block, self).__init__() 30 | 31 | self.block = nn.Sequential( 32 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), 33 | nn.BatchNorm2d(32), nn.ReLU(inplace=True), 34 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), 35 | nn.BatchNorm2d(64), nn.ReLU(inplace=True), 36 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), 37 | nn.BatchNorm2d(32), nn.ReLU(inplace=True), 38 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), 39 | nn.BatchNorm2d(16), nn.ReLU(inplace=True), 40 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)) 41 | 42 | def forward(self, x): 43 | ''' 44 | input: x: [ref im, nbr im, initial flow] - (B, 8, H, W) 45 | output: estimated flow - (B, 2, H, W) 46 | ''' 47 | return self.block(x) 48 | 49 | 50 | class SpyNet(nn.Module): 51 | '''SpyNet for estimating optical flow 52 | Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016''' 53 | 54 | def __init__(self): 55 | super(SpyNet, self).__init__() 56 | 57 | self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)]) 58 | 59 | def forward(self, ref, nbr): 60 | '''Estimating optical flow in coarse level, upsample, and estimate in fine level 61 | input: ref: reference image - [B, 3, H, W] 62 | nbr: the neighboring image to be warped - [B, 3, H, W] 63 | output: estimated optical flow - [B, 2, H, W] 64 | ''' 65 | B, C, H, W = ref.size() 66 | ref = [ref] 67 | nbr = [nbr] 68 | 69 | for _ in range(3): 70 | ref.insert( 71 | 0, 72 | nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2, 73 | count_include_pad=False)) 74 | nbr.insert( 75 | 0, 76 | nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2, 77 | count_include_pad=False)) 78 | 79 | flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0]) 80 | 81 | for i in range(4): 82 | flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear', 83 | align_corners=True) * 2.0 84 | flow = flow_up + self.blocks[i](torch.cat( 85 | [ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)) 86 | return flow 87 | 88 | 89 | class TOFlow(nn.Module): 90 | def __init__(self, adapt_official=False): 91 | super(TOFlow, self).__init__() 92 | 93 | self.SpyNet = SpyNet() 94 | 95 | self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4) 96 | self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4) 97 | self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1) 98 | self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1) 99 | 100 | self.relu = nn.ReLU(inplace=True) 101 | 102 | self.adapt_official = adapt_official # True if using translated official weights else False 103 | 104 | def forward(self, x): 105 | """ 106 | input: x: input frames - [B, 7, 3, H, W] 107 | output: SR reference frame - [B, 3, H, W] 108 | """ 109 | 110 | B, T, C, H, W = x.size() 111 | x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W) 112 | 113 | ref_idx = 3 114 | x_ref = x[:, ref_idx, :, :, :] 115 | 116 | # In the official torch code, the 0-th frame is the reference frame 117 | if self.adapt_official: 118 | x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :] 119 | ref_idx = 0 120 | 121 | x_warped = [] 122 | for i in range(7): 123 | if i == ref_idx: 124 | x_warped.append(x_ref) 125 | else: 126 | x_nbr = x[:, i, :, :, :] 127 | flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1) 128 | x_warped.append(flow_warp(x_nbr, flow)) 129 | x_warped = torch.stack(x_warped, dim=1) 130 | 131 | x = x_warped.view(B, -1, H, W) 132 | x = self.relu(self.conv_3x7_64_9x9(x)) 133 | x = self.relu(self.conv_64_64_9x9(x)) 134 | x = self.relu(self.conv_64_64_1x1(x)) 135 | x = self.conv_64_3_1x1(x) + x_ref 136 | 137 | return denormalize(x) 138 | -------------------------------------------------------------------------------- /codes/models/archs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/codes/models/archs/__init__.py -------------------------------------------------------------------------------- /codes/models/archs/arch_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | 27 | def make_layer(block, n_layers): 28 | layers = [] 29 | for _ in range(n_layers): 30 | layers.append(block()) 31 | return nn.Sequential(*layers) 32 | 33 | 34 | class ResidualBlock_noBN(nn.Module): 35 | '''Residual block w/o BN 36 | ---Conv-ReLU-Conv-+- 37 | |________________| 38 | ''' 39 | 40 | def __init__(self, nf=64): 41 | super(ResidualBlock_noBN, self).__init__() 42 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 43 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 44 | 45 | # initialization 46 | initialize_weights([self.conv1, self.conv2], 0.1) 47 | 48 | def forward(self, x): 49 | identity = x 50 | out = F.relu(self.conv1(x), inplace=True) 51 | out = self.conv2(out) 52 | return identity + out 53 | 54 | 55 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 56 | """Warp an image or feature map with optical flow 57 | Args: 58 | x (Tensor): size (N, C, H, W) 59 | flow (Tensor): size (N, H, W, 2), normal value 60 | interp_mode (str): 'nearest' or 'bilinear' 61 | padding_mode (str): 'zeros' or 'border' or 'reflection' 62 | 63 | Returns: 64 | Tensor: warped image or feature map 65 | """ 66 | assert x.size()[-2:] == flow.size()[1:3] 67 | B, C, H, W = x.size() 68 | # mesh grid 69 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 70 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 71 | grid.requires_grad = False 72 | grid = grid.type_as(x) 73 | vgrid = grid + flow 74 | # scale grid to [-1,1] 75 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 76 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 77 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 78 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 79 | return output 80 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, 2 | deform_conv, modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Function 7 | from torch.autograd.function import once_differentiable 8 | from torch.nn.modules.utils import _pair 9 | 10 | from . import deform_conv_cuda 11 | 12 | logger = logging.getLogger('base') 13 | 14 | 15 | class DeformConvFunction(Function): 16 | @staticmethod 17 | def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1, 18 | deformable_groups=1, im2col_step=64): 19 | if input is not None and input.dim() != 4: 20 | raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format( 21 | input.dim())) 22 | ctx.stride = _pair(stride) 23 | ctx.padding = _pair(padding) 24 | ctx.dilation = _pair(dilation) 25 | ctx.groups = groups 26 | ctx.deformable_groups = deformable_groups 27 | ctx.im2col_step = im2col_step 28 | 29 | ctx.save_for_backward(input, offset, weight) 30 | 31 | output = input.new_empty( 32 | DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) 33 | 34 | ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones 35 | 36 | if not input.is_cuda: 37 | raise NotImplementedError 38 | else: 39 | cur_im2col_step = min(ctx.im2col_step, input.shape[0]) 40 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' 41 | deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output, 42 | ctx.bufs_[0], ctx.bufs_[1], weight.size(3), 43 | weight.size(2), ctx.stride[1], ctx.stride[0], 44 | ctx.padding[1], ctx.padding[0], 45 | ctx.dilation[1], ctx.dilation[0], ctx.groups, 46 | ctx.deformable_groups, cur_im2col_step) 47 | return output 48 | 49 | @staticmethod 50 | @once_differentiable 51 | def backward(ctx, grad_output): 52 | input, offset, weight = ctx.saved_tensors 53 | 54 | grad_input = grad_offset = grad_weight = None 55 | 56 | if not grad_output.is_cuda: 57 | raise NotImplementedError 58 | else: 59 | cur_im2col_step = min(ctx.im2col_step, input.shape[0]) 60 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' 61 | 62 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 63 | grad_input = torch.zeros_like(input) 64 | grad_offset = torch.zeros_like(offset) 65 | deform_conv_cuda.deform_conv_backward_input_cuda( 66 | input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0], 67 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], 68 | ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, 69 | ctx.deformable_groups, cur_im2col_step) 70 | 71 | if ctx.needs_input_grad[2]: 72 | grad_weight = torch.zeros_like(weight) 73 | deform_conv_cuda.deform_conv_backward_parameters_cuda( 74 | input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1], 75 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], 76 | ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, 77 | ctx.deformable_groups, 1, cur_im2col_step) 78 | 79 | return (grad_input, grad_offset, grad_weight, None, None, None, None, None) 80 | 81 | @staticmethod 82 | def _output_size(input, weight, padding, dilation, stride): 83 | channels = weight.size(0) 84 | output_size = (input.size(0), channels) 85 | for d in range(input.dim() - 2): 86 | in_size = input.size(d + 2) 87 | pad = padding[d] 88 | kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 89 | stride_ = stride[d] 90 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) 91 | if not all(map(lambda s: s > 0, output_size)): 92 | raise ValueError("convolution input is too small (output would be {})".format('x'.join( 93 | map(str, output_size)))) 94 | return output_size 95 | 96 | 97 | class ModulatedDeformConvFunction(Function): 98 | @staticmethod 99 | def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1, 100 | groups=1, deformable_groups=1): 101 | ctx.stride = stride 102 | ctx.padding = padding 103 | ctx.dilation = dilation 104 | ctx.groups = groups 105 | ctx.deformable_groups = deformable_groups 106 | ctx.with_bias = bias is not None 107 | if not ctx.with_bias: 108 | bias = input.new_empty(1) # fake tensor 109 | if not input.is_cuda: 110 | raise NotImplementedError 111 | if weight.requires_grad or mask.requires_grad or offset.requires_grad \ 112 | or input.requires_grad: 113 | ctx.save_for_backward(input, offset, mask, weight, bias) 114 | output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) 115 | ctx._bufs = [input.new_empty(0), input.new_empty(0)] 116 | deform_conv_cuda.modulated_deform_conv_cuda_forward( 117 | input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2], 118 | weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, 119 | ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias) 120 | return output 121 | 122 | @staticmethod 123 | @once_differentiable 124 | def backward(ctx, grad_output): 125 | if not grad_output.is_cuda: 126 | raise NotImplementedError 127 | input, offset, mask, weight, bias = ctx.saved_tensors 128 | grad_input = torch.zeros_like(input) 129 | grad_offset = torch.zeros_like(offset) 130 | grad_mask = torch.zeros_like(mask) 131 | grad_weight = torch.zeros_like(weight) 132 | grad_bias = torch.zeros_like(bias) 133 | deform_conv_cuda.modulated_deform_conv_cuda_backward( 134 | input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight, 135 | grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3], 136 | ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, 137 | ctx.groups, ctx.deformable_groups, ctx.with_bias) 138 | if not ctx.with_bias: 139 | grad_bias = None 140 | 141 | return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, 142 | None) 143 | 144 | @staticmethod 145 | def _infer_shape(ctx, input, weight): 146 | n = input.size(0) 147 | channels_out = weight.size(0) 148 | height, width = input.shape[2:4] 149 | kernel_h, kernel_w = weight.shape[2:4] 150 | height_out = (height + 2 * ctx.padding - (ctx.dilation * 151 | (kernel_h - 1) + 1)) // ctx.stride + 1 152 | width_out = (width + 2 * ctx.padding - (ctx.dilation * 153 | (kernel_w - 1) + 1)) // ctx.stride + 1 154 | return n, channels_out, height_out, width_out 155 | 156 | 157 | deform_conv = DeformConvFunction.apply 158 | modulated_deform_conv = ModulatedDeformConvFunction.apply 159 | 160 | 161 | class DeformConv(nn.Module): 162 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 163 | groups=1, deformable_groups=1, bias=False): 164 | super(DeformConv, self).__init__() 165 | 166 | assert not bias 167 | assert in_channels % groups == 0, \ 168 | 'in_channels {} cannot be divisible by groups {}'.format( 169 | in_channels, groups) 170 | assert out_channels % groups == 0, \ 171 | 'out_channels {} cannot be divisible by groups {}'.format( 172 | out_channels, groups) 173 | 174 | self.in_channels = in_channels 175 | self.out_channels = out_channels 176 | self.kernel_size = _pair(kernel_size) 177 | self.stride = _pair(stride) 178 | self.padding = _pair(padding) 179 | self.dilation = _pair(dilation) 180 | self.groups = groups 181 | self.deformable_groups = deformable_groups 182 | 183 | self.weight = nn.Parameter( 184 | torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) 185 | 186 | self.reset_parameters() 187 | 188 | def reset_parameters(self): 189 | n = self.in_channels 190 | for k in self.kernel_size: 191 | n *= k 192 | stdv = 1. / math.sqrt(n) 193 | self.weight.data.uniform_(-stdv, stdv) 194 | 195 | def forward(self, x, offset): 196 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, 197 | self.groups, self.deformable_groups) 198 | 199 | 200 | class DeformConvPack(DeformConv): 201 | def __init__(self, *args, **kwargs): 202 | super(DeformConvPack, self).__init__(*args, **kwargs) 203 | 204 | self.conv_offset = nn.Conv2d( 205 | self.in_channels, 206 | self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], 207 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), 208 | bias=True) 209 | self.init_offset() 210 | 211 | def init_offset(self): 212 | self.conv_offset.weight.data.zero_() 213 | self.conv_offset.bias.data.zero_() 214 | 215 | def forward(self, x): 216 | offset = self.conv_offset(x) 217 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, 218 | self.groups, self.deformable_groups) 219 | 220 | 221 | class ModulatedDeformConv(nn.Module): 222 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, 223 | groups=1, deformable_groups=1, bias=True): 224 | super(ModulatedDeformConv, self).__init__() 225 | self.in_channels = in_channels 226 | self.out_channels = out_channels 227 | self.kernel_size = _pair(kernel_size) 228 | self.stride = stride 229 | self.padding = padding 230 | self.dilation = dilation 231 | self.groups = groups 232 | self.deformable_groups = deformable_groups 233 | self.with_bias = bias 234 | 235 | self.weight = nn.Parameter( 236 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) 237 | if bias: 238 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 239 | else: 240 | self.register_parameter('bias', None) 241 | self.reset_parameters() 242 | 243 | def reset_parameters(self): 244 | n = self.in_channels 245 | for k in self.kernel_size: 246 | n *= k 247 | stdv = 1. / math.sqrt(n) 248 | self.weight.data.uniform_(-stdv, stdv) 249 | if self.bias is not None: 250 | self.bias.data.zero_() 251 | 252 | def forward(self, x, offset, mask): 253 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, 254 | self.padding, self.dilation, self.groups, 255 | self.deformable_groups) 256 | 257 | 258 | class ModulatedDeformConvPack(ModulatedDeformConv): 259 | def __init__(self, *args, extra_offset_mask=False, **kwargs): 260 | super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) 261 | 262 | self.extra_offset_mask = extra_offset_mask 263 | self.conv_offset_mask = nn.Conv2d( 264 | self.in_channels, 265 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], 266 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), 267 | bias=True) 268 | self.init_offset() 269 | 270 | def init_offset(self): 271 | self.conv_offset_mask.weight.data.zero_() 272 | self.conv_offset_mask.bias.data.zero_() 273 | 274 | def forward(self, x): 275 | if self.extra_offset_mask: 276 | # x = [input, features] 277 | out = self.conv_offset_mask(x[1]) 278 | x = x[0] 279 | else: 280 | out = self.conv_offset_mask(x) 281 | o1, o2, mask = torch.chunk(out, 3, dim=1) 282 | offset = torch.cat((o1, o2), dim=1) 283 | mask = torch.sigmoid(mask) 284 | 285 | offset_mean = torch.mean(torch.abs(offset)) 286 | if offset_mean > 100: 287 | logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean)) 288 | 289 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, 290 | self.padding, self.dilation, self.groups, 291 | self.deformable_groups) 292 | -------------------------------------------------------------------------------- /codes/models/archs/dcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | def make_cuda_ext(name, sources): 6 | 7 | return CUDAExtension( 8 | name='{}'.format(name), sources=[p for p in sources], extra_compile_args={ 9 | 'cxx': [], 10 | 'nvcc': [ 11 | '-D__CUDA_NO_HALF_OPERATORS__', 12 | '-D__CUDA_NO_HALF_CONVERSIONS__', 13 | '-D__CUDA_NO_HALF2_OPERATORS__', 14 | ] 15 | }) 16 | 17 | 18 | setup( 19 | name='deform_conv', ext_modules=[ 20 | make_cuda_ext(name='deform_conv_cuda', 21 | sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']) 22 | ], cmdclass={'build_ext': BuildExtension}, zip_safe=False) 23 | -------------------------------------------------------------------------------- /codes/models/archs/discriminator_vgg_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class Discriminator_VGG_128(nn.Module): 7 | def __init__(self, in_nc, nf): 8 | super(Discriminator_VGG_128, self).__init__() 9 | # [64, 128, 128] 10 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 11 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 12 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 13 | # [64, 64, 64] 14 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 15 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 16 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 17 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 18 | # [128, 32, 32] 19 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 20 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 21 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 22 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 23 | # [256, 16, 16] 24 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 25 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 26 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 27 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 28 | # [512, 8, 8] 29 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 30 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 31 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 32 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 33 | 34 | self.linear1 = nn.Linear(512 * 4 * 4, 100) 35 | self.linear2 = nn.Linear(100, 1) 36 | 37 | # activation function 38 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 39 | 40 | def forward(self, x): 41 | fea = self.lrelu(self.conv0_0(x)) 42 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 43 | 44 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 45 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 46 | 47 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 48 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 49 | 50 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 51 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 52 | 53 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 54 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 55 | 56 | fea = fea.view(fea.size(0), -1) 57 | fea = self.lrelu(self.linear1(fea)) 58 | out = self.linear2(fea) 59 | return out 60 | 61 | 62 | class VGGFeatureExtractor(nn.Module): 63 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, 64 | device=torch.device('cpu')): 65 | super(VGGFeatureExtractor, self).__init__() 66 | self.use_input_norm = use_input_norm 67 | if use_bn: 68 | model = torchvision.models.vgg19_bn(pretrained=True) 69 | else: 70 | model = torchvision.models.vgg19(pretrained=True) 71 | if self.use_input_norm: 72 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 73 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] 74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 75 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] 76 | self.register_buffer('mean', mean) 77 | self.register_buffer('std', std) 78 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) 79 | # No need to BP to variable 80 | for k, v in self.features.named_parameters(): 81 | v.requires_grad = False 82 | 83 | def forward(self, x): 84 | # Assume input range is [0, 1] 85 | if self.use_input_norm: 86 | x = (x - self.mean) / self.std 87 | output = self.features(x) 88 | return output 89 | -------------------------------------------------------------------------------- /codes/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 12 | self.is_train = opt['is_train'] 13 | self.schedulers = [] 14 | self.optimizers = [] 15 | 16 | def feed_data(self, data): 17 | pass 18 | 19 | def optimize_parameters(self): 20 | pass 21 | 22 | def get_current_visuals(self): 23 | pass 24 | 25 | def get_current_losses(self): 26 | pass 27 | 28 | def print_network(self): 29 | pass 30 | 31 | def save(self, label): 32 | pass 33 | 34 | def load(self): 35 | pass 36 | 37 | def _set_lr(self, lr_groups_l): 38 | """Set learning rate for warmup 39 | lr_groups_l: list for lr_groups. each for a optimizer""" 40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 41 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 42 | param_group['lr'] = lr 43 | 44 | def _get_init_lr(self): 45 | """Get the initial lr, which is set by the scheduler""" 46 | init_lr_groups_l = [] 47 | for optimizer in self.optimizers: 48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 49 | return init_lr_groups_l 50 | 51 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 52 | for scheduler in self.schedulers: 53 | scheduler.step() 54 | # set up warm-up learning rate 55 | if cur_iter < warmup_iter: 56 | # get initial lr for each group 57 | init_lr_g_l = self._get_init_lr() 58 | # modify warming-up learning rates 59 | warm_up_lr_l = [] 60 | for init_lr_g in init_lr_g_l: 61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 62 | # set learning rate 63 | self._set_lr(warm_up_lr_l) 64 | 65 | def get_current_learning_rate(self): 66 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups] 67 | 68 | def get_network_description(self, network): 69 | """Get the string and total parameters of the network""" 70 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 71 | network = network.module 72 | return str(network), sum(map(lambda x: x.numel(), network.parameters())) 73 | 74 | def save_network(self, network, network_label, iter_label): 75 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 76 | save_path = os.path.join(self.opt['path']['models'], save_filename) 77 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 78 | network = network.module 79 | state_dict = network.state_dict() 80 | for key, param in state_dict.items(): 81 | state_dict[key] = param.cpu() 82 | torch.save(state_dict, save_path) 83 | 84 | def load_network(self, load_path, network, strict=True): 85 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 86 | network = network.module 87 | load_net = torch.load(load_path) 88 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 89 | for k, v in load_net.items(): 90 | if k.startswith('module.'): 91 | load_net_clean[k[7:]] = v 92 | else: 93 | load_net_clean[k] = v 94 | network.load_state_dict(load_net_clean, strict=strict) 95 | 96 | def save_training_state(self, epoch, iter_step): 97 | """Save training state during training, which will be used for resuming""" 98 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 99 | for s in self.schedulers: 100 | state['schedulers'].append(s.state_dict()) 101 | for o in self.optimizers: 102 | state['optimizers'].append(o.state_dict()) 103 | save_filename = '{}.state'.format(iter_step) 104 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 105 | torch.save(state, save_path) 106 | 107 | def resume_training(self, resume_state): 108 | """Resume the optimizers and schedulers for training""" 109 | resume_optimizers = resume_state['optimizers'] 110 | resume_schedulers = resume_state['schedulers'] 111 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 112 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 113 | for i, o in enumerate(resume_optimizers): 114 | self.optimizers[i].load_state_dict(o) 115 | for i, s in enumerate(resume_schedulers): 116 | self.schedulers[i].load_state_dict(s) 117 | -------------------------------------------------------------------------------- /codes/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-6): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | return loss 16 | 17 | 18 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 19 | class GANLoss(nn.Module): 20 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 21 | super(GANLoss, self).__init__() 22 | self.gan_type = gan_type.lower() 23 | self.real_label_val = real_label_val 24 | self.fake_label_val = fake_label_val 25 | 26 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 27 | self.loss = nn.BCEWithLogitsLoss() 28 | elif self.gan_type == 'lsgan': 29 | self.loss = nn.MSELoss() 30 | elif self.gan_type == 'wgan-gp': 31 | 32 | def wgan_loss(input, target): 33 | # target is boolean 34 | return -1 * input.mean() if target else input.mean() 35 | 36 | self.loss = wgan_loss 37 | else: 38 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 39 | 40 | def get_target_label(self, input, target_is_real): 41 | if self.gan_type == 'wgan-gp': 42 | return target_is_real 43 | if target_is_real: 44 | return torch.empty_like(input).fill_(self.real_label_val) 45 | else: 46 | return torch.empty_like(input).fill_(self.fake_label_val) 47 | 48 | def forward(self, input, target_is_real): 49 | target_label = self.get_target_label(input, target_is_real) 50 | loss = self.loss(input, target_label) 51 | return loss 52 | 53 | 54 | class GradientPenaltyLoss(nn.Module): 55 | def __init__(self, device=torch.device('cpu')): 56 | super(GradientPenaltyLoss, self).__init__() 57 | self.register_buffer('grad_outputs', torch.Tensor()) 58 | self.grad_outputs = self.grad_outputs.to(device) 59 | 60 | def get_grad_outputs(self, input): 61 | if self.grad_outputs.size() != input.size(): 62 | self.grad_outputs.resize_(input.size()).fill_(1.0) 63 | return self.grad_outputs 64 | 65 | def forward(self, interp, interp_crit): 66 | grad_outputs = self.get_grad_outputs(interp_crit) 67 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 68 | grad_outputs=grad_outputs, create_graph=True, 69 | retain_graph=True, only_inputs=True)[0] 70 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 71 | grad_interp_norm = grad_interp.norm(2, dim=1) 72 | 73 | loss = ((grad_interp_norm - 1)**2).mean() 74 | return loss 75 | -------------------------------------------------------------------------------- /codes/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restarts = [v + 1 for v in self.restarts] 16 | self.restart_weights = weights if weights else [1] 17 | assert len(self.restarts) == len( 18 | self.restart_weights), 'restarts and their weights do not match.' 19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch in self.restarts: 23 | if self.clear_state: 24 | self.optimizer.state = defaultdict(dict) 25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 26 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 27 | if self.last_epoch not in self.milestones: 28 | return [group['lr'] for group in self.optimizer.param_groups] 29 | return [ 30 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 31 | for group in self.optimizer.param_groups 32 | ] 33 | 34 | 35 | class CosineAnnealingLR_Restart(_LRScheduler): 36 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 37 | self.T_period = T_period 38 | self.T_max = self.T_period[0] # current T period 39 | self.eta_min = eta_min 40 | self.restarts = restarts if restarts else [0] 41 | self.restarts = [v + 1 for v in self.restarts] 42 | self.restart_weights = weights if weights else [1] 43 | self.last_restart = 0 44 | assert len(self.restarts) == len( 45 | self.restart_weights), 'restarts and their weights do not match.' 46 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | if self.last_epoch == 0: 50 | return self.base_lrs 51 | elif self.last_epoch in self.restarts: 52 | self.last_restart = self.last_epoch 53 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 54 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 55 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 56 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 57 | return [ 58 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 60 | ] 61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 63 | (group['lr'] - self.eta_min) + self.eta_min 64 | for group in self.optimizer.param_groups] 65 | 66 | 67 | if __name__ == "__main__": 68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 69 | betas=(0.9, 0.99)) 70 | ############################## 71 | # MultiStepLR_Restart 72 | ############################## 73 | ## Original 74 | lr_steps = [200000, 400000, 600000, 800000] 75 | restarts = None 76 | restart_weights = None 77 | 78 | ## two 79 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 80 | restarts = [500000] 81 | restart_weights = [1] 82 | 83 | ## four 84 | lr_steps = [ 85 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 86 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 87 | ] 88 | restarts = [250000, 500000, 750000] 89 | restart_weights = [1, 1, 1] 90 | 91 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 92 | clear_state=False) 93 | 94 | ############################## 95 | # Cosine Annealing Restart 96 | ############################## 97 | ## two 98 | T_period = [500000, 500000] 99 | restarts = [500000] 100 | restart_weights = [1] 101 | 102 | ## four 103 | T_period = [250000, 250000, 250000, 250000] 104 | restarts = [250000, 500000, 750000] 105 | restart_weights = [1, 1, 1] 106 | 107 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 108 | weights=restart_weights) 109 | 110 | ############################## 111 | # Draw figure 112 | ############################## 113 | N_iter = 1000000 114 | lr_l = list(range(N_iter)) 115 | for i in range(N_iter): 116 | scheduler.step() 117 | current_lr = optimizer.param_groups[0]['lr'] 118 | lr_l[i] = current_lr 119 | 120 | import matplotlib as mpl 121 | from matplotlib import pyplot as plt 122 | import matplotlib.ticker as mtick 123 | mpl.style.use('default') 124 | import seaborn 125 | seaborn.set(style='whitegrid') 126 | seaborn.set_context('paper') 127 | 128 | plt.figure(1) 129 | plt.subplot(111) 130 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 131 | plt.title('Title', fontsize=16, color='k') 132 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 133 | legend = plt.legend(loc='upper right', shadow=False) 134 | ax = plt.gca() 135 | labels = ax.get_xticks().tolist() 136 | for k, v in enumerate(labels): 137 | labels[k] = str(int(v / 1000)) + 'K' 138 | ax.set_xticklabels(labels) 139 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 140 | 141 | ax.set_ylabel('Learning rate') 142 | ax.set_xlabel('Iteration') 143 | fig = plt.gcf() 144 | plt.show() 145 | -------------------------------------------------------------------------------- /codes/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import models.archs.SRResNet_arch as SRResNet_arch 3 | import models.archs.discriminator_vgg_arch as SRGAN_arch 4 | import models.archs.RRDBNet_arch as RRDBNet_arch 5 | # import models.archs.EDVR_arch as EDVR_arch 6 | import models.archs.CSRNet_arch as CSRNet_arch 7 | 8 | 9 | # Generator 10 | def define_G(opt): 11 | opt_net = opt['network_G'] 12 | which_model = opt_net['which_model_G'] 13 | 14 | # image restoration 15 | if which_model == 'MSRResNet': 16 | netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 17 | nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) 18 | elif which_model == 'RRDBNet': 19 | netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 20 | nf=opt_net['nf'], nb=opt_net['nb']) 21 | 22 | elif which_model == 'AdaFMNet': 23 | netG = AdaFMNet_arch.AdaFMNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 24 | nf=opt_net['nf'], nb=opt_net['nb'], adafm_ksize=opt_net['adafm_ksize']) 25 | 26 | elif which_model == 'CResMDNet': 27 | netG = CResMDNet_arch.CResMDNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 28 | nf=opt_net['nf'], nb=opt_net['nb'], cond_dim=opt_net['cond_dim']) 29 | elif which_model == 'BaseNet': 30 | netG = CResMDNet_arch.BaseNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 31 | nf=opt_net['nf'], nb=opt_net['nb']) 32 | elif which_model == 'CondNet': 33 | netG = CResMDNet_arch.CondNet(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 34 | 35 | # image enhancement 36 | elif which_model == 'CSRNet': 37 | netG = CSRNet_arch.CSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], base_nf=opt_net['base_nf'], 38 | cond_nf=opt_net['cond_nf']) 39 | 40 | # video restoration 41 | elif which_model == 'EDVR': 42 | netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'], 43 | groups=opt_net['groups'], front_RBs=opt_net['front_RBs'], 44 | back_RBs=opt_net['back_RBs'], center=opt_net['center'], 45 | predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'], 46 | w_TSA=opt_net['w_TSA']) 47 | else: 48 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) 49 | 50 | return netG 51 | 52 | 53 | # Discriminator 54 | def define_D(opt): 55 | opt_net = opt['network_D'] 56 | which_model = opt_net['which_model_D'] 57 | 58 | if which_model == 'discriminator_vgg_128': 59 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 60 | else: 61 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) 62 | return netD 63 | 64 | 65 | # Define network used for perceptual loss 66 | def define_F(opt, use_bn=False): 67 | gpu_ids = opt['gpu_ids'] 68 | device = torch.device('cuda' if gpu_ids else 'cpu') 69 | # PyTorch pretrained VGG19-54, before ReLU. 70 | if use_bn: 71 | feature_layer = 49 72 | else: 73 | feature_layer = 34 74 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, 75 | use_input_norm=True, device=device) 76 | netF.eval() # No need to train 77 | return netF 78 | -------------------------------------------------------------------------------- /codes/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/codes/options/__init__.py -------------------------------------------------------------------------------- /codes/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | # export CUDA_VISIBLE_DEVICES 13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 16 | 17 | opt['is_train'] = is_train 18 | if opt['distortion'] == 'sr': 19 | scale = opt['scale'] 20 | 21 | # datasets 22 | for phase, dataset in opt['datasets'].items(): 23 | phase = phase.split('_')[0] 24 | dataset['phase'] = phase 25 | if opt['distortion'] == 'sr': 26 | dataset['scale'] = scale 27 | is_lmdb = False 28 | if dataset.get('dataroot_GT', None) is not None: 29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 30 | if dataset['dataroot_GT'].endswith('lmdb'): 31 | is_lmdb = True 32 | if dataset.get('dataroot_LQ', None) is not None: 33 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 34 | if dataset['dataroot_LQ'].endswith('lmdb'): 35 | is_lmdb = True 36 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 37 | if dataset['mode'].endswith('mc'): # for memcached 38 | dataset['data_type'] = 'mc' 39 | dataset['mode'] = dataset['mode'].replace('_mc', '') 40 | 41 | # path 42 | for key, path in opt['path'].items(): 43 | if path and key in opt['path'] and key != 'strict_load': 44 | opt['path'][key] = osp.expanduser(path) 45 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 46 | if is_train: 47 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 48 | opt['path']['experiments_root'] = experiments_root 49 | opt['path']['models'] = osp.join(experiments_root, 'models') 50 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 51 | opt['path']['log'] = experiments_root 52 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 53 | 54 | # change some options for debug mode 55 | if 'debug' in opt['name']: 56 | opt['train']['val_freq'] = 8 57 | opt['logger']['print_freq'] = 1 58 | opt['logger']['save_checkpoint_freq'] = 8 59 | else: # test 60 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 61 | opt['path']['results_root'] = results_root 62 | opt['path']['log'] = results_root 63 | 64 | # network 65 | if opt['distortion'] == 'sr': 66 | opt['network_G']['scale'] = scale 67 | 68 | return opt 69 | 70 | 71 | def dict2str(opt, indent_l=1): 72 | '''dict to string for logger''' 73 | msg = '' 74 | for k, v in opt.items(): 75 | if isinstance(v, dict): 76 | msg += ' ' * (indent_l * 2) + k + ':[\n' 77 | msg += dict2str(v, indent_l + 1) 78 | msg += ' ' * (indent_l * 2) + ']\n' 79 | else: 80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 81 | return msg 82 | 83 | 84 | class NoneDict(dict): 85 | def __missing__(self, key): 86 | return None 87 | 88 | 89 | # convert to NoneDict, which return None for missing key. 90 | def dict_to_nonedict(opt): 91 | if isinstance(opt, dict): 92 | new_opt = dict() 93 | for key, sub_opt in opt.items(): 94 | new_opt[key] = dict_to_nonedict(sub_opt) 95 | return NoneDict(**new_opt) 96 | elif isinstance(opt, list): 97 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 98 | else: 99 | return opt 100 | 101 | 102 | def check_resume(opt, resume_iter): 103 | '''Check resume states and pretrain_model paths''' 104 | logger = logging.getLogger('base') 105 | if opt['path']['resume_state']: 106 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 107 | 'pretrain_model_D', None) is not None: 108 | logger.warning('pretrain_model path will be ignored when resuming training.') 109 | 110 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 111 | '{}_G.pth'.format(resume_iter)) 112 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 113 | if 'gan' in opt['model']: 114 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 115 | '{}_D.pth'.format(resume_iter)) 116 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 117 | -------------------------------------------------------------------------------- /codes/options/test/test_ESRGAN.yml: -------------------------------------------------------------------------------- 1 | name: RRDB_ESRGAN_x4 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: set5 12 | mode: LQGT 13 | dataroot_GT: ../datasets/val_set5/Set5 14 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 15 | test_2: # the 2st test dataset 16 | name: set14 17 | mode: LQGT 18 | dataroot_GT: ../datasets/val_set14/Set14 19 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 20 | 21 | #### network structures 22 | network_G: 23 | which_model_G: RRDBNet 24 | in_nc: 3 25 | out_nc: 3 26 | nf: 64 27 | nb: 23 28 | upscale: 4 29 | 30 | #### path 31 | path: 32 | pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth 33 | -------------------------------------------------------------------------------- /codes/options/test/test_Enhance.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: test_csrnet 3 | suffix: ~ # add suffix to saved images 4 | model: sr 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | val: 11 | name: MIT_fivek_500 12 | mode: LQGT_enhance 13 | dataroot_GT: ../datasets/expert_C_test 14 | dataroot_LQ: ../datasets/raw_input_test 15 | 16 | #### network structures 17 | network_G: 18 | which_model_G: CSRNet 19 | in_nc: 3 20 | out_nc: 3 21 | base_nf: 64 22 | cond_nf: 32 23 | 24 | #### path 25 | path: 26 | root: 27 | pretrain_model_G: ../experiments/pretrain_models/csrnet.pth -------------------------------------------------------------------------------- /codes/options/test/test_SRGAN.yml: -------------------------------------------------------------------------------- 1 | name: MSRGANx4 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: set5 12 | mode: LQGT 13 | dataroot_GT: ../datasets/val_set5/Set5 14 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 15 | test_2: # the 2st test dataset 16 | name: set14 17 | mode: LQGT 18 | dataroot_GT: ../datasets/val_set14/Set14 19 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 20 | 21 | #### network structures 22 | network_G: 23 | which_model_G: MSRResNet 24 | in_nc: 3 25 | out_nc: 3 26 | nf: 64 27 | nb: 16 28 | upscale: 4 29 | 30 | #### path 31 | path: 32 | pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth 33 | -------------------------------------------------------------------------------- /codes/options/test/test_SRResNet.yml: -------------------------------------------------------------------------------- 1 | name: MSRResNetx4 2 | suffix: ~ # add suffix to saved images 3 | model: sr 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: set5 12 | mode: LQGT 13 | dataroot_GT: ../datasets/val_set5/Set5 14 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 15 | test_2: # the 2st test dataset 16 | name: set14 17 | mode: LQGT 18 | dataroot_GT: ../datasets/val_set14/Set14 19 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 20 | test_3: 21 | name: bsd100 22 | mode: LQGT 23 | dataroot_GT: ../datasets/BSD/BSDS100 24 | dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4 25 | test_4: 26 | name: urban100 27 | mode: LQGT 28 | dataroot_GT: ../datasets/urban100 29 | dataroot_LQ: ../datasets/urban100_bicLRx4 30 | test_5: 31 | name: div2k100 32 | mode: LQGT 33 | dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR 34 | dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4 35 | 36 | 37 | #### network structures 38 | network_G: 39 | which_model_G: MSRResNet 40 | in_nc: 3 41 | out_nc: 3 42 | nf: 64 43 | nb: 16 44 | upscale: 4 45 | 46 | #### path 47 | path: 48 | pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth 49 | -------------------------------------------------------------------------------- /codes/options/train/train_EDVR_M.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new 3 | use_tb_logger: true 4 | model: video_base 5 | distortion: sr 6 | scale: 4 7 | gpu_ids: [0,1,2,3,4,5,6,7] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: REDS 13 | mode: REDS 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb 18 | dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb 19 | cache_keys: ~ 20 | 21 | N_frames: 5 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 256 26 | LQ_size: 64 27 | use_flip: true 28 | use_rot: true 29 | color: RGB 30 | val: 31 | name: REDS4 32 | mode: video_test 33 | dataroot_GT: ../datasets/REDS4/GT 34 | dataroot_LQ: ../datasets/REDS4/sharp_bicubic 35 | cache_data: True 36 | N_frames: 5 37 | padding: new_info 38 | 39 | #### network structures 40 | network_G: 41 | which_model_G: EDVR 42 | nf: 64 43 | nframes: 5 44 | groups: 8 45 | front_RBs: 5 46 | back_RBs: 10 47 | predeblur: false 48 | HR_in: false 49 | w_TSA: true 50 | 51 | #### path 52 | path: 53 | pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth 54 | strict_load: false 55 | resume_state: ~ 56 | 57 | #### training settings: learning rate scheme, loss 58 | train: 59 | lr_G: !!float 4e-4 60 | lr_scheme: CosineAnnealingLR_Restart 61 | beta1: 0.9 62 | beta2: 0.99 63 | niter: 600000 64 | ft_tsa_only: 50000 65 | warmup_iter: -1 # -1: no warm up 66 | T_period: [50000, 100000, 150000, 150000, 150000] 67 | restarts: [50000, 150000, 300000, 450000] 68 | restart_weights: [1, 1, 1, 1] 69 | eta_min: !!float 1e-7 70 | 71 | pixel_criterion: cb 72 | pixel_weight: 1.0 73 | val_freq: !!float 5e3 74 | 75 | manual_seed: 0 76 | 77 | #### logger 78 | logger: 79 | print_freq: 100 80 | save_checkpoint_freq: !!float 5e3 81 | -------------------------------------------------------------------------------- /codes/options/train/train_EDVR_woTSA_M.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S 3 | use_tb_logger: true 4 | model: video_base 5 | distortion: sr 6 | scale: 4 7 | gpu_ids: [0,1,2,3,4,5,6,7] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: REDS 13 | mode: REDS 14 | interval_list: [1] 15 | random_reverse: false 16 | border_mode: false 17 | dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb 18 | dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb 19 | cache_keys: ~ 20 | 21 | N_frames: 5 22 | use_shuffle: true 23 | n_workers: 3 # per GPU 24 | batch_size: 32 25 | GT_size: 256 26 | LQ_size: 64 27 | use_flip: true 28 | use_rot: true 29 | color: RGB 30 | 31 | #### network structures 32 | network_G: 33 | which_model_G: EDVR 34 | nf: 64 35 | nframes: 5 36 | groups: 8 37 | front_RBs: 5 38 | back_RBs: 10 39 | predeblur: false 40 | HR_in: false 41 | w_TSA: false 42 | 43 | #### path 44 | path: 45 | pretrain_model_G: ~ 46 | strict_load: true 47 | resume_state: ~ 48 | 49 | #### training settings: learning rate scheme, loss 50 | train: 51 | lr_G: !!float 4e-4 52 | lr_scheme: CosineAnnealingLR_Restart 53 | beta1: 0.9 54 | beta2: 0.99 55 | niter: 600000 56 | warmup_iter: -1 # -1: no warm up 57 | T_period: [150000, 150000, 150000, 150000] 58 | restarts: [150000, 300000, 450000] 59 | restart_weights: [1, 1, 1] 60 | eta_min: !!float 1e-7 61 | 62 | pixel_criterion: cb 63 | pixel_weight: 1.0 64 | val_freq: !!float 5e3 65 | 66 | manual_seed: 0 67 | 68 | #### logger 69 | logger: 70 | print_freq: 100 71 | save_checkpoint_freq: !!float 5e3 72 | -------------------------------------------------------------------------------- /codes/options/train/train_ESRGAN.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 003_RRDB_ESRGANx4_DIV2K 3 | use_tb_logger: true 4 | model: srgan 5 | distortion: sr 6 | scale: 4 7 | gpu_ids: [2] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: DIV2K 13 | mode: LQGT 14 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb 15 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb 16 | 17 | use_shuffle: true 18 | n_workers: 6 # per GPU 19 | batch_size: 16 20 | GT_size: 128 21 | use_flip: true 22 | use_rot: true 23 | color: RGB 24 | val: 25 | name: val_set14 26 | mode: LQGT 27 | dataroot_GT: ../datasets/val_set14/Set14 28 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 29 | 30 | #### network structures 31 | network_G: 32 | which_model_G: RRDBNet 33 | in_nc: 3 34 | out_nc: 3 35 | nf: 64 36 | nb: 23 37 | network_D: 38 | which_model_D: discriminator_vgg_128 39 | in_nc: 3 40 | nf: 64 41 | 42 | #### path 43 | path: 44 | pretrain_model_G: ../experiments/pretrained_models/RRDB_PSNR_x4.pth 45 | strict_load: true 46 | resume_state: ~ 47 | 48 | #### training settings: learning rate scheme, loss 49 | train: 50 | lr_G: !!float 1e-4 51 | weight_decay_G: 0 52 | beta1_G: 0.9 53 | beta2_G: 0.99 54 | lr_D: !!float 1e-4 55 | weight_decay_D: 0 56 | beta1_D: 0.9 57 | beta2_D: 0.99 58 | lr_scheme: MultiStepLR 59 | 60 | niter: 400000 61 | warmup_iter: -1 # no warm up 62 | lr_steps: [50000, 100000, 200000, 300000] 63 | lr_gamma: 0.5 64 | 65 | pixel_criterion: l1 66 | pixel_weight: !!float 1e-2 67 | feature_criterion: l1 68 | feature_weight: 1 69 | gan_type: ragan # gan | ragan 70 | gan_weight: !!float 5e-3 71 | 72 | D_update_ratio: 1 73 | D_init_iters: 0 74 | 75 | manual_seed: 10 76 | val_freq: !!float 5e3 77 | 78 | #### logger 79 | logger: 80 | print_freq: 100 81 | save_checkpoint_freq: !!float 5e3 82 | -------------------------------------------------------------------------------- /codes/options/train/train_Enhance.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: csrnet 3 | use_tb_logger: true 4 | model: sr 5 | distortion: sr 6 | scale: 1 7 | gpu_ids: [0] 8 | 9 | #### datasets 10 | datasets: 11 | train: 12 | name: MIT_fivek 13 | mode: LQGT_enhance 14 | dataroot_GT: ../datasets/expert_C_train 15 | dataroot_LQ: ../datasets/raw_input_train 16 | 17 | use_shuffle: true 18 | n_workers: 16 19 | batch_size: 1 20 | color: RGB 21 | 22 | val: 23 | name: MIT_fivek_500 24 | mode: LQGT_enhance 25 | dataroot_GT: ../datasets/expert_C_test 26 | dataroot_LQ: ../datasets/raw_input_test 27 | 28 | #### network structures 29 | network_G: 30 | which_model_G: CSRNet 31 | in_nc: 3 32 | out_nc: 3 33 | base_nf: 64 34 | cond_nf: 32 35 | 36 | 37 | #### path 38 | path: 39 | root: 40 | pretrain_model_G: ~ 41 | strict_load: true 42 | resume_state: ~ 43 | 44 | #### training settings: learning rate scheme, loss 45 | train: 46 | lr_G: !!float 1e-4 47 | lr_scheme: MultiStepLR # MultiStepLR | CosineAnnealingLR_Restart 48 | beta1: 0.9 49 | beta2: 0.99 50 | niter: 600000 51 | warmup_iter: -1 # no warm up 52 | lr_scheme: MultiStepLR 53 | lr_steps: [100000, 200000, 300000, 400000, 500000] 54 | lr_gamma: 0.5 55 | 56 | pixel_criterion: l1 57 | pixel_weight: 1.0 58 | 59 | manual_seed: 10 60 | val_freq: !!float 5e3 61 | 62 | #### logger 63 | logger: 64 | print_freq: 100 65 | save_checkpoint_freq: !!float 5e3 66 | -------------------------------------------------------------------------------- /codes/options/train/train_SRGAN.yml: -------------------------------------------------------------------------------- 1 | # Not exactly the same as SRGAN in 2 | # With 16 Residual blocks w/o BN 3 | 4 | #### general settings 5 | name: 002_SRGANx4_MSRResNetx4Ini_DIV2K 6 | use_tb_logger: true 7 | model: srgan 8 | distortion: sr 9 | scale: 4 10 | gpu_ids: [1] 11 | 12 | #### datasets 13 | datasets: 14 | train: 15 | name: DIV2K 16 | mode: LQGT 17 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb 18 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb 19 | 20 | use_shuffle: true 21 | n_workers: 6 # per GPU 22 | batch_size: 16 23 | GT_size: 128 24 | use_flip: true 25 | use_rot: true 26 | color: RGB 27 | val: 28 | name: val_set14 29 | mode: LQGT 30 | dataroot_GT: ../datasets/val_set14/Set14 31 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4 32 | 33 | #### network structures 34 | network_G: 35 | which_model_G: MSRResNet 36 | in_nc: 3 37 | out_nc: 3 38 | nf: 64 39 | nb: 16 40 | upscale: 4 41 | network_D: 42 | which_model_D: discriminator_vgg_128 43 | in_nc: 3 44 | nf: 64 45 | 46 | #### path 47 | path: 48 | pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth 49 | strict_load: true 50 | resume_state: ~ 51 | 52 | #### training settings: learning rate scheme, loss 53 | train: 54 | lr_G: !!float 1e-4 55 | weight_decay_G: 0 56 | beta1_G: 0.9 57 | beta2_G: 0.99 58 | lr_D: !!float 1e-4 59 | weight_decay_D: 0 60 | beta1_D: 0.9 61 | beta2_D: 0.99 62 | lr_scheme: MultiStepLR 63 | 64 | niter: 400000 65 | warmup_iter: -1 # no warm up 66 | lr_steps: [50000, 100000, 200000, 300000] 67 | lr_gamma: 0.5 68 | 69 | pixel_criterion: l1 70 | pixel_weight: !!float 1e-2 71 | feature_criterion: l1 72 | feature_weight: 1 73 | gan_type: gan # gan | ragan 74 | gan_weight: !!float 5e-3 75 | 76 | D_update_ratio: 1 77 | D_init_iters: 0 78 | 79 | manual_seed: 10 80 | val_freq: !!float 5e3 81 | 82 | #### logger 83 | logger: 84 | print_freq: 100 85 | save_checkpoint_freq: !!float 5e3 86 | -------------------------------------------------------------------------------- /codes/options/train/train_SRResNet.yml: -------------------------------------------------------------------------------- 1 | # Not exactly the same as SRResNet in 2 | # With 16 Residual blocks w/o BN 3 | 4 | #### general settings 5 | name: 001_MSRResNetx4_scratch_DIV2K 6 | use_tb_logger: true 7 | model: sr 8 | distortion: sr 9 | scale: 4 10 | gpu_ids: [0] 11 | 12 | #### datasets 13 | datasets: 14 | train: 15 | name: DIV2K 16 | mode: LQGT 17 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb 18 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb 19 | 20 | use_shuffle: true 21 | n_workers: 6 # per GPU 22 | batch_size: 16 23 | GT_size: 128 24 | use_flip: true 25 | use_rot: true 26 | color: RGB 27 | val: 28 | name: val_set5 29 | mode: LQGT 30 | dataroot_GT: ../datasets/val_set5/Set5 31 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4 32 | 33 | #### network structures 34 | network_G: 35 | which_model_G: MSRResNet 36 | in_nc: 3 37 | out_nc: 3 38 | nf: 64 39 | nb: 16 40 | upscale: 4 41 | 42 | #### path 43 | path: 44 | pretrain_model_G: ~ 45 | strict_load: true 46 | resume_state: ~ 47 | 48 | #### training settings: learning rate scheme, loss 49 | train: 50 | lr_G: !!float 2e-4 51 | lr_scheme: CosineAnnealingLR_Restart 52 | beta1: 0.9 53 | beta2: 0.99 54 | niter: 1000000 55 | warmup_iter: -1 # no warm up 56 | T_period: [250000, 250000, 250000, 250000] 57 | restarts: [250000, 500000, 750000] 58 | restart_weights: [1, 1, 1] 59 | eta_min: !!float 1e-7 60 | 61 | pixel_criterion: l1 62 | pixel_weight: 1.0 63 | 64 | manual_seed: 10 65 | val_freq: !!float 5e3 66 | 67 | #### logger 68 | logger: 69 | print_freq: 100 70 | save_checkpoint_freq: !!float 5e3 71 | -------------------------------------------------------------------------------- /codes/run_scripts.sh: -------------------------------------------------------------------------------- 1 | # single GPU training (image SR) 2 | python train.py -opt options/train/train_SRResNet.yml 3 | python train.py -opt options/train/train_SRGAN.yml 4 | python train.py -opt options/train/train_ESRGAN.yml 5 | 6 | 7 | # distributed training (video SR) 8 | # 8 GPUs 9 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch 10 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch -------------------------------------------------------------------------------- /codes/scripts/transfer_params_MSRResNet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import torch 4 | try: 5 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 6 | import models.archs.SRResNet_arch as SRResNet_arch 7 | except ImportError: 8 | pass 9 | 10 | pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth') 11 | crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3) 12 | crt_net = crt_model.state_dict() 13 | 14 | for k, v in crt_net.items(): 15 | if k in pretrained_net and 'upconv1' not in k: 16 | crt_net[k] = pretrained_net[k] 17 | print('replace ... ', k) 18 | 19 | # x4 -> x3 20 | crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2 21 | crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2 22 | crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2 23 | crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2 24 | crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2 25 | crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2 26 | 27 | torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth') 28 | -------------------------------------------------------------------------------- /codes/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import logging 3 | import time 4 | import argparse 5 | from collections import OrderedDict 6 | import torch 7 | 8 | import options.options as option 9 | import utils.util as util 10 | from data.util import bgr2ycbcr 11 | from data import create_dataset, create_dataloader 12 | from models import create_model 13 | 14 | #### options 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') 17 | opt = option.parse(parser.parse_args().opt, is_train=False) 18 | opt = option.dict_to_nonedict(opt) 19 | 20 | util.mkdirs( 21 | (path for key, path in opt['path'].items() 22 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) 23 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 24 | screen=True, tofile=True) 25 | logger = logging.getLogger('base') 26 | logger.info(option.dict2str(opt)) 27 | 28 | #### Create test dataset and dataloader 29 | test_loaders = [] 30 | for phase, dataset_opt in sorted(opt['datasets'].items()): 31 | test_set = create_dataset(dataset_opt) 32 | test_loader = create_dataloader(test_set, dataset_opt) 33 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) 34 | test_loaders.append(test_loader) 35 | 36 | model = create_model(opt) 37 | for test_loader in test_loaders: 38 | test_set_name = test_loader.dataset.opt['name'] 39 | logger.info('\nTesting [{:s}]...'.format(test_set_name)) 40 | test_start_time = time.time() 41 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name) 42 | util.mkdir(dataset_dir) 43 | 44 | test_results = OrderedDict() 45 | test_results['psnr'] = [] 46 | test_results['ssim'] = [] 47 | test_results['psnr_y'] = [] 48 | test_results['ssim_y'] = [] 49 | 50 | cond = test_loader.dataset.opt['cond'] 51 | cond_norm = test_loader.dataset.opt['cond_norm'] 52 | for data in test_loader: 53 | need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True 54 | 55 | if cond is not None: 56 | for i in range(len(cond)): 57 | cond[i] = cond[i] / cond_norm[i] 58 | data['cond'] = torch.Tensor(cond).view(1, -1) 59 | need_cond = True 60 | elif test_loader.dataset.opt['mode'] in ['LQGT_cond']: 61 | need_cond = True 62 | else: 63 | need_cond = False 64 | 65 | model.feed_data(data, need_GT=need_GT, need_cond=need_cond) 66 | 67 | img_path = data['LQ_path'][0] 68 | img_name = osp.splitext(osp.basename(img_path))[0] 69 | 70 | model.test() 71 | visuals = model.get_current_visuals(need_GT=need_GT) 72 | 73 | sr_img = util.tensor2img(visuals['rlt']) # uint8 74 | 75 | # save images 76 | suffix = opt['suffix'] 77 | if suffix: 78 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') 79 | else: 80 | save_img_path = osp.join(dataset_dir, img_name + '.png') 81 | util.save_img(sr_img, save_img_path) 82 | 83 | # calculate PSNR and SSIM 84 | if need_GT: 85 | gt_img = util.tensor2img(visuals['GT']) 86 | psnr = util.calculate_psnr(sr_img, gt_img) 87 | ssim = util.calculate_ssim(sr_img, gt_img) 88 | test_results['psnr'].append(psnr) 89 | test_results['ssim'].append(ssim) 90 | 91 | if gt_img.shape[2] == 3: # RGB image 92 | sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True) 93 | gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True) 94 | 95 | psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255) 96 | ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255) 97 | test_results['psnr_y'].append(psnr_y) 98 | test_results['ssim_y'].append(ssim_y) 99 | logger.info( 100 | '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'. 101 | format(img_name, psnr, ssim, psnr_y, ssim_y)) 102 | else: 103 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) 104 | else: 105 | logger.info(img_name) 106 | 107 | if need_GT: # metrics 108 | # Average PSNR/SSIM results 109 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 110 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) 111 | logger.info( 112 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format( 113 | test_set_name, ave_psnr, ave_ssim)) 114 | if test_results['psnr_y'] and test_results['ssim_y']: 115 | ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) 116 | ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) 117 | logger.info( 118 | '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'. 119 | format(ave_psnr_y, ave_ssim_y)) 120 | -------------------------------------------------------------------------------- /codes/test_CSRNet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import logging 3 | import time 4 | import argparse 5 | from collections import OrderedDict 6 | import torch 7 | 8 | import options.options as option 9 | import utils.util as util 10 | from data.util import bgr2ycbcr 11 | from data import create_dataset, create_dataloader 12 | from models import create_model 13 | 14 | #### options 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') 17 | opt = option.parse(parser.parse_args().opt, is_train=False) 18 | opt = option.dict_to_nonedict(opt) 19 | 20 | util.mkdirs( 21 | (path for key, path in opt['path'].items() 22 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) 23 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 24 | screen=True, tofile=True) 25 | logger = logging.getLogger('base') 26 | logger.info(option.dict2str(opt)) 27 | 28 | #### Create test dataset and dataloader 29 | test_loaders = [] 30 | for phase, dataset_opt in sorted(opt['datasets'].items()): 31 | test_set = create_dataset(dataset_opt) 32 | test_loader = create_dataloader(test_set, dataset_opt) 33 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) 34 | test_loaders.append(test_loader) 35 | 36 | model = create_model(opt) 37 | for test_loader in test_loaders: 38 | test_set_name = test_loader.dataset.opt['name'] 39 | logger.info('\nTesting [{:s}]...'.format(test_set_name)) 40 | test_start_time = time.time() 41 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name) 42 | util.mkdir(dataset_dir) 43 | 44 | test_results = OrderedDict() 45 | test_results['psnr'] = [] 46 | test_results['ssim'] = [] 47 | 48 | for data in test_loader: 49 | need_GT = True 50 | model.feed_data(data, need_GT=need_GT, need_cond=False) 51 | 52 | img_path = data['LQ_path'][0] 53 | img_name = osp.splitext(osp.basename(img_path))[0] 54 | 55 | model.test() 56 | visuals = model.get_current_visuals(need_GT=need_GT) 57 | 58 | sr_img = util.tensor2img(visuals['rlt']) # uint8 59 | 60 | # save images 61 | suffix = opt['suffix'] 62 | if suffix: 63 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') 64 | else: 65 | save_img_path = osp.join(dataset_dir, img_name + '.png') 66 | util.save_img(sr_img, save_img_path) 67 | 68 | # calculate PSNR and SSIM 69 | if need_GT: 70 | gt_img = util.tensor2img(visuals['GT']) 71 | psnr = util.calculate_psnr(sr_img, gt_img) 72 | ssim = util.calculate_ssim(sr_img, gt_img) 73 | test_results['psnr'].append(psnr) 74 | test_results['ssim'].append(ssim) 75 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) 76 | else: 77 | logger.info(img_name) 78 | 79 | if need_GT: # metrics 80 | # Average PSNR/SSIM results 81 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 82 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) 83 | logger.info( 84 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format( 85 | test_set_name, ave_psnr, ave_ssim)) -------------------------------------------------------------------------------- /codes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/codes/utils/__init__.py -------------------------------------------------------------------------------- /codes/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import torch.nn.functional as F 6 | from datetime import datetime 7 | import random 8 | import logging 9 | from collections import OrderedDict 10 | import numpy as np 11 | import cv2 12 | import torch 13 | # from torchvision.utils import make_grid 14 | from shutil import get_terminal_size 15 | 16 | import yaml 17 | try: 18 | from yaml import CLoader as Loader, CDumper as Dumper 19 | except ImportError: 20 | from yaml import Loader, Dumper 21 | 22 | 23 | def OrderedYaml(): 24 | '''yaml orderedDict support''' 25 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 26 | 27 | def dict_representer(dumper, data): 28 | return dumper.represent_dict(data.items()) 29 | 30 | def dict_constructor(loader, node): 31 | return OrderedDict(loader.construct_pairs(node)) 32 | 33 | Dumper.add_representer(OrderedDict, dict_representer) 34 | Loader.add_constructor(_mapping_tag, dict_constructor) 35 | return Loader, Dumper 36 | 37 | 38 | #################### 39 | # miscellaneous 40 | #################### 41 | 42 | 43 | def get_timestamp(): 44 | return datetime.now().strftime('%y%m%d-%H%M%S') 45 | 46 | 47 | def mkdir(path): 48 | if not os.path.exists(path): 49 | os.makedirs(path) 50 | 51 | 52 | def mkdirs(paths): 53 | if isinstance(paths, str): 54 | mkdir(paths) 55 | else: 56 | for path in paths: 57 | mkdir(path) 58 | 59 | 60 | def mkdir_and_rename(path): 61 | if os.path.exists(path): 62 | new_name = path + '_archived_' + get_timestamp() 63 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 64 | logger = logging.getLogger('base') 65 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 66 | os.rename(path, new_name) 67 | os.makedirs(path) 68 | 69 | 70 | def set_random_seed(seed): 71 | random.seed(seed) 72 | np.random.seed(seed) 73 | torch.manual_seed(seed) 74 | torch.cuda.manual_seed_all(seed) 75 | 76 | 77 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 78 | '''set up logger''' 79 | lg = logging.getLogger(logger_name) 80 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 81 | datefmt='%y-%m-%d %H:%M:%S') 82 | lg.setLevel(level) 83 | if tofile: 84 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 85 | fh = logging.FileHandler(log_file, mode='w') 86 | fh.setFormatter(formatter) 87 | lg.addHandler(fh) 88 | if screen: 89 | sh = logging.StreamHandler() 90 | sh.setFormatter(formatter) 91 | lg.addHandler(sh) 92 | 93 | 94 | #################### 95 | # image convert 96 | #################### 97 | def crop_border(img_list, crop_border): 98 | """Crop borders of images 99 | Args: 100 | img_list (list [Numpy]): HWC 101 | crop_border (int): crop border for each end of height and weight 102 | 103 | Returns: 104 | (list [Numpy]): cropped image list 105 | """ 106 | if crop_border == 0: 107 | return img_list 108 | else: 109 | return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list] 110 | 111 | 112 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 113 | ''' 114 | Converts a torch Tensor into an image Numpy array 115 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 116 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 117 | ''' 118 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 119 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 120 | n_dim = tensor.dim() 121 | if n_dim == 4: 122 | n_img = len(tensor) 123 | # img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 124 | # img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 125 | elif n_dim == 3: 126 | img_np = tensor.numpy() 127 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 128 | elif n_dim == 2: 129 | img_np = tensor.numpy() 130 | else: 131 | raise TypeError( 132 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 133 | if out_type == np.uint8: 134 | img_np = (img_np * 255.0).round() 135 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 136 | return img_np.astype(out_type) 137 | 138 | 139 | def save_img(img, img_path, mode='RGB'): 140 | cv2.imwrite(img_path, img) 141 | 142 | 143 | def DUF_downsample(x, scale=4): 144 | """Downsamping with Gaussian kernel used in the DUF official code 145 | 146 | Args: 147 | x (Tensor, [B, T, C, H, W]): frames to be downsampled. 148 | scale (int): downsampling factor: 2 | 3 | 4. 149 | """ 150 | 151 | assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale) 152 | 153 | def gkern(kernlen=13, nsig=1.6): 154 | import scipy.ndimage.filters as fi 155 | inp = np.zeros((kernlen, kernlen)) 156 | # set element at the middle to one, a dirac delta 157 | inp[kernlen // 2, kernlen // 2] = 1 158 | # gaussian-smooth the dirac, resulting in a gaussian filter mask 159 | return fi.gaussian_filter(inp, nsig) 160 | 161 | B, T, C, H, W = x.size() 162 | x = x.view(-1, 1, H, W) 163 | pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter 164 | r_h, r_w = 0, 0 165 | if scale == 3: 166 | r_h = 3 - (H % 3) 167 | r_w = 3 - (W % 3) 168 | x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect') 169 | 170 | gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0) 171 | x = F.conv2d(x, gaussian_filter, stride=scale) 172 | x = x[:, :, 2:-2, 2:-2] 173 | x = x.view(B, T, C, x.size(2), x.size(3)) 174 | return x 175 | 176 | 177 | def single_forward(model, inp): 178 | """PyTorch model forward (single test), it is just a simple warpper 179 | Args: 180 | model (PyTorch model) 181 | inp (Tensor): inputs defined by the model 182 | 183 | Returns: 184 | output (Tensor): outputs of the model. float, in CPU 185 | """ 186 | with torch.no_grad(): 187 | model_output = model(inp) 188 | if isinstance(model_output, list) or isinstance(model_output, tuple): 189 | output = model_output[0] 190 | else: 191 | output = model_output 192 | output = output.data.float().cpu() 193 | return output 194 | 195 | 196 | def flipx4_forward(model, inp): 197 | """Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W 198 | Args: 199 | model (PyTorch model) 200 | inp (Tensor): inputs defined by the model 201 | 202 | Returns: 203 | output (Tensor): outputs of the model. float, in CPU 204 | """ 205 | # normal 206 | output_f = single_forward(model, inp) 207 | 208 | # flip W 209 | output = single_forward(model, torch.flip(inp, (-1, ))) 210 | output_f = output_f + torch.flip(output, (-1, )) 211 | # flip H 212 | output = single_forward(model, torch.flip(inp, (-2, ))) 213 | output_f = output_f + torch.flip(output, (-2, )) 214 | # flip both H and W 215 | output = single_forward(model, torch.flip(inp, (-2, -1))) 216 | output_f = output_f + torch.flip(output, (-2, -1)) 217 | 218 | return output_f / 4 219 | 220 | 221 | #################### 222 | # metric 223 | #################### 224 | 225 | 226 | def calculate_psnr(img1, img2): 227 | # img1 and img2 have range [0, 255] 228 | img1 = img1.astype(np.float64) 229 | img2 = img2.astype(np.float64) 230 | mse = np.mean((img1 - img2)**2) 231 | if mse == 0: 232 | return float('inf') 233 | return 20 * math.log10(255.0 / math.sqrt(mse)) 234 | 235 | 236 | def ssim(img1, img2): 237 | C1 = (0.01 * 255)**2 238 | C2 = (0.03 * 255)**2 239 | 240 | img1 = img1.astype(np.float64) 241 | img2 = img2.astype(np.float64) 242 | kernel = cv2.getGaussianKernel(11, 1.5) 243 | window = np.outer(kernel, kernel.transpose()) 244 | 245 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 246 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 247 | mu1_sq = mu1**2 248 | mu2_sq = mu2**2 249 | mu1_mu2 = mu1 * mu2 250 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 251 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 252 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 253 | 254 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 255 | (sigma1_sq + sigma2_sq + C2)) 256 | return ssim_map.mean() 257 | 258 | 259 | def calculate_ssim(img1, img2): 260 | '''calculate SSIM 261 | the same outputs as MATLAB's 262 | img1, img2: [0, 255] 263 | ''' 264 | if not img1.shape == img2.shape: 265 | raise ValueError('Input images must have the same dimensions.') 266 | if img1.ndim == 2: 267 | return ssim(img1, img2) 268 | elif img1.ndim == 3: 269 | if img1.shape[2] == 3: 270 | ssims = [] 271 | for i in range(3): 272 | ssims.append(ssim(img1, img2)) 273 | return np.array(ssims).mean() 274 | elif img1.shape[2] == 1: 275 | return ssim(np.squeeze(img1), np.squeeze(img2)) 276 | else: 277 | raise ValueError('Wrong input image dimensions.') 278 | 279 | 280 | class ProgressBar(object): 281 | '''A progress bar which can print the progress 282 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 283 | ''' 284 | 285 | def __init__(self, task_num=0, bar_width=50, start=True): 286 | self.task_num = task_num 287 | max_bar_width = self._get_max_bar_width() 288 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 289 | self.completed = 0 290 | if start: 291 | self.start() 292 | 293 | def _get_max_bar_width(self): 294 | terminal_width, _ = get_terminal_size() 295 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 296 | if max_bar_width < 10: 297 | print('terminal width is too small ({}), please consider widen the terminal for better ' 298 | 'progressbar visualization'.format(terminal_width)) 299 | max_bar_width = 10 300 | return max_bar_width 301 | 302 | def start(self): 303 | if self.task_num > 0: 304 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( 305 | ' ' * self.bar_width, self.task_num, 'Start...')) 306 | else: 307 | sys.stdout.write('completed: 0, elapsed: 0s') 308 | sys.stdout.flush() 309 | self.start_time = time.time() 310 | 311 | def update(self, msg='In progress...'): 312 | self.completed += 1 313 | elapsed = time.time() - self.start_time 314 | fps = self.completed / elapsed 315 | if self.task_num > 0: 316 | percentage = self.completed / float(self.task_num) 317 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 318 | mark_width = int(self.bar_width * percentage) 319 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 320 | sys.stdout.write('\033[2F') # cursor up 2 lines 321 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 322 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( 323 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) 324 | else: 325 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( 326 | self.completed, int(elapsed + 0.5), fps)) 327 | sys.stdout.flush() 328 | -------------------------------------------------------------------------------- /experiments/pretrain_models/csrnet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/experiments/pretrain_models/csrnet.pth -------------------------------------------------------------------------------- /figures/csrnet_fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/figures/csrnet_fig1.png -------------------------------------------------------------------------------- /figures/csrnet_fig6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/figures/csrnet_fig6.png --------------------------------------------------------------------------------