├── .idea
└── workspace.xml
├── README.md
├── codes
├── calculate_metrics.py
├── data
│ ├── LQGT_dataset.py
│ ├── LQGT_enhance_dataset.py
│ ├── LQ_dataset.py
│ ├── REDS_dataset.py
│ ├── Vimeo90K_dataset.py
│ ├── __init__.py
│ ├── data_sampler.py
│ ├── util.py
│ └── video_test_dataset.py
├── data_scripts
│ ├── create_lmdb.py
│ ├── extract_subimages.py
│ ├── generate_LR_Vimeo90K.m
│ ├── generate_mod_LR_bic.m
│ ├── generate_mod_LR_bic.py
│ ├── prepare_DIV2K_x4_dataset.sh
│ ├── regroup_REDS.py
│ ├── rename.py
│ └── test_dataloader.py
├── metrics
│ ├── calculate_PSNR_SSIM.m
│ └── calculate_PSNR_SSIM.py
├── models
│ ├── SRGAN_model.py
│ ├── SR_model.py
│ ├── Video_base_model.py
│ ├── __init__.py
│ ├── archs
│ │ ├── CSRNet_arch.py
│ │ ├── DUF_arch.py
│ │ ├── EDVR_arch.py
│ │ ├── RRDBNet_arch.py
│ │ ├── SRResNet_arch.py
│ │ ├── TOF_arch.py
│ │ ├── __init__.py
│ │ ├── arch_util.py
│ │ ├── dcn
│ │ │ ├── __init__.py
│ │ │ ├── deform_conv.py
│ │ │ ├── setup.py
│ │ │ └── src
│ │ │ │ ├── deform_conv_cuda.cpp
│ │ │ │ └── deform_conv_cuda_kernel.cu
│ │ └── discriminator_vgg_arch.py
│ ├── base_model.py
│ ├── loss.py
│ ├── lr_scheduler.py
│ └── networks.py
├── options
│ ├── __init__.py
│ ├── options.py
│ ├── test
│ │ ├── test_ESRGAN.yml
│ │ ├── test_Enhance.yml
│ │ ├── test_SRGAN.yml
│ │ └── test_SRResNet.yml
│ └── train
│ │ ├── train_EDVR_M.yml
│ │ ├── train_EDVR_woTSA_M.yml
│ │ ├── train_ESRGAN.yml
│ │ ├── train_Enhance.yml
│ │ ├── train_SRGAN.yml
│ │ └── train_SRResNet.yml
├── run_scripts.sh
├── scripts
│ └── transfer_params_MSRResNet.py
├── test.py
├── test_CSRNet.py
├── train.py
└── utils
│ ├── __init__.py
│ └── util.py
├── experiments
└── pretrain_models
│ └── csrnet.pth
└── figures
├── csrnet_fig1.png
└── csrnet_fig6.png
/README.md:
--------------------------------------------------------------------------------
1 | # Conditional Sequential Modulation for Efficient Global Image Retouching [Paper Link](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123580664.pdf)
2 | By Jingwen He*, Yihao Liu*, [Yu Qiao](http://mmlab.siat.ac.cn/yuqiao/), and [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ&hl=en) (* indicates equal contribution)
3 |
4 |
5 |
6 |
7 |
8 | Left: Compared with existing state-of-the-art methods, our method achieves
9 | superior performance with extremely few parameters (1/13 of HDRNet and 1/250
10 | of White-Box). The diameter of the circle represents the amount of trainable
11 | parameters. Right: Image retouching examples.
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | The first row shows smooth transition effects between different styles (expert A
22 | to B) by image interpolation. In the second row, we use image interpolation to control
23 | the retouching strength from input image to the automatic retouched result. We denote
24 | the interpolation coefficient α for each image.
25 |
26 | ### BibTex
27 | @article{he2020conditional,
28 | title={Conditional Sequential Modulation for Efficient Global Image Retouching},
29 | author={He, Jingwen and Liu, Yihao and Qiao, Yu and Dong, Chao},
30 | journal={arXiv preprint arXiv:2009.10390},
31 | year={2020}
32 | }
33 |
34 |
35 | ## Dependencies and Installation
36 |
37 | - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
38 | - [PyTorch >= 1.0](https://pytorch.org/)
39 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads)
40 | - Python packages: `pip install numpy opencv-python lmdb pyyaml`
41 | - TensorBoard:
42 | - PyTorch >= 1.1: `pip install tb-nightly future`
43 | - PyTorch == 1.0: `pip install tensorboardX`
44 |
45 |
46 | ## Datasets
47 |
48 | Here, we provide the preprocessed datasets: [MIT-Adobe FiveK dataset](https://drive.google.com/drive/folders/1qrGLFzW7RBlBO1FqgrLPrq9p2_p11ZFs?usp=sharing), which contains both training pairs and testing pairs.
49 | - training pairs: {GT: expert_C_train; Input: raw_input_train}
50 | - testing pairs: {GT: expert_C_test; Input: raw_input_test}
51 |
52 | ## How to Test
53 | 1. Modify the configuration file [`options/test/test_Enhance.yml`](codes/options/test/test_Enhance.yml). e.g., `dataroot_GT`, `dataroot_LQ`, and `pretrain_model_G`.
54 | (We provide a pretrained model in [`experiments/pretrain_models/csrnet.pth`](experiments/pretrain_models/))
55 | 1. Run command:
56 | ```c++
57 | python test_CSRNet.py -opt options/test/test_Enhance.yml
58 | ```
59 | 1. Modify the python file [`calculate_metrics.py`](codes/calculate_metrics.py): `input_path`, `GT_path` (Line 139, 140). Then run:
60 | ```c++
61 | python calculate_metrics.py
62 | ```
63 |
64 | ## How to Train
65 | 1. Modify the configuration file [`options/train/train_Enhance.yml`](codes/options/train/train_Enhance.yml). e.g., `dataroot_GT`, `dataroot_LQ`.
66 | 1. Run command:
67 | ```c++
68 | python train.py -opt options/train/train_Enhance.yml
69 | ```
70 |
71 | ## Acknowledgement
72 |
73 | - This code is based on [mmsr](https://github.com/open-mmlab/mmsr).
74 | - Thanks Yihao Liu for part of this work.
75 |
--------------------------------------------------------------------------------
/codes/calculate_metrics.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import cv2
3 | from PIL import Image
4 | import numpy as np
5 | import math
6 | import os
7 | import tifffile as tiff
8 | from skimage import color
9 |
10 | def ProPhotoRGB2XYZ(pp_rgb,reverse=False):
11 | if not reverse:
12 | M = [[0.7976749, 0.1351917, 0.0313534], \
13 | [0.2880402, 0.7118741, 0.0000857], \
14 | [0.0000000, 0.0000000, 0.8252100]]
15 | else:
16 | M = [[ 1.34594337, -0.25560752, -0.05111183],\
17 | [-0.54459882, 1.5081673, 0.02053511],\
18 | [ 0, 0, 1.21181275]]
19 | M = np.array(M)
20 | sp = pp_rgb.shape
21 | xyz = np.transpose(np.dot(M, np.transpose(pp_rgb.reshape((sp[0] * sp[1], sp[2])))))
22 | return xyz.reshape((sp[0], sp[1], 3))
23 |
24 | def linearize_ProPhotoRGB(pp_rgb, reverse=False):
25 | if not reverse:
26 | gamma = 1.8
27 | else:
28 | gamma = 1.0/1.8
29 | pp_rgb = np.power(pp_rgb, gamma)
30 | return pp_rgb
31 |
32 | def XYZ_chromatic_adapt(xyz, src_white='D65', dest_white='D50'):
33 | if src_white == 'D65' and dest_white == 'D50':
34 | M = [[1.0478112, 0.0228866, -0.0501270], \
35 | [0.0295424, 0.9904844, -0.0170491], \
36 | [-0.0092345, 0.0150436, 0.7521316]]
37 | elif src_white == 'D50' and dest_white == 'D65':
38 | M = [[0.9555766, -0.0230393, 0.0631636], \
39 | [-0.0282895, 1.0099416, 0.0210077], \
40 | [0.0122982, -0.0204830, 1.3299098]]
41 | else:
42 | raise UtilCnnImageEnhanceError('invalid pair of source and destination white reference %s,%s')\
43 | % (src_white, dest_white)
44 | M = np.array(M)
45 | sp = xyz.shape
46 | assert sp[2] == 3
47 | xyz = np.transpose(np.dot(M, np.transpose(xyz.reshape((sp[0] * sp[1], 3)))))
48 | return xyz.reshape((sp[0], sp[1], 3))
49 |
50 | def read_tiff_16bit_img_into_XYZ(tiff_fn, exposure=0):
51 | pp_rgb = tiff.imread(tiff_fn)
52 | pp_rgb = np.float64(pp_rgb) / (2 ** 16 - 1.0)
53 | if not pp_rgb.shape[2] == 3:
54 | print('pp_rgb shape',pp_rgb.shape)
55 | raise UtilImageError('image channel number is not 3')
56 | pp_rgb = linearize_ProPhotoRGB(pp_rgb)
57 | pp_rgb *= np.power(2, exposure)
58 | xyz = ProPhotoRGB2XYZ(pp_rgb)
59 | xyz = XYZ_chromatic_adapt(xyz, src_white='D50', dest_white='D65')
60 | return xyz
61 |
62 | def read_tiff_16bit_img_into_LAB(tiff_fn, exposure=0, normalize_Lab=False):
63 | xyz = read_tiff_16bit_img_into_XYZ(tiff_fn, exposure)
64 | lab = color.xyz2lab(xyz)
65 | if normalize_Lab:
66 | normalize_Lab_image(lab)
67 | return lab
68 |
69 |
70 |
71 | def calculate_Lab_RMSE(img1, img2):
72 | # img1 and img2 have range [0, 255]
73 | #img1 = img1.astype(np.float64)#/255
74 | #img2 = img2.astype(np.float64)#/255
75 | num_pix = img1.shape[0]*img1.shape[1]
76 |
77 | Lab_RMSE = np.mean(np.sqrt(np.sum((img1 - img2)**2, axis=2))) # correct 1
78 | #Lab_RMSE = np.sum(np.sqrt(np.sum((img1 - img2) ** 2, axis=2))) / num_pix # correct 2 same with correct 1
79 |
80 | #Lab_RMSE = np.sqrt(np.sum(((img1 - img2) ** 2)) / num_pix) # a liiter different
81 |
82 | return Lab_RMSE
83 |
84 | def calculate_psnr(img1, img2):
85 | # img1 and img2 have range [0, 255]
86 | img1 = img1.astype(np.float64)
87 | img2 = img2.astype(np.float64)
88 | mse = np.mean((img1 - img2)**2)
89 | if mse == 0:
90 | return float('inf')
91 | return 20 * math.log10(255.0 / math.sqrt(mse))
92 |
93 |
94 | def ssim_my(img1, img2):
95 | C1 = (0.01 * 255)**2
96 | C2 = (0.03 * 255)**2
97 |
98 | img1 = img1.astype(np.float64)
99 | img2 = img2.astype(np.float64)
100 | kernel = cv2.getGaussianKernel(11, 1.5)
101 | window = np.outer(kernel, kernel.transpose())
102 |
103 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
104 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
105 | mu1_sq = mu1**2
106 | mu2_sq = mu2**2
107 | mu1_mu2 = mu1 * mu2
108 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
109 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
110 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
111 |
112 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
113 | (sigma1_sq + sigma2_sq + C2))
114 | return ssim_map.mean()
115 |
116 |
117 | def calculate_ssim(img1, img2):
118 | '''calculate SSIM
119 | the same outputs as MATLAB's
120 | img1, img2: [0, 255]
121 | '''
122 | if not img1.shape == img2.shape:
123 | raise ValueError('Input images must have the same dimensions.')
124 | if img1.ndim == 2:
125 | return ssim_my(img1, img2)
126 | elif img1.ndim == 3:
127 | if img1.shape[2] == 3:
128 | ssims = []
129 | for i in range(3):
130 | ssims.append(ssim_my(img1, img2))
131 | return np.array(ssims).mean()
132 | elif img1.shape[2] == 1:
133 | return ssim_my(np.squeeze(img1), np.squeeze(img2))
134 | else:
135 | raise ValueError('Wrong input image dimensions.')
136 |
137 | # ##########################################################
138 | # Please specify the paths for input dir and ground truth dir.
139 | input_path=""
140 | GT_path=""
141 |
142 | input_fname_list = os.listdir(input_path)
143 | input_fname_list.sort()
144 | input_path_list = [os.path.join(input_path, fname) for fname in input_fname_list]
145 |
146 | GT_fname_list = os.listdir(GT_path)
147 | GT_fname_list.sort()
148 | GT_path_list = [os.path.join(GT_path, fname) for fname in GT_fname_list]
149 |
150 | assert len(input_path_list) == len(GT_path_list)
151 | print(len(input_path_list))
152 |
153 |
154 | psnr_list = []
155 | ssim_list = []
156 | Lab_RMSE_list = []
157 | for i in range(len(input_path_list)):
158 | assert input_fname_list[i].split('.')[0] == GT_fname_list[i].split('.')[0]
159 | img1 = cv2.imread(input_path_list[i], cv2.IMREAD_COLOR)
160 | img2 = cv2.imread(GT_path_list[i], cv2.IMREAD_COLOR)
161 |
162 |
163 | img1_rgb = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
164 | img2_rgb = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
165 |
166 | img1_lab = cv2.cvtColor(img1, cv2.COLOR_BGR2Lab)
167 | img2_lab = cv2.cvtColor(img2, cv2.COLOR_BGR2Lab)
168 |
169 |
170 |
171 | psnr = calculate_psnr(img1_rgb, img2_rgb)
172 | ssim = calculate_ssim(img1_rgb, img2_rgb)
173 |
174 | Lab_RMSE = calculate_Lab_RMSE(img1_lab, img2_lab)
175 |
176 | print('img: {} PSNR: {} SSIM: {} Lab_RMSE: {}'.format(input_fname_list[i].split('.')[0], psnr, ssim, Lab_RMSE))
177 |
178 | psnr_list.append(psnr)
179 | ssim_list.append(ssim)
180 | Lab_RMSE_list.append(Lab_RMSE)
181 |
182 | print('Average PSNR: {} SSIM: {} Lab_RMSE: {} Total image: {}'.format(np.mean(psnr_list), np.mean(ssim_list), np.mean(Lab_RMSE_list), len(psnr_list)))
--------------------------------------------------------------------------------
/codes/data/LQGT_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import cv2
4 | import lmdb
5 | import torch
6 | import torch.utils.data as data
7 | import data.util as util
8 |
9 |
10 | class LQGTDataset(data.Dataset):
11 | """
12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, etc) and GT image pairs.
13 | If only GT images are provided, generate LQ images on-the-fly.
14 | """
15 |
16 | def __init__(self, opt):
17 | super(LQGTDataset, self).__init__()
18 | self.opt = opt
19 | self.data_type = self.opt['data_type']
20 | self.paths_LQ, self.paths_GT = None, None
21 | self.sizes_LQ, self.sizes_GT = None, None
22 | self.LQ_env, self.GT_env = None, None # environments for lmdb
23 |
24 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
25 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
26 | assert self.paths_GT, 'Error: GT path is empty.'
27 | if self.paths_LQ and self.paths_GT:
28 | assert len(self.paths_LQ) == len(
29 | self.paths_GT
30 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format(
31 | len(self.paths_LQ), len(self.paths_GT))
32 | self.random_scale_list = [1]
33 |
34 | def _init_lmdb(self):
35 | # https://github.com/chainer/chainermn/issues/129
36 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
37 | meminit=False)
38 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
39 | meminit=False)
40 |
41 | def __getitem__(self, index):
42 | if self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
43 | self._init_lmdb()
44 | GT_path, LQ_path = None, None
45 | scale = self.opt['scale']
46 | GT_size = self.opt['GT_size']
47 |
48 | # get GT image
49 | GT_path = self.paths_GT[index]
50 | resolution = [int(s) for s in self.sizes_GT[index].split('_')
51 | ] if self.data_type == 'lmdb' else None
52 | img_GT = util.read_img(self.GT_env, GT_path, resolution)
53 | if self.opt['phase'] != 'train': # modcrop in the validation / test phase
54 | img_GT = util.modcrop(img_GT, scale)
55 | #### downsample in base network
56 | img_GT = util.modcrop(img_GT, 2)
57 | if self.opt['color']: # change color space if necessary
58 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
59 |
60 | # get LQ image
61 | if self.paths_LQ:
62 | LQ_path = self.paths_LQ[index]
63 | resolution = [int(s) for s in self.sizes_LQ[index].split('_')
64 | ] if self.data_type == 'lmdb' else None
65 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
66 | #### downsample in base network
67 | img_LQ = util.modcrop(img_LQ, 2)
68 | else: # down-sampling on-the-fly
69 | # randomly scale during training
70 | if self.opt['phase'] == 'train':
71 | random_scale = random.choice(self.random_scale_list)
72 | H_s, W_s, _ = img_GT.shape
73 |
74 | def _mod(n, random_scale, scale, thres):
75 | rlt = int(n * random_scale)
76 | rlt = (rlt // scale) * scale
77 | return thres if rlt < thres else rlt
78 |
79 | H_s = _mod(H_s, random_scale, scale, GT_size)
80 | W_s = _mod(W_s, random_scale, scale, GT_size)
81 | img_GT = cv2.resize(img_GT, (W_s, H_s), interpolation=cv2.INTER_LINEAR)
82 | if img_GT.ndim == 2:
83 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR)
84 |
85 | H, W, _ = img_GT.shape
86 | # using matlab imresize
87 | img_LQ = util.imresize_np(img_GT, 1 / scale, True)
88 | if img_LQ.ndim == 2:
89 | img_LQ = np.expand_dims(img_LQ, axis=2)
90 |
91 | if self.opt['phase'] == 'train':
92 | # if the image size is too small
93 | H, W, _ = img_GT.shape
94 | if H < GT_size or W < GT_size:
95 | img_GT = cv2.resize(img_GT, (GT_size, GT_size), interpolation=cv2.INTER_LINEAR)
96 | # using matlab imresize
97 | img_LQ = util.imresize_np(img_GT, 1 / scale, True)
98 | if img_LQ.ndim == 2:
99 | img_LQ = np.expand_dims(img_LQ, axis=2)
100 |
101 | H, W, C = img_LQ.shape
102 | LQ_size = GT_size // scale
103 |
104 | # randomly crop
105 | rnd_h = random.randint(0, max(0, H - LQ_size))
106 | rnd_w = random.randint(0, max(0, W - LQ_size))
107 | img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :]
108 | rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale)
109 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :]
110 |
111 | # augmentation - flip, rotate
112 | img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'],
113 | self.opt['use_rot'])
114 |
115 | if self.opt['color']: # change color space if necessary
116 | img_LQ = util.channel_convert(C, self.opt['color'],
117 | [img_LQ])[0] # TODO during val no definition
118 |
119 | # BGR to RGB, HWC to CHW, numpy to tensor
120 | if img_GT.shape[2] == 3:
121 | img_GT = img_GT[:, :, [2, 1, 0]]
122 | img_LQ = img_LQ[:, :, [2, 1, 0]]
123 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
124 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
125 |
126 | if LQ_path is None:
127 | LQ_path = GT_path
128 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
129 |
130 | def __len__(self):
131 | return len(self.paths_GT)
132 |
--------------------------------------------------------------------------------
/codes/data/LQGT_enhance_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data as data
4 | import data.util as util
5 |
6 | class LQGT_enhance_dataset(data.Dataset):
7 | def __init__(self, opt):
8 | super(LQGT_enhance_dataset, self).__init__()
9 | self.opt = opt
10 | self.data_type = self.opt['data_type']
11 | self.paths_LQ, self.paths_GT = None, None
12 | self.sizes_LQ, self.sizes_GT = None, None
13 | self.LQ_env, self.GT_env = None, None # environments for lmdb
14 |
15 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT'])
16 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
17 | assert self.paths_GT, 'Error: GT path is empty.'
18 | if self.paths_LQ and self.paths_GT:
19 | assert len(self.paths_LQ) == len(
20 | self.paths_GT
21 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format(
22 | len(self.paths_LQ), len(self.paths_GT))
23 |
24 | def __getitem__(self, index):
25 | GT_path, LQ_path = None, None
26 |
27 | # get GT image
28 | GT_path = self.paths_GT[index]
29 | LQ_path = self.paths_LQ[index]
30 | img_GT = util.read_img(self.GT_env, GT_path)
31 | img_LQ = util.read_img(self.LQ_env, LQ_path)
32 |
33 | if self.opt['color']:
34 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0]
35 | img_LQ = util.channel_convert(img_LQ.shape[2], self.opt['color'], [img_LQ])[0]
36 |
37 | # BGR to RGB, HWC to CHW, numpy to tensor
38 | if img_GT.shape[2] == 3:
39 | img_GT = img_GT[:, :, [2, 1, 0]]
40 | img_LQ = img_LQ[:, :, [2, 1, 0]]
41 |
42 | H, W, _ = img_LQ.shape
43 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
44 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
45 |
46 | if LQ_path is None:
47 | LQ_path = GT_path
48 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path}
49 |
50 | def __len__(self):
51 | return len(self.paths_GT)
52 |
--------------------------------------------------------------------------------
/codes/data/LQ_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import lmdb
3 | import torch
4 | import torch.utils.data as data
5 | import data.util as util
6 |
7 |
8 | class LQDataset(data.Dataset):
9 | '''Read LQ images only in the test phase.'''
10 |
11 | def __init__(self, opt):
12 | super(LQDataset, self).__init__()
13 | self.opt = opt
14 | self.data_type = self.opt['data_type']
15 | self.paths_LQ, self.paths_GT = None, None
16 | self.LQ_env = None # environment for lmdb
17 |
18 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ'])
19 | assert self.paths_LQ, 'Error: LQ paths are empty.'
20 |
21 | def _init_lmdb(self):
22 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
23 | meminit=False)
24 |
25 | def __getitem__(self, index):
26 | if self.data_type == 'lmdb' and self.LQ_env is None:
27 | self._init_lmdb()
28 | LQ_path = None
29 |
30 | # get LQ image
31 | LQ_path = self.paths_LQ[index]
32 | resolution = [int(s) for s in self.sizes_LQ[index].split('_')
33 | ] if self.data_type == 'lmdb' else None
34 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution)
35 | H, W, C = img_LQ.shape
36 |
37 | if self.opt['color']: # change color space if necessary
38 | img_LQ = util.channel_convert(C, self.opt['color'], [img_LQ])[0]
39 |
40 | # BGR to RGB, HWC to CHW, numpy to tensor
41 | if img_LQ.shape[2] == 3:
42 | img_LQ = img_LQ[:, :, [2, 1, 0]]
43 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float()
44 |
45 | return {'LQ': img_LQ, 'LQ_path': LQ_path}
46 |
47 | def __len__(self):
48 | return len(self.paths_LQ)
49 |
--------------------------------------------------------------------------------
/codes/data/REDS_dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | REDS dataset
3 | support reading images from lmdb, image folder and memcached
4 | '''
5 | import os.path as osp
6 | import random
7 | import pickle
8 | import logging
9 | import numpy as np
10 | import cv2
11 | import lmdb
12 | import torch
13 | import torch.utils.data as data
14 | import data.util as util
15 | try:
16 | import mc # import memcached
17 | except ImportError:
18 | pass
19 |
20 | logger = logging.getLogger('base')
21 |
22 |
23 | class REDSDataset(data.Dataset):
24 | '''
25 | Reading the training REDS dataset
26 | key example: 000_00000000
27 | GT: Ground-Truth;
28 | LQ: Low-Quality, e.g., low-resolution/blurry/noisy/compressed frames
29 | support reading N LQ frames, N = 1, 3, 5, 7
30 | '''
31 |
32 | def __init__(self, opt):
33 | super(REDSDataset, self).__init__()
34 | self.opt = opt
35 | # temporal augmentation
36 | self.interval_list = opt['interval_list']
37 | self.random_reverse = opt['random_reverse']
38 | logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
39 | ','.join(str(x) for x in opt['interval_list']), self.random_reverse))
40 |
41 | self.half_N_frames = opt['N_frames'] // 2
42 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
43 | self.data_type = self.opt['data_type']
44 | self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
45 | #### directly load image keys
46 | if self.data_type == 'lmdb':
47 | self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
48 | logger.info('Using lmdb meta info for cache keys.')
49 | elif opt['cache_keys']:
50 | logger.info('Using cache keys: {}'.format(opt['cache_keys']))
51 | self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
52 | else:
53 | raise ValueError(
54 | 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
55 |
56 | # remove the REDS4 for testing
57 | self.paths_GT = [
58 | v for v in self.paths_GT if v.split('_')[0] not in ['000', '011', '015', '020']
59 | ]
60 | assert self.paths_GT, 'Error: GT path is empty.'
61 |
62 | if self.data_type == 'lmdb':
63 | self.GT_env, self.LQ_env = None, None
64 | elif self.data_type == 'mc': # memcached
65 | self.mclient = None
66 | elif self.data_type == 'img':
67 | pass
68 | else:
69 | raise ValueError('Wrong data type: {}'.format(self.data_type))
70 |
71 | def _init_lmdb(self):
72 | # https://github.com/chainer/chainermn/issues/129
73 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
74 | meminit=False)
75 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
76 | meminit=False)
77 |
78 | def _ensure_memcached(self):
79 | if self.mclient is None:
80 | # specify the config files
81 | server_list_config_file = None
82 | client_config_file = None
83 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
84 | client_config_file)
85 |
86 | def _read_img_mc(self, path):
87 | ''' Return BGR, HWC, [0, 255], uint8'''
88 | value = mc.pyvector()
89 | self.mclient.Get(path, value)
90 | value_buf = mc.ConvertBuffer(value)
91 | img_array = np.frombuffer(value_buf, np.uint8)
92 | img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
93 | return img
94 |
95 | def _read_img_mc_BGR(self, path, name_a, name_b):
96 | ''' Read BGR channels separately and then combine for 1M limits in cluster'''
97 | img_B = self._read_img_mc(osp.join(path + '_B', name_a, name_b + '.png'))
98 | img_G = self._read_img_mc(osp.join(path + '_G', name_a, name_b + '.png'))
99 | img_R = self._read_img_mc(osp.join(path + '_R', name_a, name_b + '.png'))
100 | img = cv2.merge((img_B, img_G, img_R))
101 | return img
102 |
103 | def __getitem__(self, index):
104 | if self.data_type == 'mc':
105 | self._ensure_memcached()
106 | elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
107 | self._init_lmdb()
108 |
109 | scale = self.opt['scale']
110 | GT_size = self.opt['GT_size']
111 | key = self.paths_GT[index]
112 | name_a, name_b = key.split('_')
113 | center_frame_idx = int(name_b)
114 |
115 | #### determine the neighbor frames
116 | interval = random.choice(self.interval_list)
117 | if self.opt['border_mode']:
118 | direction = 1 # 1: forward; 0: backward
119 | N_frames = self.opt['N_frames']
120 | if self.random_reverse and random.random() < 0.5:
121 | direction = random.choice([0, 1])
122 | if center_frame_idx + interval * (N_frames - 1) > 99:
123 | direction = 0
124 | elif center_frame_idx - interval * (N_frames - 1) < 0:
125 | direction = 1
126 | # get the neighbor list
127 | if direction == 1:
128 | neighbor_list = list(
129 | range(center_frame_idx, center_frame_idx + interval * N_frames, interval))
130 | else:
131 | neighbor_list = list(
132 | range(center_frame_idx, center_frame_idx - interval * N_frames, -interval))
133 | name_b = '{:08d}'.format(neighbor_list[0])
134 | else:
135 | # ensure not exceeding the borders
136 | while (center_frame_idx + self.half_N_frames * interval >
137 | 99) or (center_frame_idx - self.half_N_frames * interval < 0):
138 | center_frame_idx = random.randint(0, 99)
139 | # get the neighbor list
140 | neighbor_list = list(
141 | range(center_frame_idx - self.half_N_frames * interval,
142 | center_frame_idx + self.half_N_frames * interval + 1, interval))
143 | if self.random_reverse and random.random() < 0.5:
144 | neighbor_list.reverse()
145 | name_b = '{:08d}'.format(neighbor_list[self.half_N_frames])
146 |
147 | assert len(
148 | neighbor_list) == self.opt['N_frames'], 'Wrong length of neighbor list: {}'.format(
149 | len(neighbor_list))
150 |
151 | #### get the GT image (as the center frame)
152 | if self.data_type == 'mc':
153 | img_GT = self._read_img_mc_BGR(self.GT_root, name_a, name_b)
154 | img_GT = img_GT.astype(np.float32) / 255.
155 | elif self.data_type == 'lmdb':
156 | img_GT = util.read_img(self.GT_env, key, (3, 720, 1280))
157 | else:
158 | img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b + '.png'))
159 |
160 | #### get LQ images
161 | LQ_size_tuple = (3, 180, 320) if self.LR_input else (3, 720, 1280)
162 | img_LQ_l = []
163 | for v in neighbor_list:
164 | img_LQ_path = osp.join(self.LQ_root, name_a, '{:08d}.png'.format(v))
165 | if self.data_type == 'mc':
166 | if self.LR_input:
167 | img_LQ = self._read_img_mc(img_LQ_path)
168 | else:
169 | img_LQ = self._read_img_mc_BGR(self.LQ_root, name_a, '{:08d}'.format(v))
170 | img_LQ = img_LQ.astype(np.float32) / 255.
171 | elif self.data_type == 'lmdb':
172 | img_LQ = util.read_img(self.LQ_env, '{}_{:08d}'.format(name_a, v), LQ_size_tuple)
173 | else:
174 | img_LQ = util.read_img(None, img_LQ_path)
175 | img_LQ_l.append(img_LQ)
176 |
177 | if self.opt['phase'] == 'train':
178 | C, H, W = LQ_size_tuple # LQ size
179 | # randomly crop
180 | if self.LR_input:
181 | LQ_size = GT_size // scale
182 | rnd_h = random.randint(0, max(0, H - LQ_size))
183 | rnd_w = random.randint(0, max(0, W - LQ_size))
184 | img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
185 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
186 | img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
187 | else:
188 | rnd_h = random.randint(0, max(0, H - GT_size))
189 | rnd_w = random.randint(0, max(0, W - GT_size))
190 | img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
191 | img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
192 |
193 | # augmentation - flip, rotate
194 | img_LQ_l.append(img_GT)
195 | rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
196 | img_LQ_l = rlt[0:-1]
197 | img_GT = rlt[-1]
198 |
199 | # stack LQ images to NHWC, N is the frame number
200 | img_LQs = np.stack(img_LQ_l, axis=0)
201 | # BGR to RGB, HWC to CHW, numpy to tensor
202 | img_GT = img_GT[:, :, [2, 1, 0]]
203 | img_LQs = img_LQs[:, :, :, [2, 1, 0]]
204 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
205 | img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
206 | (0, 3, 1, 2)))).float()
207 | return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
208 |
209 | def __len__(self):
210 | return len(self.paths_GT)
211 |
--------------------------------------------------------------------------------
/codes/data/Vimeo90K_dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | Vimeo90K dataset
3 | support reading images from lmdb, image folder and memcached
4 | '''
5 | import os.path as osp
6 | import random
7 | import pickle
8 | import logging
9 | import numpy as np
10 | import cv2
11 | import lmdb
12 | import torch
13 | import torch.utils.data as data
14 | import data.util as util
15 | try:
16 | import mc # import memcached
17 | except ImportError:
18 | pass
19 | logger = logging.getLogger('base')
20 |
21 |
22 | class Vimeo90KDataset(data.Dataset):
23 | '''
24 | Reading the training Vimeo90K dataset
25 | key example: 00001_0001 (_1, ..., _7)
26 | GT (Ground-Truth): 4th frame;
27 | LQ (Low-Quality): support reading N LQ frames, N = 1, 3, 5, 7 centered with 4th frame
28 | '''
29 |
30 | def __init__(self, opt):
31 | super(Vimeo90KDataset, self).__init__()
32 | self.opt = opt
33 | # temporal augmentation
34 | self.interval_list = opt['interval_list']
35 | self.random_reverse = opt['random_reverse']
36 | logger.info('Temporal augmentation interval list: [{}], with random reverse is {}.'.format(
37 | ','.join(str(x) for x in opt['interval_list']), self.random_reverse))
38 |
39 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
40 | self.data_type = self.opt['data_type']
41 | self.LR_input = False if opt['GT_size'] == opt['LQ_size'] else True # low resolution inputs
42 |
43 | #### determine the LQ frame list
44 | '''
45 | N | frames
46 | 1 | 4
47 | 3 | 3,4,5
48 | 5 | 2,3,4,5,6
49 | 7 | 1,2,3,4,5,6,7
50 | '''
51 | self.LQ_frames_list = []
52 | for i in range(opt['N_frames']):
53 | self.LQ_frames_list.append(i + (9 - opt['N_frames']) // 2)
54 |
55 | #### directly load image keys
56 | if self.data_type == 'lmdb':
57 | self.paths_GT, _ = util.get_image_paths(self.data_type, opt['dataroot_GT'])
58 | logger.info('Using lmdb meta info for cache keys.')
59 | elif opt['cache_keys']:
60 | logger.info('Using cache keys: {}'.format(opt['cache_keys']))
61 | self.paths_GT = pickle.load(open(opt['cache_keys'], 'rb'))['keys']
62 | else:
63 | raise ValueError(
64 | 'Need to create cache keys (meta_info.pkl) by running [create_lmdb.py]')
65 | assert self.paths_GT, 'Error: GT path is empty.'
66 |
67 | if self.data_type == 'lmdb':
68 | self.GT_env, self.LQ_env = None, None
69 | elif self.data_type == 'mc': # memcached
70 | self.mclient = None
71 | elif self.data_type == 'img':
72 | pass
73 | else:
74 | raise ValueError('Wrong data type: {}'.format(self.data_type))
75 |
76 | def _init_lmdb(self):
77 | # https://github.com/chainer/chainermn/issues/129
78 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False,
79 | meminit=False)
80 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False,
81 | meminit=False)
82 |
83 | def _ensure_memcached(self):
84 | if self.mclient is None:
85 | # specify the config files
86 | server_list_config_file = None
87 | client_config_file = None
88 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file,
89 | client_config_file)
90 |
91 | def _read_img_mc(self, path):
92 | ''' Return BGR, HWC, [0, 255], uint8'''
93 | value = mc.pyvector()
94 | self.mclient.Get(path, value)
95 | value_buf = mc.ConvertBuffer(value)
96 | img_array = np.frombuffer(value_buf, np.uint8)
97 | img = cv2.imdecode(img_array, cv2.IMREAD_UNCHANGED)
98 | return img
99 |
100 | def __getitem__(self, index):
101 | if self.data_type == 'mc':
102 | self._ensure_memcached()
103 | elif self.data_type == 'lmdb' and (self.GT_env is None or self.LQ_env is None):
104 | self._init_lmdb()
105 |
106 | scale = self.opt['scale']
107 | GT_size = self.opt['GT_size']
108 | key = self.paths_GT[index]
109 | name_a, name_b = key.split('_')
110 | #### get the GT image (as the center frame)
111 | if self.data_type == 'mc':
112 | img_GT = self._read_img_mc(osp.join(self.GT_root, name_a, name_b, '4.png'))
113 | img_GT = img_GT.astype(np.float32) / 255.
114 | elif self.data_type == 'lmdb':
115 | img_GT = util.read_img(self.GT_env, key + '_4', (3, 256, 448))
116 | else:
117 | img_GT = util.read_img(None, osp.join(self.GT_root, name_a, name_b, 'im4.png'))
118 |
119 | #### get LQ images
120 | LQ_size_tuple = (3, 64, 112) if self.LR_input else (3, 256, 448)
121 | img_LQ_l = []
122 | for v in self.LQ_frames_list:
123 | if self.data_type == 'mc':
124 | img_LQ = self._read_img_mc(
125 | osp.join(self.LQ_root, name_a, name_b, '{}.png'.format(v)))
126 | img_LQ = img_LQ.astype(np.float32) / 255.
127 | elif self.data_type == 'lmdb':
128 | img_LQ = util.read_img(self.LQ_env, key + '_{}'.format(v), LQ_size_tuple)
129 | else:
130 | img_LQ = util.read_img(None,
131 | osp.join(self.LQ_root, name_a, name_b, 'im{}.png'.format(v)))
132 | img_LQ_l.append(img_LQ)
133 |
134 | if self.opt['phase'] == 'train':
135 | C, H, W = LQ_size_tuple # LQ size
136 | # randomly crop
137 | if self.LR_input:
138 | LQ_size = GT_size // scale
139 | rnd_h = random.randint(0, max(0, H - LQ_size))
140 | rnd_w = random.randint(0, max(0, W - LQ_size))
141 | img_LQ_l = [v[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] for v in img_LQ_l]
142 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale)
143 | img_GT = img_GT[rnd_h_HR:rnd_h_HR + GT_size, rnd_w_HR:rnd_w_HR + GT_size, :]
144 | else:
145 | rnd_h = random.randint(0, max(0, H - GT_size))
146 | rnd_w = random.randint(0, max(0, W - GT_size))
147 | img_LQ_l = [v[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :] for v in img_LQ_l]
148 | img_GT = img_GT[rnd_h:rnd_h + GT_size, rnd_w:rnd_w + GT_size, :]
149 |
150 | # augmentation - flip, rotate
151 | img_LQ_l.append(img_GT)
152 | rlt = util.augment(img_LQ_l, self.opt['use_flip'], self.opt['use_rot'])
153 | img_LQ_l = rlt[0:-1]
154 | img_GT = rlt[-1]
155 |
156 | # stack LQ images to NHWC, N is the frame number
157 | img_LQs = np.stack(img_LQ_l, axis=0)
158 | # BGR to RGB, HWC to CHW, numpy to tensor
159 | img_GT = img_GT[:, :, [2, 1, 0]]
160 | img_LQs = img_LQs[:, :, :, [2, 1, 0]]
161 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float()
162 | img_LQs = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQs,
163 | (0, 3, 1, 2)))).float()
164 | return {'LQs': img_LQs, 'GT': img_GT, 'key': key}
165 |
166 | def __len__(self):
167 | return len(self.paths_GT)
168 |
--------------------------------------------------------------------------------
/codes/data/__init__.py:
--------------------------------------------------------------------------------
1 | """create dataset and dataloader"""
2 | import logging
3 | import torch
4 | import torch.utils.data
5 |
6 |
7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
8 | phase = dataset_opt['phase']
9 | if phase == 'train':
10 | if opt['dist']:
11 | world_size = torch.distributed.get_world_size()
12 | num_workers = dataset_opt['n_workers']
13 | assert dataset_opt['batch_size'] % world_size == 0
14 | batch_size = dataset_opt['batch_size'] // world_size
15 | shuffle = False
16 | else:
17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
18 | batch_size = dataset_opt['batch_size']
19 | shuffle = True
20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
21 | num_workers=num_workers, sampler=sampler, drop_last=True,
22 | pin_memory=False)
23 | else:
24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
25 | pin_memory=False)
26 |
27 |
28 | def create_dataset(dataset_opt):
29 | mode = dataset_opt['mode']
30 | # datasets for image restoration and image enhancement
31 | if mode == 'LQ':
32 | from data.LQ_dataset import LQDataset as D
33 | elif mode == 'LQGT':
34 | from data.LQGT_dataset import LQGTDataset as D
35 | elif mode == 'LQGT_cond':
36 | from data.LQGT_cond_dataset import LQGT_cond_Dataset as D
37 | elif mode == 'LQGT_enhance':
38 | from data.LQGT_enhance_dataset import LQGT_enhance_dataset as D
39 | # datasets for video restoration
40 | elif mode == 'REDS':
41 | from data.REDS_dataset import REDSDataset as D
42 | elif mode == 'Vimeo90K':
43 | from data.Vimeo90K_dataset import Vimeo90KDataset as D
44 | elif mode == 'video_test':
45 | from data.video_test_dataset import VideoTestDataset as D
46 | else:
47 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
48 | dataset = D(dataset_opt)
49 |
50 | logger = logging.getLogger('base')
51 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
52 | dataset_opt['name']))
53 | return dataset
54 |
--------------------------------------------------------------------------------
/codes/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from torch.utils.data.distributed.DistributedSampler
3 | Support enlarging the dataset for *iteration-oriented* training, for saving time when restart the
4 | dataloader after each epoch
5 | """
6 | import math
7 | import torch
8 | from torch.utils.data.sampler import Sampler
9 | import torch.distributed as dist
10 |
11 |
12 | class DistIterSampler(Sampler):
13 | """Sampler that restricts data loading to a subset of the dataset.
14 |
15 | It is especially useful in conjunction with
16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
17 | process can pass a DistributedSampler instance as a DataLoader sampler,
18 | and load a subset of the original dataset that is exclusive to it.
19 |
20 | .. note::
21 | Dataset is assumed to be of constant size.
22 |
23 | Arguments:
24 | dataset: Dataset used for sampling.
25 | num_replicas (optional): Number of processes participating in
26 | distributed training.
27 | rank (optional): Rank of the current process within num_replicas.
28 | """
29 |
30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
31 | if num_replicas is None:
32 | if not dist.is_available():
33 | raise RuntimeError("Requires distributed package to be available")
34 | num_replicas = dist.get_world_size()
35 | if rank is None:
36 | if not dist.is_available():
37 | raise RuntimeError("Requires distributed package to be available")
38 | rank = dist.get_rank()
39 | self.dataset = dataset
40 | self.num_replicas = num_replicas
41 | self.rank = rank
42 | self.epoch = 0
43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
44 | self.total_size = self.num_samples * self.num_replicas
45 |
46 | def __iter__(self):
47 | # deterministically shuffle based on epoch
48 | g = torch.Generator()
49 | g.manual_seed(self.epoch)
50 | indices = torch.randperm(self.total_size, generator=g).tolist()
51 |
52 | dsize = len(self.dataset)
53 | indices = [v % dsize for v in indices]
54 |
55 | # subsample
56 | indices = indices[self.rank:self.total_size:self.num_replicas]
57 | assert len(indices) == self.num_samples
58 |
59 | return iter(indices)
60 |
61 | def __len__(self):
62 | return self.num_samples
63 |
64 | def set_epoch(self, epoch):
65 | self.epoch = epoch
66 |
--------------------------------------------------------------------------------
/codes/data/video_test_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import torch
3 | import torch.utils.data as data
4 | import data.util as util
5 |
6 |
7 | class VideoTestDataset(data.Dataset):
8 | """
9 | A video test dataset. Support:
10 | Vid4
11 | REDS4
12 | Vimeo90K-Test
13 |
14 | no need to prepare LMDB files
15 | """
16 |
17 | def __init__(self, opt):
18 | super(VideoTestDataset, self).__init__()
19 | self.opt = opt
20 | self.cache_data = opt['cache_data']
21 | self.half_N_frames = opt['N_frames'] // 2
22 | self.GT_root, self.LQ_root = opt['dataroot_GT'], opt['dataroot_LQ']
23 | self.data_type = self.opt['data_type']
24 | self.data_info = {'path_LQ': [], 'path_GT': [], 'folder': [], 'idx': [], 'border': []}
25 | if self.data_type == 'lmdb':
26 | raise ValueError('No need to use LMDB during validation/test.')
27 | #### Generate data info and cache data
28 | self.imgs_LQ, self.imgs_GT = {}, {}
29 | if opt['name'].lower() in ['vid4', 'reds4']:
30 | subfolders_LQ = util.glob_file_list(self.LQ_root)
31 | subfolders_GT = util.glob_file_list(self.GT_root)
32 | for subfolder_LQ, subfolder_GT in zip(subfolders_LQ, subfolders_GT):
33 | subfolder_name = osp.basename(subfolder_GT)
34 | img_paths_LQ = util.glob_file_list(subfolder_LQ)
35 | img_paths_GT = util.glob_file_list(subfolder_GT)
36 | max_idx = len(img_paths_LQ)
37 | assert max_idx == len(
38 | img_paths_GT), 'Different number of images in LQ and GT folders'
39 | self.data_info['path_LQ'].extend(img_paths_LQ)
40 | self.data_info['path_GT'].extend(img_paths_GT)
41 | self.data_info['folder'].extend([subfolder_name] * max_idx)
42 | for i in range(max_idx):
43 | self.data_info['idx'].append('{}/{}'.format(i, max_idx))
44 | border_l = [0] * max_idx
45 | for i in range(self.half_N_frames):
46 | border_l[i] = 1
47 | border_l[max_idx - i - 1] = 1
48 | self.data_info['border'].extend(border_l)
49 |
50 | if self.cache_data:
51 | self.imgs_LQ[subfolder_name] = util.read_img_seq(img_paths_LQ)
52 | self.imgs_GT[subfolder_name] = util.read_img_seq(img_paths_GT)
53 | elif opt['name'].lower() in ['vimeo90k-test']:
54 | pass # TODO
55 | else:
56 | raise ValueError(
57 | 'Not support video test dataset. Support Vid4, REDS4 and Vimeo90k-Test.')
58 |
59 | def __getitem__(self, index):
60 | # path_LQ = self.data_info['path_LQ'][index]
61 | # path_GT = self.data_info['path_GT'][index]
62 | folder = self.data_info['folder'][index]
63 | idx, max_idx = self.data_info['idx'][index].split('/')
64 | idx, max_idx = int(idx), int(max_idx)
65 | border = self.data_info['border'][index]
66 |
67 | if self.cache_data:
68 | select_idx = util.index_generation(idx, max_idx, self.opt['N_frames'],
69 | padding=self.opt['padding'])
70 | imgs_LQ = self.imgs_LQ[folder].index_select(0, torch.LongTensor(select_idx))
71 | img_GT = self.imgs_GT[folder][idx]
72 | else:
73 | pass # TODO
74 |
75 | return {
76 | 'LQs': imgs_LQ,
77 | 'GT': img_GT,
78 | 'folder': folder,
79 | 'idx': self.data_info['idx'][index],
80 | 'border': border
81 | }
82 |
83 | def __len__(self):
84 | return len(self.data_info['path_GT'])
85 |
--------------------------------------------------------------------------------
/codes/data_scripts/extract_subimages.py:
--------------------------------------------------------------------------------
1 | """A multi-thread tool to crop large images to sub-images for faster IO."""
2 | import os
3 | import os.path as osp
4 | import sys
5 | from multiprocessing import Pool
6 | import numpy as np
7 | import cv2
8 | from PIL import Image
9 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
10 | from utils.util import ProgressBar # noqa: E402
11 | import data.util as data_util # noqa: E402
12 |
13 |
14 | def main():
15 | mode = 'pair' # single (one input folder) | pair (extract corresponding GT and LR pairs)
16 | opt = {}
17 | opt['n_thread'] = 20
18 | opt['compression_level'] = 3 # 3 is the default value in cv2
19 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer
20 | # compression time. If read raw images during training, use 0 for faster IO speed.
21 | if mode == 'single':
22 | opt['input_folder'] = '../../datasets/DIV2K_train_HR'
23 | opt['save_folder'] = '../../datasets/DIV2K_sub'
24 | opt['crop_sz'] = 480 # the size of each sub-image
25 | opt['step'] = 240 # step of the sliding crop window
26 | opt['thres_sz'] = 48 # size threshold
27 | extract_signle(opt)
28 | elif mode == 'pair':
29 | GT_folder = '../../datasets/DIV2K_train_HR'
30 | LR_folder = '../../datasets/DIV2K_train_LR_bicubic/X4'
31 | save_GT_folder = '../../datasets/DIV2K_sub'
32 | save_LR_folder = '../../datasets/DIV2K800_sub_bicLRx4'
33 | scale_ratio = 4
34 | crop_sz = 480 # the size of each sub-image (GT)
35 | step = 240 # step of the sliding crop window (GT)
36 | thres_sz = 48 # size threshold
37 | ########################################################################
38 | # check that all the GT and LR images have correct scale ratio
39 | img_GT_list = data_util._get_paths_from_images(GT_folder)
40 | img_LR_list = data_util._get_paths_from_images(LR_folder)
41 | assert len(img_GT_list) == len(img_LR_list), 'different length of GT_folder and LR_folder.'
42 | for path_GT, path_LR in zip(img_GT_list, img_LR_list):
43 | img_GT = Image.open(path_GT)
44 | img_LR = Image.open(path_LR)
45 | w_GT, h_GT = img_GT.size
46 | w_LR, h_LR = img_LR.size
47 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
48 | w_GT, scale_ratio, w_LR, path_GT)
49 | assert w_GT / w_LR == scale_ratio, 'GT width [{:d}] is not {:d}X as LR weight [{:d}] for {:s}.'.format( # noqa: E501
50 | w_GT, scale_ratio, w_LR, path_GT)
51 | # check crop size, step and threshold size
52 | assert crop_sz % scale_ratio == 0, 'crop size is not {:d}X multiplication.'.format(
53 | scale_ratio)
54 | assert step % scale_ratio == 0, 'step is not {:d}X multiplication.'.format(scale_ratio)
55 | assert thres_sz % scale_ratio == 0, 'thres_sz is not {:d}X multiplication.'.format(
56 | scale_ratio)
57 | print('process GT...')
58 | opt['input_folder'] = GT_folder
59 | opt['save_folder'] = save_GT_folder
60 | opt['crop_sz'] = crop_sz
61 | opt['step'] = step
62 | opt['thres_sz'] = thres_sz
63 | extract_signle(opt)
64 | print('process LR...')
65 | opt['input_folder'] = LR_folder
66 | opt['save_folder'] = save_LR_folder
67 | opt['crop_sz'] = crop_sz // scale_ratio
68 | opt['step'] = step // scale_ratio
69 | opt['thres_sz'] = thres_sz // scale_ratio
70 | extract_signle(opt)
71 | assert len(data_util._get_paths_from_images(save_GT_folder)) == len(
72 | data_util._get_paths_from_images(
73 | save_LR_folder)), 'different length of save_GT_folder and save_LR_folder.'
74 | else:
75 | raise ValueError('Wrong mode.')
76 |
77 |
78 | def extract_signle(opt):
79 | input_folder = opt['input_folder']
80 | save_folder = opt['save_folder']
81 | if not osp.exists(save_folder):
82 | os.makedirs(save_folder)
83 | print('mkdir [{:s}] ...'.format(save_folder))
84 | else:
85 | print('Folder [{:s}] already exists. Exit...'.format(save_folder))
86 | sys.exit(1)
87 | img_list = data_util._get_paths_from_images(input_folder)
88 |
89 | def update(arg):
90 | pbar.update(arg)
91 |
92 | pbar = ProgressBar(len(img_list))
93 |
94 | pool = Pool(opt['n_thread'])
95 | for path in img_list:
96 | pool.apply_async(worker, args=(path, opt), callback=update)
97 | pool.close()
98 | pool.join()
99 | print('All subprocesses done.')
100 |
101 |
102 | def worker(path, opt):
103 | crop_sz = opt['crop_sz']
104 | step = opt['step']
105 | thres_sz = opt['thres_sz']
106 | img_name = osp.basename(path)
107 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
108 |
109 | n_channels = len(img.shape)
110 | if n_channels == 2:
111 | h, w = img.shape
112 | elif n_channels == 3:
113 | h, w, c = img.shape
114 | else:
115 | raise ValueError('Wrong image shape - {}'.format(n_channels))
116 |
117 | h_space = np.arange(0, h - crop_sz + 1, step)
118 | if h - (h_space[-1] + crop_sz) > thres_sz:
119 | h_space = np.append(h_space, h - crop_sz)
120 | w_space = np.arange(0, w - crop_sz + 1, step)
121 | if w - (w_space[-1] + crop_sz) > thres_sz:
122 | w_space = np.append(w_space, w - crop_sz)
123 |
124 | index = 0
125 | for x in h_space:
126 | for y in w_space:
127 | index += 1
128 | if n_channels == 2:
129 | crop_img = img[x:x + crop_sz, y:y + crop_sz]
130 | else:
131 | crop_img = img[x:x + crop_sz, y:y + crop_sz, :]
132 | crop_img = np.ascontiguousarray(crop_img)
133 | cv2.imwrite(
134 | osp.join(opt['save_folder'],
135 | img_name.replace('.png', '_s{:03d}.png'.format(index))), crop_img,
136 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
137 | return 'Processing {:s} ...'.format(img_name)
138 |
139 |
140 | if __name__ == '__main__':
141 | main()
142 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_LR_Vimeo90K.m:
--------------------------------------------------------------------------------
1 | function generate_LR_Vimeo90K()
2 | %% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
3 |
4 | up_scale = 4;
5 | mod_scale = 4;
6 | idx = 0;
7 | filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
8 | for i = 1 : length(filepaths)
9 | [~,imname,ext] = fileparts(filepaths(i).name);
10 | folder_path = filepaths(i).folder;
11 | save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
12 | if ~exist(save_LR_folder, 'dir')
13 | mkdir(save_LR_folder);
14 | end
15 | if isempty(imname)
16 | disp('Ignore . folder.');
17 | elseif strcmp(imname, '.')
18 | disp('Ignore .. folder.');
19 | else
20 | idx = idx + 1;
21 | str_rlt = sprintf('%d\t%s.\n', idx, imname);
22 | fprintf(str_rlt);
23 | % read image
24 | img = imread(fullfile(folder_path, [imname, ext]));
25 | img = im2double(img);
26 | % modcrop
27 | img = modcrop(img, mod_scale);
28 | % LR
29 | im_LR = imresize(img, 1/up_scale, 'bicubic');
30 | if exist('save_LR_folder', 'var')
31 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
32 | end
33 | end
34 | end
35 | end
36 |
37 | %% modcrop
38 | function img = modcrop(img, modulo)
39 | if size(img,3) == 1
40 | sz = size(img);
41 | sz = sz - mod(sz, modulo);
42 | img = img(1:sz(1), 1:sz(2));
43 | else
44 | tmpsz = size(img);
45 | sz = tmpsz(1:2);
46 | sz = sz - mod(sz, modulo);
47 | img = img(1:sz(1), 1:sz(2),:);
48 | end
49 | end
50 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_mod_LR_bic.m:
--------------------------------------------------------------------------------
1 | function generate_mod_LR_bic()
2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images.
3 |
4 | %% set parameters
5 | % comment the unnecessary line
6 | input_folder = '../../datasets/DIV2K/DIV2K800';
7 | % save_mod_folder = '';
8 | save_LR_folder = '../../datasets/DIV2K/DIV2K800_bicLRx4';
9 | % save_bic_folder = '';
10 |
11 | up_scale = 4;
12 | mod_scale = 4;
13 |
14 | if exist('save_mod_folder', 'var')
15 | if exist(save_mod_folder, 'dir')
16 | disp(['It will cover ', save_mod_folder]);
17 | else
18 | mkdir(save_mod_folder);
19 | end
20 | end
21 | if exist('save_LR_folder', 'var')
22 | if exist(save_LR_folder, 'dir')
23 | disp(['It will cover ', save_LR_folder]);
24 | else
25 | mkdir(save_LR_folder);
26 | end
27 | end
28 | if exist('save_bic_folder', 'var')
29 | if exist(save_bic_folder, 'dir')
30 | disp(['It will cover ', save_bic_folder]);
31 | else
32 | mkdir(save_bic_folder);
33 | end
34 | end
35 |
36 | idx = 0;
37 | filepaths = dir(fullfile(input_folder,'*.*'));
38 | for i = 1 : length(filepaths)
39 | [paths,imname,ext] = fileparts(filepaths(i).name);
40 | if isempty(imname)
41 | disp('Ignore . folder.');
42 | elseif strcmp(imname, '.')
43 | disp('Ignore .. folder.');
44 | else
45 | idx = idx + 1;
46 | str_rlt = sprintf('%d\t%s.\n', idx, imname);
47 | fprintf(str_rlt);
48 | % read image
49 | img = imread(fullfile(input_folder, [imname, ext]));
50 | img = im2double(img);
51 | % modcrop
52 | img = modcrop(img, mod_scale);
53 | if exist('save_mod_folder', 'var')
54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png']));
55 | end
56 | % LR
57 | im_LR = imresize(img, 1/up_scale, 'bicubic');
58 | if exist('save_LR_folder', 'var')
59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
60 | end
61 | % Bicubic
62 | if exist('save_bic_folder', 'var')
63 | im_B = imresize(im_LR, up_scale, 'bicubic');
64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '.png']));
65 | end
66 | end
67 | end
68 | end
69 |
70 | %% modcrop
71 | function img = modcrop(img, modulo)
72 | if size(img,3) == 1
73 | sz = size(img);
74 | sz = sz - mod(sz, modulo);
75 | img = img(1:sz(1), 1:sz(2));
76 | else
77 | tmpsz = size(img);
78 | sz = tmpsz(1:2);
79 | sz = sz - mod(sz, modulo);
80 | img = img(1:sz(1), 1:sz(2),:);
81 | end
82 | end
83 |
--------------------------------------------------------------------------------
/codes/data_scripts/generate_mod_LR_bic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import cv2
4 | import numpy as np
5 |
6 | try:
7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8 | from data.util import imresize_np
9 | except ImportError:
10 | pass
11 |
12 |
13 | def generate_mod_LR_bic():
14 | # set parameters
15 | up_scale = 4
16 | mod_scale = 4
17 | # set data dir
18 | sourcedir = '/data/datasets/img'
19 | savedir = '/data/datasets/mod'
20 |
21 | saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale))
22 | saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale))
23 | saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale))
24 |
25 | if not os.path.isdir(sourcedir):
26 | print('Error: No source data found')
27 | exit(0)
28 | if not os.path.isdir(savedir):
29 | os.mkdir(savedir)
30 |
31 | if not os.path.isdir(os.path.join(savedir, 'HR')):
32 | os.mkdir(os.path.join(savedir, 'HR'))
33 | if not os.path.isdir(os.path.join(savedir, 'LR')):
34 | os.mkdir(os.path.join(savedir, 'LR'))
35 | if not os.path.isdir(os.path.join(savedir, 'Bic')):
36 | os.mkdir(os.path.join(savedir, 'Bic'))
37 |
38 | if not os.path.isdir(saveHRpath):
39 | os.mkdir(saveHRpath)
40 | else:
41 | print('It will cover ' + str(saveHRpath))
42 |
43 | if not os.path.isdir(saveLRpath):
44 | os.mkdir(saveLRpath)
45 | else:
46 | print('It will cover ' + str(saveLRpath))
47 |
48 | if not os.path.isdir(saveBicpath):
49 | os.mkdir(saveBicpath)
50 | else:
51 | print('It will cover ' + str(saveBicpath))
52 |
53 | filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')]
54 | num_files = len(filepaths)
55 |
56 | # prepare data with augementation
57 | for i in range(num_files):
58 | filename = filepaths[i]
59 | print('No.{} -- Processing {}'.format(i, filename))
60 | # read image
61 | image = cv2.imread(os.path.join(sourcedir, filename))
62 |
63 | width = int(np.floor(image.shape[1] / mod_scale))
64 | height = int(np.floor(image.shape[0] / mod_scale))
65 | # modcrop
66 | if len(image.shape) == 3:
67 | image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
68 | else:
69 | image_HR = image[0:mod_scale * height, 0:mod_scale * width]
70 | # LR
71 | image_LR = imresize_np(image_HR, 1 / up_scale, True)
72 | # bic
73 | image_Bic = imresize_np(image_LR, up_scale, True)
74 |
75 | cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
76 | cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
77 | cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
78 |
79 |
80 | if __name__ == "__main__":
81 | generate_mod_LR_bic()
82 |
--------------------------------------------------------------------------------
/codes/data_scripts/prepare_DIV2K_x4_dataset.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | echo "Prepare DIV2K X4 datasets..."
4 | cd ../../datasets
5 | mkdir DIV2K
6 | cd DIV2K
7 |
8 | #### Step 1
9 | echo "Step 1: Download the datasets: [DIV2K_train_HR] and [DIV2K_train_LR_bicubic_X4]..."
10 | # GT
11 | FOLDER=DIV2K_train_HR
12 | FILE=DIV2K_train_HR.zip
13 | if [ ! -d "$FOLDER" ]; then
14 | if [ ! -f "$FILE" ]; then
15 | echo "Downloading $FILE..."
16 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
17 | fi
18 | unzip $FILE
19 | fi
20 | # LR
21 | FOLDER=DIV2K_train_LR_bicubic
22 | FILE=DIV2K_train_LR_bicubic_X4.zip
23 | if [ ! -d "$FOLDER" ]; then
24 | if [ ! -f "$FILE" ]; then
25 | echo "Downloading $FILE..."
26 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/$FILE
27 | fi
28 | unzip $FILE
29 | fi
30 |
31 | #### Step 2
32 | echo "Step 2: Rename the LR images..."
33 | cd ../../codes/data_scripts
34 | python rename.py
35 |
36 | #### Step 4
37 | echo "Step 4: Crop to sub-images..."
38 | python extract_subimages.py
39 |
40 | #### Step 5
41 | echo "Step5: Create LMDB files..."
42 | python create_lmdb.py
43 |
--------------------------------------------------------------------------------
/codes/data_scripts/regroup_REDS.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 | train_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic/X4'
5 | val_path = '/home/xtwang/datasets/REDS/val_sharp_bicubic/X4'
6 |
7 | # mv the val set
8 | val_folders = glob.glob(os.path.join(val_path, '*'))
9 | for folder in val_folders:
10 | new_folder_idx = '{:03d}'.format(int(folder.split('/')[-1]) + 240)
11 | os.system('cp -r {} {}'.format(folder, os.path.join(train_path, new_folder_idx)))
12 |
--------------------------------------------------------------------------------
/codes/data_scripts/rename.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 |
4 |
5 | def main():
6 | folder = '../../datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
7 | DIV2K(folder)
8 | print('Finished.')
9 |
10 |
11 | def DIV2K(path):
12 | img_path_l = glob.glob(os.path.join(path, '*'))
13 | for img_path in img_path_l:
14 | new_path = img_path.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
15 | os.rename(img_path, new_path)
16 |
17 |
18 | if __name__ == "__main__":
19 | main()
--------------------------------------------------------------------------------
/codes/data_scripts/test_dataloader.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os.path as osp
3 | import math
4 | import torchvision.utils
5 |
6 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
7 | from data import create_dataloader, create_dataset # noqa: E402
8 | from utils import util # noqa: E402
9 |
10 |
11 | def main():
12 | dataset = 'DIV2K800_sub' # REDS | Vimeo90K | DIV2K800_sub
13 | opt = {}
14 | opt['dist'] = False
15 | opt['gpu_ids'] = [0]
16 | if dataset == 'REDS':
17 | opt['name'] = 'test_REDS'
18 | opt['dataroot_GT'] = '../../datasets/REDS/train_sharp_wval.lmdb'
19 | opt['dataroot_LQ'] = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
20 | opt['mode'] = 'REDS'
21 | opt['N_frames'] = 5
22 | opt['phase'] = 'train'
23 | opt['use_shuffle'] = True
24 | opt['n_workers'] = 8
25 | opt['batch_size'] = 16
26 | opt['GT_size'] = 256
27 | opt['LQ_size'] = 64
28 | opt['scale'] = 4
29 | opt['use_flip'] = True
30 | opt['use_rot'] = True
31 | opt['interval_list'] = [1]
32 | opt['random_reverse'] = False
33 | opt['border_mode'] = False
34 | opt['cache_keys'] = None
35 | opt['data_type'] = 'lmdb' # img | lmdb | mc
36 | elif dataset == 'Vimeo90K':
37 | opt['name'] = 'test_Vimeo90K'
38 | opt['dataroot_GT'] = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
39 | opt['dataroot_LQ'] = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
40 | opt['mode'] = 'Vimeo90K'
41 | opt['N_frames'] = 7
42 | opt['phase'] = 'train'
43 | opt['use_shuffle'] = True
44 | opt['n_workers'] = 8
45 | opt['batch_size'] = 16
46 | opt['GT_size'] = 256
47 | opt['LQ_size'] = 64
48 | opt['scale'] = 4
49 | opt['use_flip'] = True
50 | opt['use_rot'] = True
51 | opt['interval_list'] = [1]
52 | opt['random_reverse'] = False
53 | opt['border_mode'] = False
54 | opt['cache_keys'] = None
55 | opt['data_type'] = 'lmdb' # img | lmdb | mc
56 | elif dataset == 'DIV2K800_sub':
57 | opt['name'] = 'DIV2K800'
58 | opt['dataroot_GT'] = '../../datasets/DIV2K/DIV2K800_sub.lmdb'
59 | opt['dataroot_LQ'] = '../../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb'
60 | opt['mode'] = 'LQGT'
61 | opt['phase'] = 'train'
62 | opt['use_shuffle'] = True
63 | opt['n_workers'] = 8
64 | opt['batch_size'] = 16
65 | opt['GT_size'] = 128
66 | opt['scale'] = 4
67 | opt['use_flip'] = True
68 | opt['use_rot'] = True
69 | opt['color'] = 'RGB'
70 | opt['data_type'] = 'lmdb' # img | lmdb
71 | else:
72 | raise ValueError('Please implement by yourself.')
73 |
74 | util.mkdir('tmp')
75 | train_set = create_dataset(opt)
76 | train_loader = create_dataloader(train_set, opt, opt, None)
77 | nrow = int(math.sqrt(opt['batch_size']))
78 | padding = 2 if opt['phase'] == 'train' else 0
79 |
80 | print('start...')
81 | for i, data in enumerate(train_loader):
82 | if i > 5:
83 | break
84 | print(i)
85 | if dataset == 'REDS' or dataset == 'Vimeo90K':
86 | LQs = data['LQs']
87 | else:
88 | LQ = data['LQ']
89 | GT = data['GT']
90 |
91 | if dataset == 'REDS' or dataset == 'Vimeo90K':
92 | for j in range(LQs.size(1)):
93 | torchvision.utils.save_image(LQs[:, j, :, :, :],
94 | 'tmp/LQ_{:03d}_{}.png'.format(i, j), nrow=nrow,
95 | padding=padding, normalize=False)
96 | else:
97 | torchvision.utils.save_image(LQ, 'tmp/LQ_{:03d}.png'.format(i), nrow=nrow,
98 | padding=padding, normalize=False)
99 | torchvision.utils.save_image(GT, 'tmp/GT_{:03d}.png'.format(i), nrow=nrow, padding=padding,
100 | normalize=False)
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
105 |
--------------------------------------------------------------------------------
/codes/metrics/calculate_PSNR_SSIM.m:
--------------------------------------------------------------------------------
1 | function calculate_PSNR_SSIM()
2 |
3 | % GT and SR folder
4 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5';
5 | folder_SR = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5';
6 | scale = 4;
7 | suffix = ''; % suffix for SR images
8 | test_Y = 1; % 1 for test Y channel only; 0 for test RGB channels
9 | if test_Y
10 | fprintf('Tesing Y channel.\n');
11 | else
12 | fprintf('Tesing RGB channels.\n');
13 | end
14 | filepaths = dir(fullfile(folder_GT, '*.png'));
15 | PSNR_all = zeros(1, length(filepaths));
16 | SSIM_all = zeros(1, length(filepaths));
17 |
18 | for idx_im = 1:length(filepaths)
19 | im_name = filepaths(idx_im).name;
20 | im_GT = imread(fullfile(folder_GT, im_name));
21 | im_SR = imread(fullfile(folder_SR, [im_name(1:end-4), suffix, '.png']));
22 |
23 | if test_Y % evaluate on Y channel in YCbCr color space
24 | if size(im_GT, 3) == 3
25 | im_GT_YCbCr = rgb2ycbcr(im2double(im_GT));
26 | im_GT_in = im_GT_YCbCr(:,:,1);
27 | im_SR_YCbCr = rgb2ycbcr(im2double(im_SR));
28 | im_SR_in = im_SR_YCbCr(:,:,1);
29 | else
30 | im_GT_in = im2double(im_GT);
31 | im_SR_in = im2double(im_SR);
32 | end
33 | else % evaluate on RGB channels
34 | im_GT_in = im2double(im_GT);
35 | im_SR_in = im2double(im_SR);
36 | end
37 |
38 | % calculate PSNR and SSIM
39 | PSNR_all(idx_im) = calculate_PSNR(im_GT_in * 255, im_SR_in * 255, scale);
40 | SSIM_all(idx_im) = calculate_SSIM(im_GT_in * 255, im_SR_in * 255, scale);
41 | fprintf('%d.(X%d)%20s: \tPSNR = %f \tSSIM = %f\n', idx_im, scale, im_name(1:end-4), PSNR_all(idx_im), SSIM_all(idx_im));
42 | end
43 |
44 | fprintf('\n%26s: \tPSNR = %f \tSSIM = %f\n', '####Average', mean(PSNR_all), mean(SSIM_all));
45 | end
46 |
47 | function res = calculate_PSNR(GT, SR, border)
48 | % remove border
49 | GT = GT(border+1:end-border, border+1:end-border, :);
50 | SR = SR(border+1:end-border, border+1:end-border, :);
51 | % calculate PNSR (assume in [0,255])
52 | error = GT(:) - SR(:);
53 | mse = mean(error.^2);
54 | res = 10 * log10(255^2/mse);
55 | end
56 |
57 | function res = calculate_SSIM(GT, SR, border)
58 | GT = GT(border+1:end-border, border+1:end-border, :);
59 | SR = SR(border+1:end-border, border+1:end-border, :);
60 | % calculate SSIM
61 | mssim = zeros(1, size(SR, 3));
62 | for i = 1:size(SR,3)
63 | [mssim(i), ~] = ssim_index(GT(:,:,i), SR(:,:,i));
64 | end
65 | res = mean(mssim);
66 | end
67 |
68 | function [mssim, ssim_map] = ssim_index(img1, img2, K, window, L)
69 |
70 | %========================================================================
71 | %SSIM Index, Version 1.0
72 | %Copyright(c) 2003 Zhou Wang
73 | %All Rights Reserved.
74 | %
75 | %The author is with Howard Hughes Medical Institute, and Laboratory
76 | %for Computational Vision at Center for Neural Science and Courant
77 | %Institute of Mathematical Sciences, New York University.
78 | %
79 | %----------------------------------------------------------------------
80 | %Permission to use, copy, or modify this software and its documentation
81 | %for educational and research purposes only and without fee is hereby
82 | %granted, provided that this copyright notice and the original authors'
83 | %names appear on all copies and supporting documentation. This program
84 | %shall not be used, rewritten, or adapted as the basis of a commercial
85 | %software or hardware product without first obtaining permission of the
86 | %authors. The authors make no representations about the suitability of
87 | %this software for any purpose. It is provided "as is" without express
88 | %or implied warranty.
89 | %----------------------------------------------------------------------
90 | %
91 | %This is an implementation of the algorithm for calculating the
92 | %Structural SIMilarity (SSIM) index between two images. Please refer
93 | %to the following paper:
94 | %
95 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
96 | %quality assessment: From error measurement to structural similarity"
97 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
98 | %
99 | %Kindly report any suggestions or corrections to zhouwang@ieee.org
100 | %
101 | %----------------------------------------------------------------------
102 | %
103 | %Input : (1) img1: the first image being compared
104 | % (2) img2: the second image being compared
105 | % (3) K: constants in the SSIM index formula (see the above
106 | % reference). defualt value: K = [0.01 0.03]
107 | % (4) window: local window for statistics (see the above
108 | % reference). default widnow is Gaussian given by
109 | % window = fspecial('gaussian', 11, 1.5);
110 | % (5) L: dynamic range of the images. default: L = 255
111 | %
112 | %Output: (1) mssim: the mean SSIM index value between 2 images.
113 | % If one of the images being compared is regarded as
114 | % perfect quality, then mssim can be considered as the
115 | % quality measure of the other image.
116 | % If img1 = img2, then mssim = 1.
117 | % (2) ssim_map: the SSIM index map of the test image. The map
118 | % has a smaller size than the input images. The actual size:
119 | % size(img1) - size(window) + 1.
120 | %
121 | %Default Usage:
122 | % Given 2 test images img1 and img2, whose dynamic range is 0-255
123 | %
124 | % [mssim ssim_map] = ssim_index(img1, img2);
125 | %
126 | %Advanced Usage:
127 | % User defined parameters. For example
128 | %
129 | % K = [0.05 0.05];
130 | % window = ones(8);
131 | % L = 100;
132 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
133 | %
134 | %See the results:
135 | %
136 | % mssim %Gives the mssim value
137 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map
138 | %
139 | %========================================================================
140 |
141 |
142 | if (nargin < 2 || nargin > 5)
143 | ssim_index = -Inf;
144 | ssim_map = -Inf;
145 | return;
146 | end
147 |
148 | if (size(img1) ~= size(img2))
149 | ssim_index = -Inf;
150 | ssim_map = -Inf;
151 | return;
152 | end
153 |
154 | [M, N] = size(img1);
155 |
156 | if (nargin == 2)
157 | if ((M < 11) || (N < 11))
158 | ssim_index = -Inf;
159 | ssim_map = -Inf;
160 | return
161 | end
162 | window = fspecial('gaussian', 11, 1.5); %
163 | K(1) = 0.01; % default settings
164 | K(2) = 0.03; %
165 | L = 255; %
166 | end
167 |
168 | if (nargin == 3)
169 | if ((M < 11) || (N < 11))
170 | ssim_index = -Inf;
171 | ssim_map = -Inf;
172 | return
173 | end
174 | window = fspecial('gaussian', 11, 1.5);
175 | L = 255;
176 | if (length(K) == 2)
177 | if (K(1) < 0 || K(2) < 0)
178 | ssim_index = -Inf;
179 | ssim_map = -Inf;
180 | return;
181 | end
182 | else
183 | ssim_index = -Inf;
184 | ssim_map = -Inf;
185 | return;
186 | end
187 | end
188 |
189 | if (nargin == 4)
190 | [H, W] = size(window);
191 | if ((H*W) < 4 || (H > M) || (W > N))
192 | ssim_index = -Inf;
193 | ssim_map = -Inf;
194 | return
195 | end
196 | L = 255;
197 | if (length(K) == 2)
198 | if (K(1) < 0 || K(2) < 0)
199 | ssim_index = -Inf;
200 | ssim_map = -Inf;
201 | return;
202 | end
203 | else
204 | ssim_index = -Inf;
205 | ssim_map = -Inf;
206 | return;
207 | end
208 | end
209 |
210 | if (nargin == 5)
211 | [H, W] = size(window);
212 | if ((H*W) < 4 || (H > M) || (W > N))
213 | ssim_index = -Inf;
214 | ssim_map = -Inf;
215 | return
216 | end
217 | if (length(K) == 2)
218 | if (K(1) < 0 || K(2) < 0)
219 | ssim_index = -Inf;
220 | ssim_map = -Inf;
221 | return;
222 | end
223 | else
224 | ssim_index = -Inf;
225 | ssim_map = -Inf;
226 | return;
227 | end
228 | end
229 |
230 | C1 = (K(1)*L)^2;
231 | C2 = (K(2)*L)^2;
232 | window = window/sum(sum(window));
233 | img1 = double(img1);
234 | img2 = double(img2);
235 |
236 | mu1 = filter2(window, img1, 'valid');
237 | mu2 = filter2(window, img2, 'valid');
238 | mu1_sq = mu1.*mu1;
239 | mu2_sq = mu2.*mu2;
240 | mu1_mu2 = mu1.*mu2;
241 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
242 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
243 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
244 |
245 | if (C1 > 0 && C2 > 0)
246 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
247 | else
248 | numerator1 = 2*mu1_mu2 + C1;
249 | numerator2 = 2*sigma12 + C2;
250 | denominator1 = mu1_sq + mu2_sq + C1;
251 | denominator2 = sigma1_sq + sigma2_sq + C2;
252 | ssim_map = ones(size(mu1));
253 | index = (denominator1.*denominator2 > 0);
254 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
255 | index = (denominator1 ~= 0) & (denominator2 == 0);
256 | ssim_map(index) = numerator1(index)./denominator1(index);
257 | end
258 |
259 | mssim = mean2(ssim_map);
260 |
261 | end
262 |
--------------------------------------------------------------------------------
/codes/metrics/calculate_PSNR_SSIM.py:
--------------------------------------------------------------------------------
1 | '''
2 | calculate the PSNR and SSIM.
3 | same as MATLAB's results
4 | '''
5 | import os
6 | import math
7 | import numpy as np
8 | import cv2
9 | import glob
10 |
11 |
12 | def main():
13 | # Configurations
14 |
15 | # GT - Ground-truth;
16 | # Gen: Generated / Restored / Recovered images
17 | folder_GT = '/mnt/SSD/xtwang/BasicSR_datasets/val_set5/Set5'
18 | folder_Gen = '/home/xtwang/Projects/BasicSR/results/RRDB_PSNR_x4/set5'
19 |
20 | crop_border = 4
21 | suffix = '' # suffix for Gen images
22 | test_Y = False # True: test Y channel only; False: test RGB channels
23 |
24 | PSNR_all = []
25 | SSIM_all = []
26 | img_list = sorted(glob.glob(folder_GT + '/*'))
27 |
28 | if test_Y:
29 | print('Testing Y channel.')
30 | else:
31 | print('Testing RGB channels.')
32 |
33 | for i, img_path in enumerate(img_list):
34 | base_name = os.path.splitext(os.path.basename(img_path))[0]
35 | im_GT = cv2.imread(img_path) / 255.
36 | im_Gen = cv2.imread(os.path.join(folder_Gen, base_name + suffix + '.png')) / 255.
37 |
38 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
39 | im_GT_in = bgr2ycbcr(im_GT)
40 | im_Gen_in = bgr2ycbcr(im_Gen)
41 | else:
42 | im_GT_in = im_GT
43 | im_Gen_in = im_Gen
44 |
45 | # crop borders
46 | if im_GT_in.ndim == 3:
47 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :]
48 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :]
49 | elif im_GT_in.ndim == 2:
50 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border]
51 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border]
52 | else:
53 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim))
54 |
55 | # calculate PSNR and SSIM
56 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255)
57 |
58 | SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255)
59 | print('{:3d} - {:25}. \tPSNR: {:.6f} dB, \tSSIM: {:.6f}'.format(
60 | i + 1, base_name, PSNR, SSIM))
61 | PSNR_all.append(PSNR)
62 | SSIM_all.append(SSIM)
63 | print('Average: PSNR: {:.6f} dB, SSIM: {:.6f}'.format(
64 | sum(PSNR_all) / len(PSNR_all),
65 | sum(SSIM_all) / len(SSIM_all)))
66 |
67 |
68 | def calculate_psnr(img1, img2):
69 | # img1 and img2 have range [0, 255]
70 | img1 = img1.astype(np.float64)
71 | img2 = img2.astype(np.float64)
72 | mse = np.mean((img1 - img2)**2)
73 | if mse == 0:
74 | return float('inf')
75 | return 20 * math.log10(255.0 / math.sqrt(mse))
76 |
77 |
78 | def ssim(img1, img2):
79 | C1 = (0.01 * 255)**2
80 | C2 = (0.03 * 255)**2
81 |
82 | img1 = img1.astype(np.float64)
83 | img2 = img2.astype(np.float64)
84 | kernel = cv2.getGaussianKernel(11, 1.5)
85 | window = np.outer(kernel, kernel.transpose())
86 |
87 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
88 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
89 | mu1_sq = mu1**2
90 | mu2_sq = mu2**2
91 | mu1_mu2 = mu1 * mu2
92 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
93 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
94 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
95 |
96 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
97 | (sigma1_sq + sigma2_sq + C2))
98 | return ssim_map.mean()
99 |
100 |
101 | def calculate_ssim(img1, img2):
102 | '''calculate SSIM
103 | the same outputs as MATLAB's
104 | img1, img2: [0, 255]
105 | '''
106 | if not img1.shape == img2.shape:
107 | raise ValueError('Input images must have the same dimensions.')
108 | if img1.ndim == 2:
109 | return ssim(img1, img2)
110 | elif img1.ndim == 3:
111 | if img1.shape[2] == 3:
112 | ssims = []
113 | for i in range(3):
114 | ssims.append(ssim(img1, img2))
115 | return np.array(ssims).mean()
116 | elif img1.shape[2] == 1:
117 | return ssim(np.squeeze(img1), np.squeeze(img2))
118 | else:
119 | raise ValueError('Wrong input image dimensions.')
120 |
121 |
122 | def bgr2ycbcr(img, only_y=True):
123 | '''same as matlab rgb2ycbcr
124 | only_y: only return Y channel
125 | Input:
126 | uint8, [0, 255]
127 | float, [0, 1]
128 | '''
129 | in_img_type = img.dtype
130 | img.astype(np.float32)
131 | if in_img_type != np.uint8:
132 | img *= 255.
133 | # convert
134 | if only_y:
135 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
136 | else:
137 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
138 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
139 | if in_img_type == np.uint8:
140 | rlt = rlt.round()
141 | else:
142 | rlt /= 255.
143 | return rlt.astype(in_img_type)
144 |
145 |
146 | if __name__ == '__main__':
147 | main()
148 |
--------------------------------------------------------------------------------
/codes/models/SRGAN_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.parallel import DataParallel, DistributedDataParallel
6 | import models.networks as networks
7 | import models.lr_scheduler as lr_scheduler
8 | from .base_model import BaseModel
9 | from models.loss import GANLoss
10 |
11 | logger = logging.getLogger('base')
12 |
13 |
14 | class SRGANModel(BaseModel):
15 | def __init__(self, opt):
16 | super(SRGANModel, self).__init__(opt)
17 | if opt['dist']:
18 | self.rank = torch.distributed.get_rank()
19 | else:
20 | self.rank = -1 # non dist training
21 | train_opt = opt['train']
22 |
23 | # define networks and load pretrained models
24 | self.netG = networks.define_G(opt).to(self.device)
25 | if opt['dist']:
26 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
27 | else:
28 | self.netG = DataParallel(self.netG)
29 | if self.is_train:
30 | self.netD = networks.define_D(opt).to(self.device)
31 | if opt['dist']:
32 | self.netD = DistributedDataParallel(self.netD,
33 | device_ids=[torch.cuda.current_device()])
34 | else:
35 | self.netD = DataParallel(self.netD)
36 |
37 | self.netG.train()
38 | self.netD.train()
39 |
40 | # define losses, optimizer and scheduler
41 | if self.is_train:
42 | # G pixel loss
43 | if train_opt['pixel_weight'] > 0:
44 | l_pix_type = train_opt['pixel_criterion']
45 | if l_pix_type == 'l1':
46 | self.cri_pix = nn.L1Loss().to(self.device)
47 | elif l_pix_type == 'l2':
48 | self.cri_pix = nn.MSELoss().to(self.device)
49 | else:
50 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type))
51 | self.l_pix_w = train_opt['pixel_weight']
52 | else:
53 | logger.info('Remove pixel loss.')
54 | self.cri_pix = None
55 |
56 | # G feature loss
57 | if train_opt['feature_weight'] > 0:
58 | l_fea_type = train_opt['feature_criterion']
59 | if l_fea_type == 'l1':
60 | self.cri_fea = nn.L1Loss().to(self.device)
61 | elif l_fea_type == 'l2':
62 | self.cri_fea = nn.MSELoss().to(self.device)
63 | else:
64 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type))
65 | self.l_fea_w = train_opt['feature_weight']
66 | else:
67 | logger.info('Remove feature loss.')
68 | self.cri_fea = None
69 | if self.cri_fea: # load VGG perceptual loss
70 | self.netF = networks.define_F(opt, use_bn=False).to(self.device)
71 | if opt['dist']:
72 | self.netF = DistributedDataParallel(self.netF,
73 | device_ids=[torch.cuda.current_device()])
74 | else:
75 | self.netF = DataParallel(self.netF)
76 |
77 | # GD gan loss
78 | self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device)
79 | self.l_gan_w = train_opt['gan_weight']
80 | # D_update_ratio and D_init_iters
81 | self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1
82 | self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0
83 |
84 | # optimizers
85 | # G
86 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
87 | optim_params = []
88 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
89 | if v.requires_grad:
90 | optim_params.append(v)
91 | else:
92 | if self.rank <= 0:
93 | logger.warning('Params [{:s}] will not optimize.'.format(k))
94 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
95 | weight_decay=wd_G,
96 | betas=(train_opt['beta1_G'], train_opt['beta2_G']))
97 | self.optimizers.append(self.optimizer_G)
98 | # D
99 | wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0
100 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'],
101 | weight_decay=wd_D,
102 | betas=(train_opt['beta1_D'], train_opt['beta2_D']))
103 | self.optimizers.append(self.optimizer_D)
104 |
105 | # schedulers
106 | if train_opt['lr_scheme'] == 'MultiStepLR':
107 | for optimizer in self.optimizers:
108 | self.schedulers.append(
109 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
110 | restarts=train_opt['restarts'],
111 | weights=train_opt['restart_weights'],
112 | gamma=train_opt['lr_gamma'],
113 | clear_state=train_opt['clear_state']))
114 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
115 | for optimizer in self.optimizers:
116 | self.schedulers.append(
117 | lr_scheduler.CosineAnnealingLR_Restart(
118 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
119 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
120 | else:
121 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
122 |
123 | self.log_dict = OrderedDict()
124 |
125 | self.print_network() # print network
126 | self.load() # load G and D if needed
127 |
128 | def feed_data(self, data, need_GT=True):
129 | self.var_L = data['LQ'].to(self.device) # LQ
130 | if need_GT:
131 | self.var_H = data['GT'].to(self.device) # GT
132 | input_ref = data['ref'] if 'ref' in data else data['GT']
133 | self.var_ref = input_ref.to(self.device)
134 |
135 | def optimize_parameters(self, step):
136 | # G
137 | for p in self.netD.parameters():
138 | p.requires_grad = False
139 |
140 | self.optimizer_G.zero_grad()
141 | self.fake_H = self.netG(self.var_L)
142 |
143 | l_g_total = 0
144 | if step % self.D_update_ratio == 0 and step > self.D_init_iters:
145 | if self.cri_pix: # pixel loss
146 | l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H)
147 | l_g_total += l_g_pix
148 | if self.cri_fea: # feature loss
149 | real_fea = self.netF(self.var_H).detach()
150 | fake_fea = self.netF(self.fake_H)
151 | l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea)
152 | l_g_total += l_g_fea
153 |
154 | pred_g_fake = self.netD(self.fake_H)
155 | if self.opt['train']['gan_type'] == 'gan':
156 | l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True)
157 | elif self.opt['train']['gan_type'] == 'ragan':
158 | pred_d_real = self.netD(self.var_ref).detach()
159 | l_g_gan = self.l_gan_w * (
160 | self.cri_gan(pred_d_real - torch.mean(pred_g_fake), False) +
161 | self.cri_gan(pred_g_fake - torch.mean(pred_d_real), True)) / 2
162 | l_g_total += l_g_gan
163 |
164 | l_g_total.backward()
165 | self.optimizer_G.step()
166 |
167 | # D
168 | for p in self.netD.parameters():
169 | p.requires_grad = True
170 |
171 | self.optimizer_D.zero_grad()
172 | l_d_total = 0
173 | pred_d_real = self.netD(self.var_ref)
174 | pred_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G
175 | if self.opt['train']['gan_type'] == 'gan':
176 | l_d_real = self.cri_gan(pred_d_real, True)
177 | l_d_fake = self.cri_gan(pred_d_fake, False)
178 | l_d_total = l_d_real + l_d_fake
179 | elif self.opt['train']['gan_type'] == 'ragan':
180 | l_d_real = self.cri_gan(pred_d_real - torch.mean(pred_d_fake), True)
181 | l_d_fake = self.cri_gan(pred_d_fake - torch.mean(pred_d_real), False)
182 | l_d_total = (l_d_real + l_d_fake) / 2
183 |
184 | l_d_total.backward()
185 | self.optimizer_D.step()
186 |
187 | # set log
188 | if step % self.D_update_ratio == 0 and step > self.D_init_iters:
189 | if self.cri_pix:
190 | self.log_dict['l_g_pix'] = l_g_pix.item()
191 | if self.cri_fea:
192 | self.log_dict['l_g_fea'] = l_g_fea.item()
193 | self.log_dict['l_g_gan'] = l_g_gan.item()
194 |
195 | self.log_dict['l_d_real'] = l_d_real.item()
196 | self.log_dict['l_d_fake'] = l_d_fake.item()
197 | self.log_dict['D_real'] = torch.mean(pred_d_real.detach())
198 | self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach())
199 |
200 | def test(self):
201 | self.netG.eval()
202 | with torch.no_grad():
203 | self.fake_H = self.netG(self.var_L)
204 | self.netG.train()
205 |
206 | def get_current_log(self):
207 | return self.log_dict
208 |
209 | def get_current_visuals(self, need_GT=True):
210 | out_dict = OrderedDict()
211 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
212 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
213 | if need_GT:
214 | out_dict['GT'] = self.var_H.detach()[0].float().cpu()
215 | return out_dict
216 |
217 | def print_network(self):
218 | # Generator
219 | s, n = self.get_network_description(self.netG)
220 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
221 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
222 | self.netG.module.__class__.__name__)
223 | else:
224 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
225 | if self.rank <= 0:
226 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
227 | logger.info(s)
228 | if self.is_train:
229 | # Discriminator
230 | s, n = self.get_network_description(self.netD)
231 | if isinstance(self.netD, nn.DataParallel) or isinstance(self.netD,
232 | DistributedDataParallel):
233 | net_struc_str = '{} - {}'.format(self.netD.__class__.__name__,
234 | self.netD.module.__class__.__name__)
235 | else:
236 | net_struc_str = '{}'.format(self.netD.__class__.__name__)
237 | if self.rank <= 0:
238 | logger.info('Network D structure: {}, with parameters: {:,d}'.format(
239 | net_struc_str, n))
240 | logger.info(s)
241 |
242 | if self.cri_fea: # F, Perceptual Network
243 | s, n = self.get_network_description(self.netF)
244 | if isinstance(self.netF, nn.DataParallel) or isinstance(
245 | self.netF, DistributedDataParallel):
246 | net_struc_str = '{} - {}'.format(self.netF.__class__.__name__,
247 | self.netF.module.__class__.__name__)
248 | else:
249 | net_struc_str = '{}'.format(self.netF.__class__.__name__)
250 | if self.rank <= 0:
251 | logger.info('Network F structure: {}, with parameters: {:,d}'.format(
252 | net_struc_str, n))
253 | logger.info(s)
254 |
255 | def load(self):
256 | load_path_G = self.opt['path']['pretrain_model_G']
257 | if load_path_G is not None:
258 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
259 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
260 | load_path_D = self.opt['path']['pretrain_model_D']
261 | if self.opt['is_train'] and load_path_D is not None:
262 | logger.info('Loading model for D [{:s}] ...'.format(load_path_D))
263 | self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])
264 |
265 | def save(self, iter_step):
266 | self.save_network(self.netG, 'G', iter_step)
267 | self.save_network(self.netD, 'D', iter_step)
268 |
--------------------------------------------------------------------------------
/codes/models/SR_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.parallel import DataParallel, DistributedDataParallel
7 | import models.networks as networks
8 | import models.lr_scheduler as lr_scheduler
9 | from .base_model import BaseModel
10 | from models.loss import CharbonnierLoss
11 |
12 | logger = logging.getLogger('base')
13 |
14 |
15 | class SRModel(BaseModel):
16 | def __init__(self, opt):
17 | super(SRModel, self).__init__(opt)
18 |
19 | if opt['dist']:
20 | self.rank = torch.distributed.get_rank()
21 | else:
22 | self.rank = -1 # non dist training
23 | train_opt = opt['train']
24 |
25 | # define network and load pretrained models
26 | self.netG = networks.define_G(opt).to(self.device)
27 | if opt['dist']:
28 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
29 | else:
30 | self.netG = DataParallel(self.netG)
31 | # print network
32 | self.print_network()
33 | self.load()
34 |
35 | if self.is_train:
36 | self.netG.train()
37 |
38 | # loss
39 | loss_type = train_opt['pixel_criterion']
40 | if loss_type == 'l1':
41 | self.cri_pix = nn.L1Loss().to(self.device)
42 | elif loss_type == 'l2':
43 | self.cri_pix = nn.MSELoss().to(self.device)
44 | elif loss_type == 'cb':
45 | self.cri_pix = CharbonnierLoss().to(self.device)
46 | else:
47 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
48 | self.l_pix_w = train_opt['pixel_weight']
49 |
50 | # optimizers
51 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
52 | optim_params = []
53 | if train_opt['finetune_adafm']:
54 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
55 | v.requires_grad = False
56 | if k.find('adafm') >= 0:
57 | v.requires_grad = True
58 | optim_params.append(v)
59 | logger.info('Params [{:s}] will optimize.'.format(k))
60 | else:
61 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
62 | if v.requires_grad:
63 | optim_params.append(v)
64 | else:
65 | if self.rank <= 0:
66 | logger.warning('Params [{:s}] will not optimize.'.format(k))
67 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
68 | weight_decay=wd_G,
69 | betas=(train_opt['beta1'], train_opt['beta2']))
70 | self.optimizers.append(self.optimizer_G)
71 |
72 | # schedulers
73 | if train_opt['lr_scheme'] == 'MultiStepLR':
74 | for optimizer in self.optimizers:
75 | self.schedulers.append(
76 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
77 | restarts=train_opt['restarts'],
78 | weights=train_opt['restart_weights'],
79 | gamma=train_opt['lr_gamma'],
80 | clear_state=train_opt['clear_state']))
81 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
82 | for optimizer in self.optimizers:
83 | self.schedulers.append(
84 | lr_scheduler.CosineAnnealingLR_Restart(
85 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
86 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
87 | else:
88 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
89 |
90 | self.log_dict = OrderedDict()
91 |
92 | def feed_data(self, data, need_GT=True, need_cond=False):
93 | self.var_L = data['LQ'].to(self.device) # LQ
94 | if need_GT:
95 | self.real_H = data['GT'].to(self.device) # GT
96 | if need_cond:
97 | self.cond = data['cond'].to(self.device) # cond
98 | self.input = [self.var_L, self.cond]
99 | else:
100 | self.input = self.var_L
101 |
102 | def optimize_parameters(self, step):
103 | self.optimizer_G.zero_grad()
104 | self.fake_H = self.netG(self.input)
105 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
106 | l_pix.backward()
107 | self.optimizer_G.step()
108 |
109 | # set log
110 | self.log_dict['l_pix'] = l_pix.item()
111 |
112 | def test(self):
113 | self.netG.eval()
114 | with torch.no_grad():
115 | self.fake_H = self.netG(self.input)
116 | self.netG.train()
117 |
118 | def test_x8(self):
119 | # from https://github.com/thstkdgus35/EDSR-PyTorch
120 | self.netG.eval()
121 |
122 | def _transform(v, op):
123 | # if self.precision != 'single': v = v.float()
124 | v2np = v.data.cpu().numpy()
125 | if op == 'v':
126 | tfnp = v2np[:, :, :, ::-1].copy()
127 | elif op == 'h':
128 | tfnp = v2np[:, :, ::-1, :].copy()
129 | elif op == 't':
130 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
131 |
132 | ret = torch.Tensor(tfnp).to(self.device)
133 | # if self.precision == 'half': ret = ret.half()
134 |
135 | return ret
136 |
137 | lr_list = [self.var_L]
138 | for tf in 'v', 'h', 't':
139 | lr_list.extend([_transform(t, tf) for t in lr_list])
140 | with torch.no_grad():
141 | sr_list = [self.netG(aug) for aug in lr_list]
142 | for i in range(len(sr_list)):
143 | if i > 3:
144 | sr_list[i] = _transform(sr_list[i], 't')
145 | if i % 4 > 1:
146 | sr_list[i] = _transform(sr_list[i], 'h')
147 | if (i % 4) % 2 == 1:
148 | sr_list[i] = _transform(sr_list[i], 'v')
149 |
150 | output_cat = torch.cat(sr_list, dim=0)
151 | self.fake_H = output_cat.mean(dim=0, keepdim=True)
152 | self.netG.train()
153 |
154 | def get_current_log(self):
155 | return self.log_dict
156 |
157 | def get_current_visuals(self, need_GT=True):
158 | out_dict = OrderedDict()
159 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
160 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
161 | if need_GT:
162 | out_dict['GT'] = self.real_H.detach()[0].float().cpu()
163 | return out_dict
164 |
165 | def print_network(self):
166 | s, n = self.get_network_description(self.netG)
167 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
168 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
169 | self.netG.module.__class__.__name__)
170 | else:
171 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
172 | if self.rank <= 0:
173 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
174 | logger.info(s)
175 |
176 | def load(self):
177 | load_path_G = self.opt['path']['pretrain_model_G']
178 | if load_path_G is not None:
179 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
180 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
181 |
182 | def update(self, new_model_dict):
183 | if isinstance(self.netG, nn.DataParallel):
184 | network = self.netG.module
185 | network.load_state_dict(new_model_dict)
186 |
187 | def save(self, iter_label):
188 | self.save_network(self.netG, 'G', iter_label)
189 |
--------------------------------------------------------------------------------
/codes/models/Video_base_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.parallel import DataParallel, DistributedDataParallel
7 | import models.networks as networks
8 | import models.lr_scheduler as lr_scheduler
9 | from .base_model import BaseModel
10 | from models.loss import CharbonnierLoss
11 |
12 | logger = logging.getLogger('base')
13 |
14 |
15 | class VideoBaseModel(BaseModel):
16 | def __init__(self, opt):
17 | super(VideoBaseModel, self).__init__(opt)
18 |
19 | if opt['dist']:
20 | self.rank = torch.distributed.get_rank()
21 | else:
22 | self.rank = -1 # non dist training
23 | train_opt = opt['train']
24 |
25 | # define network and load pretrained models
26 | self.netG = networks.define_G(opt).to(self.device)
27 | if opt['dist']:
28 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
29 | else:
30 | self.netG = DataParallel(self.netG)
31 | # print network
32 | self.print_network()
33 | self.load()
34 |
35 | if self.is_train:
36 | self.netG.train()
37 |
38 | #### loss
39 | loss_type = train_opt['pixel_criterion']
40 | if loss_type == 'l1':
41 | self.cri_pix = nn.L1Loss(reduction='sum').to(self.device)
42 | elif loss_type == 'l2':
43 | self.cri_pix = nn.MSELoss(reduction='sum').to(self.device)
44 | elif loss_type == 'cb':
45 | self.cri_pix = CharbonnierLoss().to(self.device)
46 | else:
47 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
48 | self.l_pix_w = train_opt['pixel_weight']
49 |
50 | #### optimizers
51 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
52 | if train_opt['ft_tsa_only']:
53 | normal_params = []
54 | tsa_fusion_params = []
55 | for k, v in self.netG.named_parameters():
56 | if v.requires_grad:
57 | if 'tsa_fusion' in k:
58 | tsa_fusion_params.append(v)
59 | else:
60 | normal_params.append(v)
61 | else:
62 | if self.rank <= 0:
63 | logger.warning('Params [{:s}] will not optimize.'.format(k))
64 | optim_params = [
65 | { # add normal params first
66 | 'params': normal_params,
67 | 'lr': train_opt['lr_G']
68 | },
69 | {
70 | 'params': tsa_fusion_params,
71 | 'lr': train_opt['lr_G']
72 | },
73 | ]
74 | else:
75 | optim_params = []
76 | for k, v in self.netG.named_parameters():
77 | if v.requires_grad:
78 | optim_params.append(v)
79 | else:
80 | if self.rank <= 0:
81 | logger.warning('Params [{:s}] will not optimize.'.format(k))
82 |
83 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
84 | weight_decay=wd_G,
85 | betas=(train_opt['beta1'], train_opt['beta2']))
86 | self.optimizers.append(self.optimizer_G)
87 |
88 | #### schedulers
89 | if train_opt['lr_scheme'] == 'MultiStepLR':
90 | for optimizer in self.optimizers:
91 | self.schedulers.append(
92 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
93 | restarts=train_opt['restarts'],
94 | weights=train_opt['restart_weights'],
95 | gamma=train_opt['lr_gamma'],
96 | clear_state=train_opt['clear_state']))
97 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
98 | for optimizer in self.optimizers:
99 | self.schedulers.append(
100 | lr_scheduler.CosineAnnealingLR_Restart(
101 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
102 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
103 | else:
104 | raise NotImplementedError()
105 |
106 | self.log_dict = OrderedDict()
107 |
108 | def feed_data(self, data, need_GT=True):
109 | self.var_L = data['LQs'].to(self.device)
110 | if need_GT:
111 | self.real_H = data['GT'].to(self.device)
112 |
113 | def set_params_lr_zero(self):
114 | # fix normal module
115 | self.optimizers[0].param_groups[0]['lr'] = 0
116 |
117 | def optimize_parameters(self, step):
118 | if self.opt['train']['ft_tsa_only'] and step < self.opt['train']['ft_tsa_only']:
119 | self.set_params_lr_zero()
120 |
121 | self.optimizer_G.zero_grad()
122 | self.fake_H = self.netG(self.var_L)
123 |
124 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H)
125 | l_pix.backward()
126 | self.optimizer_G.step()
127 |
128 | # set log
129 | self.log_dict['l_pix'] = l_pix.item()
130 |
131 | def test(self):
132 | self.netG.eval()
133 | with torch.no_grad():
134 | self.fake_H = self.netG(self.var_L)
135 | self.netG.train()
136 |
137 | def get_current_log(self):
138 | return self.log_dict
139 |
140 | def get_current_visuals(self, need_GT=True):
141 | out_dict = OrderedDict()
142 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
143 | out_dict['rlt'] = self.fake_H.detach()[0].float().cpu()
144 | if need_GT:
145 | out_dict['GT'] = self.real_H.detach()[0].float().cpu()
146 | return out_dict
147 |
148 | def print_network(self):
149 | s, n = self.get_network_description(self.netG)
150 | if isinstance(self.netG, nn.DataParallel):
151 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
152 | self.netG.module.__class__.__name__)
153 | else:
154 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
155 | if self.rank <= 0:
156 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
157 | logger.info(s)
158 |
159 | def load(self):
160 | load_path_G = self.opt['path']['pretrain_model_G']
161 | if load_path_G is not None:
162 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
163 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
164 |
165 | def save(self, iter_label):
166 | self.save_network(self.netG, 'G', iter_label)
167 |
--------------------------------------------------------------------------------
/codes/models/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger('base')
3 |
4 |
5 | def create_model(opt):
6 | model = opt['model']
7 | # image restoration
8 | if model == 'sr': # PSNR-oriented super resolution
9 | from .SR_model import SRModel as M
10 | elif model == 'srgan': # GAN-based super resolution, SRGAN / ESRGAN
11 | from .SRGAN_model import SRGANModel as M
12 | # video restoration
13 | elif model == 'video_base':
14 | from .Video_base_model import VideoBaseModel as M
15 | else:
16 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
17 | m = M(opt)
18 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
19 | return m
20 |
--------------------------------------------------------------------------------
/codes/models/archs/CSRNet_arch.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import math
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class Condition(nn.Module):
9 | def __init__(self, in_nc=3, nf=32):
10 | super(Condition, self).__init__()
11 | stride = 2
12 | pad = 0
13 | self.pad = nn.ZeroPad2d(1)
14 | self.conv1 = nn.Conv2d(in_nc, nf, 7, stride, pad, bias=True)
15 | self.conv2 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True)
16 | self.conv3 = nn.Conv2d(nf, nf, 3, stride, pad, bias=True)
17 | self.act = nn.ReLU(inplace=True)
18 |
19 | def forward(self, x):
20 | conv1_out = self.act(self.conv1(self.pad(x)))
21 | conv2_out = self.act(self.conv2(self.pad(conv1_out)))
22 | conv3_out = self.act(self.conv3(self.pad(conv2_out)))
23 | out = torch.mean(conv3_out, dim=[2, 3], keepdim=False)
24 |
25 | return out
26 |
27 |
28 | # 3layers with control
29 | class CSRNet(nn.Module):
30 | def __init__(self, in_nc=3, out_nc=3, base_nf=64, cond_nf=32):
31 | super(CSRNet, self).__init__()
32 |
33 | self.base_nf = base_nf
34 | self.out_nc = out_nc
35 |
36 | self.cond_net = Condition(in_nc=in_nc, nf=cond_nf)
37 |
38 | self.cond_scale1 = nn.Linear(cond_nf, base_nf, bias=True)
39 | self.cond_scale2 = nn.Linear(cond_nf, base_nf, bias=True)
40 | self.cond_scale3 = nn.Linear(cond_nf, 3, bias=True)
41 |
42 | self.cond_shift1 = nn.Linear(cond_nf, base_nf, bias=True)
43 | self.cond_shift2 = nn.Linear(cond_nf, base_nf, bias=True)
44 | self.cond_shift3 = nn.Linear(cond_nf, 3, bias=True)
45 |
46 | self.conv1 = nn.Conv2d(in_nc, base_nf, 1, 1, bias=True)
47 | self.conv2 = nn.Conv2d(base_nf, base_nf, 1, 1, bias=True)
48 | self.conv3 = nn.Conv2d(base_nf, out_nc, 1, 1, bias=True)
49 |
50 | self.act = nn.ReLU(inplace=True)
51 |
52 |
53 | def forward(self, x):
54 | cond = self.cond_net(x)
55 |
56 | scale1 = self.cond_scale1(cond)
57 | shift1 = self.cond_shift1(cond)
58 |
59 | scale2 = self.cond_scale2(cond)
60 | shift2 = self.cond_shift2(cond)
61 |
62 | scale3 = self.cond_scale3(cond)
63 | shift3 = self.cond_shift3(cond)
64 |
65 | out = self.conv1(x)
66 | out = out * scale1.view(-1, self.base_nf, 1, 1) + shift1.view(-1, self.base_nf, 1, 1) + out
67 | out = self.act(out)
68 |
69 |
70 | out = self.conv2(out)
71 | out = out * scale2.view(-1, self.base_nf, 1, 1) + shift2.view(-1, self.base_nf, 1, 1) + out
72 | out = self.act(out)
73 |
74 | out = self.conv3(out)
75 | out = out * scale3.view(-1, self.out_nc, 1, 1) + shift3.view(-1, self.out_nc, 1, 1) + out
76 | return out
--------------------------------------------------------------------------------
/codes/models/archs/EDVR_arch.py:
--------------------------------------------------------------------------------
1 | ''' network architecture for EDVR '''
2 | import functools
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import models.archs.arch_util as arch_util
7 | try:
8 | from models.archs.dcn.deform_conv import ModulatedDeformConvPack as DCN
9 | except ImportError:
10 | raise ImportError('Failed to import DCNv2 module.')
11 |
12 |
13 | class Predeblur_ResNet_Pyramid(nn.Module):
14 | def __init__(self, nf=128, HR_in=False):
15 | '''
16 | HR_in: True if the inputs are high spatial size
17 | '''
18 |
19 | super(Predeblur_ResNet_Pyramid, self).__init__()
20 | self.HR_in = True if HR_in else False
21 | if self.HR_in:
22 | self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
23 | self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
24 | self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
25 | else:
26 | self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
27 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
28 | self.RB_L1_1 = basic_block()
29 | self.RB_L1_2 = basic_block()
30 | self.RB_L1_3 = basic_block()
31 | self.RB_L1_4 = basic_block()
32 | self.RB_L1_5 = basic_block()
33 | self.RB_L2_1 = basic_block()
34 | self.RB_L2_2 = basic_block()
35 | self.RB_L3_1 = basic_block()
36 | self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
37 | self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
38 |
39 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
40 |
41 | def forward(self, x):
42 | if self.HR_in:
43 | L1_fea = self.lrelu(self.conv_first_1(x))
44 | L1_fea = self.lrelu(self.conv_first_2(L1_fea))
45 | L1_fea = self.lrelu(self.conv_first_3(L1_fea))
46 | else:
47 | L1_fea = self.lrelu(self.conv_first(x))
48 | L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea))
49 | L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea))
50 | L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear',
51 | align_corners=False)
52 | L2_fea = self.RB_L2_1(L2_fea) + L3_fea
53 | L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear',
54 | align_corners=False)
55 | L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea
56 | out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea)))
57 | return out
58 |
59 |
60 | class PCD_Align(nn.Module):
61 | ''' Alignment module using Pyramid, Cascading and Deformable convolution
62 | with 3 pyramid levels.
63 | '''
64 |
65 | def __init__(self, nf=64, groups=8):
66 | super(PCD_Align, self).__init__()
67 | # L3: level 3, 1/4 spatial size
68 | self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
69 | self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
70 | self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
71 | extra_offset_mask=True)
72 | # L2: level 2, 1/2 spatial size
73 | self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
74 | self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
75 | self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
76 | self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
77 | extra_offset_mask=True)
78 | self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
79 | # L1: level 1, original spatial size
80 | self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
81 | self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
82 | self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
83 | self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
84 | extra_offset_mask=True)
85 | self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
86 | # Cascading DCN
87 | self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
88 | self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
89 |
90 | self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
91 | extra_offset_mask=True)
92 |
93 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
94 |
95 | def forward(self, nbr_fea_l, ref_fea_l):
96 | '''align other neighboring frames to the reference frame in the feature level
97 | nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
98 | '''
99 | # L3
100 | L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1)
101 | L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
102 | L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
103 | L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))
104 | # L2
105 | L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1)
106 | L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
107 | L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)
108 | L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))
109 | L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
110 | L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
111 | L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)
112 | L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
113 | # L1
114 | L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1)
115 | L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
116 | L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)
117 | L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
118 | L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
119 | L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])
120 | L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)
121 | L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
122 | # Cascading
123 | offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)
124 | offset = self.lrelu(self.cas_offset_conv1(offset))
125 | offset = self.lrelu(self.cas_offset_conv2(offset))
126 | L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))
127 |
128 | return L1_fea
129 |
130 |
131 | class TSA_Fusion(nn.Module):
132 | ''' Temporal Spatial Attention fusion module
133 | Temporal: correlation;
134 | Spatial: 3 pyramid levels.
135 | '''
136 |
137 | def __init__(self, nf=64, nframes=5, center=2):
138 | super(TSA_Fusion, self).__init__()
139 | self.center = center
140 | # temporal attention (before fusion conv)
141 | self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
142 | self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
143 |
144 | # fusion conv: using 1x1 to save parameters and computation
145 | self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
146 |
147 | # spatial attention (after fusion conv)
148 | self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
149 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
150 | self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
151 | self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True)
152 | self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
153 | self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True)
154 | self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
155 | self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
156 | self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
157 | self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
158 | self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
159 | self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True)
160 |
161 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
162 |
163 | def forward(self, aligned_fea):
164 | B, N, C, H, W = aligned_fea.size() # N video frames
165 | #### temporal attention
166 | emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone())
167 | emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W]
168 |
169 | cor_l = []
170 | for i in range(N):
171 | emb_nbr = emb[:, i, :, :, :]
172 | cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W
173 | cor_l.append(cor_tmp)
174 | cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W
175 | cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W)
176 | aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob
177 |
178 | #### fusion
179 | fea = self.lrelu(self.fea_fusion(aligned_fea))
180 |
181 | #### spatial attention
182 | att = self.lrelu(self.sAtt_1(aligned_fea))
183 | att_max = self.maxpool(att)
184 | att_avg = self.avgpool(att)
185 | att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
186 | # pyramid levels
187 | att_L = self.lrelu(self.sAtt_L1(att))
188 | att_max = self.maxpool(att_L)
189 | att_avg = self.avgpool(att_L)
190 | att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
191 | att_L = self.lrelu(self.sAtt_L3(att_L))
192 | att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
193 |
194 | att = self.lrelu(self.sAtt_3(att))
195 | att = att + att_L
196 | att = self.lrelu(self.sAtt_4(att))
197 | att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
198 | att = self.sAtt_5(att)
199 | att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
200 | att = torch.sigmoid(att)
201 |
202 | fea = fea * att * 2 + att_add
203 | return fea
204 |
205 |
206 | class EDVR(nn.Module):
207 | def __init__(self, nf=64, nframes=5, groups=8, front_RBs=5, back_RBs=10, center=None,
208 | predeblur=False, HR_in=False, w_TSA=True):
209 | super(EDVR, self).__init__()
210 | self.nf = nf
211 | self.center = nframes // 2 if center is None else center
212 | self.is_predeblur = True if predeblur else False
213 | self.HR_in = True if HR_in else False
214 | self.w_TSA = w_TSA
215 | ResidualBlock_noBN_f = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
216 |
217 | #### extract features (for each frame)
218 | if self.is_predeblur:
219 | self.pre_deblur = Predeblur_ResNet_Pyramid(nf=nf, HR_in=self.HR_in)
220 | self.conv_1x1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
221 | else:
222 | if self.HR_in:
223 | self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
224 | self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
225 | self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
226 | else:
227 | self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
228 | self.feature_extraction = arch_util.make_layer(ResidualBlock_noBN_f, front_RBs)
229 | self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
230 | self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
231 | self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
232 | self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
233 |
234 | self.pcd_align = PCD_Align(nf=nf, groups=groups)
235 | if self.w_TSA:
236 | self.tsa_fusion = TSA_Fusion(nf=nf, nframes=nframes, center=self.center)
237 | else:
238 | self.tsa_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
239 |
240 | #### reconstruction
241 | self.recon_trunk = arch_util.make_layer(ResidualBlock_noBN_f, back_RBs)
242 | #### upsampling
243 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
244 | self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True)
245 | self.pixel_shuffle = nn.PixelShuffle(2)
246 | self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True)
247 | self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True)
248 |
249 | #### activation function
250 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
251 |
252 | def forward(self, x):
253 | B, N, C, H, W = x.size() # N video frames
254 | x_center = x[:, self.center, :, :, :].contiguous()
255 |
256 | #### extract LR features
257 | # L1
258 | if self.is_predeblur:
259 | L1_fea = self.pre_deblur(x.view(-1, C, H, W))
260 | L1_fea = self.conv_1x1(L1_fea)
261 | if self.HR_in:
262 | H, W = H // 4, W // 4
263 | else:
264 | if self.HR_in:
265 | L1_fea = self.lrelu(self.conv_first_1(x.view(-1, C, H, W)))
266 | L1_fea = self.lrelu(self.conv_first_2(L1_fea))
267 | L1_fea = self.lrelu(self.conv_first_3(L1_fea))
268 | H, W = H // 4, W // 4
269 | else:
270 | L1_fea = self.lrelu(self.conv_first(x.view(-1, C, H, W)))
271 | L1_fea = self.feature_extraction(L1_fea)
272 | # L2
273 | L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
274 | L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea))
275 | # L3
276 | L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
277 | L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea))
278 |
279 | L1_fea = L1_fea.view(B, N, -1, H, W)
280 | L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2)
281 | L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4)
282 |
283 | #### pcd align
284 | # ref feature list
285 | ref_fea_l = [
286 | L1_fea[:, self.center, :, :, :].clone(), L2_fea[:, self.center, :, :, :].clone(),
287 | L3_fea[:, self.center, :, :, :].clone()
288 | ]
289 | aligned_fea = []
290 | for i in range(N):
291 | nbr_fea_l = [
292 | L1_fea[:, i, :, :, :].clone(), L2_fea[:, i, :, :, :].clone(),
293 | L3_fea[:, i, :, :, :].clone()
294 | ]
295 | aligned_fea.append(self.pcd_align(nbr_fea_l, ref_fea_l))
296 | aligned_fea = torch.stack(aligned_fea, dim=1) # [B, N, C, H, W]
297 |
298 | if not self.w_TSA:
299 | aligned_fea = aligned_fea.view(B, -1, H, W)
300 | fea = self.tsa_fusion(aligned_fea)
301 |
302 | out = self.recon_trunk(fea)
303 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
304 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
305 | out = self.lrelu(self.HRconv(out))
306 | out = self.conv_last(out)
307 | if self.HR_in:
308 | base = x_center
309 | else:
310 | base = F.interpolate(x_center, scale_factor=4, mode='bilinear', align_corners=False)
311 | out += base
312 | return out
313 |
--------------------------------------------------------------------------------
/codes/models/archs/RRDBNet_arch.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import models.archs.arch_util as arch_util
6 |
7 |
8 | class ResidualDenseBlock_5C(nn.Module):
9 | def __init__(self, nf=64, gc=32, bias=True):
10 | super(ResidualDenseBlock_5C, self).__init__()
11 | # gc: growth channel, i.e. intermediate channels
12 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
13 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
14 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
15 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
16 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
17 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
18 |
19 | # initialization
20 | arch_util.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5],
21 | 0.1)
22 |
23 | def forward(self, x):
24 | x1 = self.lrelu(self.conv1(x))
25 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
26 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
27 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
28 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
29 | return x5 * 0.2 + x
30 |
31 |
32 | class RRDB(nn.Module):
33 | '''Residual in Residual Dense Block'''
34 |
35 | def __init__(self, nf, gc=32):
36 | super(RRDB, self).__init__()
37 | self.RDB1 = ResidualDenseBlock_5C(nf, gc)
38 | self.RDB2 = ResidualDenseBlock_5C(nf, gc)
39 | self.RDB3 = ResidualDenseBlock_5C(nf, gc)
40 |
41 | def forward(self, x):
42 | out = self.RDB1(x)
43 | out = self.RDB2(out)
44 | out = self.RDB3(out)
45 | return out * 0.2 + x
46 |
47 |
48 | class RRDBNet(nn.Module):
49 | def __init__(self, in_nc, out_nc, nf, nb, gc=32):
50 | super(RRDBNet, self).__init__()
51 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
52 |
53 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
54 | self.RRDB_trunk = arch_util.make_layer(RRDB_block_f, nb)
55 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
56 | #### upsampling
57 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
58 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
59 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
60 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
61 |
62 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
63 |
64 | def forward(self, x):
65 | fea = self.conv_first(x)
66 | trunk = self.trunk_conv(self.RRDB_trunk(fea))
67 | fea = fea + trunk
68 |
69 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
70 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
71 | out = self.conv_last(self.lrelu(self.HRconv(fea)))
72 |
73 | return out
74 |
--------------------------------------------------------------------------------
/codes/models/archs/SRResNet_arch.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import models.archs.arch_util as arch_util
5 |
6 |
7 | class MSRResNet(nn.Module):
8 | ''' modified SRResNet'''
9 |
10 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4):
11 | super(MSRResNet, self).__init__()
12 | self.upscale = upscale
13 |
14 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
15 | basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
16 | self.recon_trunk = arch_util.make_layer(basic_block, nb)
17 |
18 | # upsampling
19 | if self.upscale == 2:
20 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
21 | self.pixel_shuffle = nn.PixelShuffle(2)
22 | elif self.upscale == 3:
23 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True)
24 | self.pixel_shuffle = nn.PixelShuffle(3)
25 | elif self.upscale == 4:
26 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
27 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True)
28 | self.pixel_shuffle = nn.PixelShuffle(2)
29 |
30 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
31 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
32 |
33 | # activation function
34 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
35 |
36 | # initialization
37 | arch_util.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last],
38 | 0.1)
39 | if self.upscale == 4:
40 | arch_util.initialize_weights(self.upconv2, 0.1)
41 |
42 | def forward(self, x):
43 | fea = self.lrelu(self.conv_first(x))
44 | out = self.recon_trunk(fea)
45 |
46 | if self.upscale == 4:
47 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
48 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
49 | elif self.upscale == 3 or self.upscale == 2:
50 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out)))
51 |
52 | out = self.conv_last(self.lrelu(self.HRconv(out)))
53 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
54 | out += base
55 | return out
56 |
--------------------------------------------------------------------------------
/codes/models/archs/TOF_arch.py:
--------------------------------------------------------------------------------
1 | '''PyTorch implementation of TOFlow
2 | Paper: Xue et al., Video Enhancement with Task-Oriented Flow, IJCV 2018
3 | Code reference:
4 | 1. https://github.com/anchen1011/toflow
5 | 2. https://github.com/Coldog2333/pytoflow
6 | '''
7 |
8 | import torch
9 | import torch.nn as nn
10 | from .arch_util import flow_warp
11 |
12 |
13 | def normalize(x):
14 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
15 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
16 | return (x - mean) / std
17 |
18 |
19 | def denormalize(x):
20 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).type_as(x)
21 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).type_as(x)
22 | return x * std + mean
23 |
24 |
25 | class SpyNet_Block(nn.Module):
26 | '''A submodule of SpyNet.'''
27 |
28 | def __init__(self):
29 | super(SpyNet_Block, self).__init__()
30 |
31 | self.block = nn.Sequential(
32 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
33 | nn.BatchNorm2d(32), nn.ReLU(inplace=True),
34 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
35 | nn.BatchNorm2d(64), nn.ReLU(inplace=True),
36 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
37 | nn.BatchNorm2d(32), nn.ReLU(inplace=True),
38 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
39 | nn.BatchNorm2d(16), nn.ReLU(inplace=True),
40 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
41 |
42 | def forward(self, x):
43 | '''
44 | input: x: [ref im, nbr im, initial flow] - (B, 8, H, W)
45 | output: estimated flow - (B, 2, H, W)
46 | '''
47 | return self.block(x)
48 |
49 |
50 | class SpyNet(nn.Module):
51 | '''SpyNet for estimating optical flow
52 | Ranjan et al., Optical Flow Estimation using a Spatial Pyramid Network, 2016'''
53 |
54 | def __init__(self):
55 | super(SpyNet, self).__init__()
56 |
57 | self.blocks = nn.ModuleList([SpyNet_Block() for _ in range(4)])
58 |
59 | def forward(self, ref, nbr):
60 | '''Estimating optical flow in coarse level, upsample, and estimate in fine level
61 | input: ref: reference image - [B, 3, H, W]
62 | nbr: the neighboring image to be warped - [B, 3, H, W]
63 | output: estimated optical flow - [B, 2, H, W]
64 | '''
65 | B, C, H, W = ref.size()
66 | ref = [ref]
67 | nbr = [nbr]
68 |
69 | for _ in range(3):
70 | ref.insert(
71 | 0,
72 | nn.functional.avg_pool2d(input=ref[0], kernel_size=2, stride=2,
73 | count_include_pad=False))
74 | nbr.insert(
75 | 0,
76 | nn.functional.avg_pool2d(input=nbr[0], kernel_size=2, stride=2,
77 | count_include_pad=False))
78 |
79 | flow = torch.zeros(B, 2, H // 16, W // 16).type_as(ref[0])
80 |
81 | for i in range(4):
82 | flow_up = nn.functional.interpolate(input=flow, scale_factor=2, mode='bilinear',
83 | align_corners=True) * 2.0
84 | flow = flow_up + self.blocks[i](torch.cat(
85 | [ref[i], flow_warp(nbr[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
86 | return flow
87 |
88 |
89 | class TOFlow(nn.Module):
90 | def __init__(self, adapt_official=False):
91 | super(TOFlow, self).__init__()
92 |
93 | self.SpyNet = SpyNet()
94 |
95 | self.conv_3x7_64_9x9 = nn.Conv2d(3 * 7, 64, 9, 1, 4)
96 | self.conv_64_64_9x9 = nn.Conv2d(64, 64, 9, 1, 4)
97 | self.conv_64_64_1x1 = nn.Conv2d(64, 64, 1)
98 | self.conv_64_3_1x1 = nn.Conv2d(64, 3, 1)
99 |
100 | self.relu = nn.ReLU(inplace=True)
101 |
102 | self.adapt_official = adapt_official # True if using translated official weights else False
103 |
104 | def forward(self, x):
105 | """
106 | input: x: input frames - [B, 7, 3, H, W]
107 | output: SR reference frame - [B, 3, H, W]
108 | """
109 |
110 | B, T, C, H, W = x.size()
111 | x = normalize(x.view(-1, C, H, W)).view(B, T, C, H, W)
112 |
113 | ref_idx = 3
114 | x_ref = x[:, ref_idx, :, :, :]
115 |
116 | # In the official torch code, the 0-th frame is the reference frame
117 | if self.adapt_official:
118 | x = x[:, [3, 0, 1, 2, 4, 5, 6], :, :, :]
119 | ref_idx = 0
120 |
121 | x_warped = []
122 | for i in range(7):
123 | if i == ref_idx:
124 | x_warped.append(x_ref)
125 | else:
126 | x_nbr = x[:, i, :, :, :]
127 | flow = self.SpyNet(x_ref, x_nbr).permute(0, 2, 3, 1)
128 | x_warped.append(flow_warp(x_nbr, flow))
129 | x_warped = torch.stack(x_warped, dim=1)
130 |
131 | x = x_warped.view(B, -1, H, W)
132 | x = self.relu(self.conv_3x7_64_9x9(x))
133 | x = self.relu(self.conv_64_64_9x9(x))
134 | x = self.relu(self.conv_64_64_1x1(x))
135 | x = self.conv_64_3_1x1(x) + x_ref
136 |
137 | return denormalize(x)
138 |
--------------------------------------------------------------------------------
/codes/models/archs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/codes/models/archs/__init__.py
--------------------------------------------------------------------------------
/codes/models/archs/arch_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 |
6 |
7 | def initialize_weights(net_l, scale=1):
8 | if not isinstance(net_l, list):
9 | net_l = [net_l]
10 | for net in net_l:
11 | for m in net.modules():
12 | if isinstance(m, nn.Conv2d):
13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14 | m.weight.data *= scale # for residual block
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.Linear):
18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | m.weight.data *= scale
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 | elif isinstance(m, nn.BatchNorm2d):
23 | init.constant_(m.weight, 1)
24 | init.constant_(m.bias.data, 0.0)
25 |
26 |
27 | def make_layer(block, n_layers):
28 | layers = []
29 | for _ in range(n_layers):
30 | layers.append(block())
31 | return nn.Sequential(*layers)
32 |
33 |
34 | class ResidualBlock_noBN(nn.Module):
35 | '''Residual block w/o BN
36 | ---Conv-ReLU-Conv-+-
37 | |________________|
38 | '''
39 |
40 | def __init__(self, nf=64):
41 | super(ResidualBlock_noBN, self).__init__()
42 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
43 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
44 |
45 | # initialization
46 | initialize_weights([self.conv1, self.conv2], 0.1)
47 |
48 | def forward(self, x):
49 | identity = x
50 | out = F.relu(self.conv1(x), inplace=True)
51 | out = self.conv2(out)
52 | return identity + out
53 |
54 |
55 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
56 | """Warp an image or feature map with optical flow
57 | Args:
58 | x (Tensor): size (N, C, H, W)
59 | flow (Tensor): size (N, H, W, 2), normal value
60 | interp_mode (str): 'nearest' or 'bilinear'
61 | padding_mode (str): 'zeros' or 'border' or 'reflection'
62 |
63 | Returns:
64 | Tensor: warped image or feature map
65 | """
66 | assert x.size()[-2:] == flow.size()[1:3]
67 | B, C, H, W = x.size()
68 | # mesh grid
69 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
70 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
71 | grid.requires_grad = False
72 | grid = grid.type_as(x)
73 | vgrid = grid + flow
74 | # scale grid to [-1,1]
75 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
76 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
77 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
78 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
79 | return output
80 |
--------------------------------------------------------------------------------
/codes/models/archs/dcn/__init__.py:
--------------------------------------------------------------------------------
1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack,
2 | deform_conv, modulated_deform_conv)
3 |
4 | __all__ = [
5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
6 | 'modulated_deform_conv'
7 | ]
8 |
--------------------------------------------------------------------------------
/codes/models/archs/dcn/deform_conv.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Function
7 | from torch.autograd.function import once_differentiable
8 | from torch.nn.modules.utils import _pair
9 |
10 | from . import deform_conv_cuda
11 |
12 | logger = logging.getLogger('base')
13 |
14 |
15 | class DeformConvFunction(Function):
16 | @staticmethod
17 | def forward(ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1,
18 | deformable_groups=1, im2col_step=64):
19 | if input is not None and input.dim() != 4:
20 | raise ValueError("Expected 4D tensor as input, got {}D tensor instead.".format(
21 | input.dim()))
22 | ctx.stride = _pair(stride)
23 | ctx.padding = _pair(padding)
24 | ctx.dilation = _pair(dilation)
25 | ctx.groups = groups
26 | ctx.deformable_groups = deformable_groups
27 | ctx.im2col_step = im2col_step
28 |
29 | ctx.save_for_backward(input, offset, weight)
30 |
31 | output = input.new_empty(
32 | DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
33 |
34 | ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
35 |
36 | if not input.is_cuda:
37 | raise NotImplementedError
38 | else:
39 | cur_im2col_step = min(ctx.im2col_step, input.shape[0])
40 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
41 | deform_conv_cuda.deform_conv_forward_cuda(input, weight, offset, output,
42 | ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
43 | weight.size(2), ctx.stride[1], ctx.stride[0],
44 | ctx.padding[1], ctx.padding[0],
45 | ctx.dilation[1], ctx.dilation[0], ctx.groups,
46 | ctx.deformable_groups, cur_im2col_step)
47 | return output
48 |
49 | @staticmethod
50 | @once_differentiable
51 | def backward(ctx, grad_output):
52 | input, offset, weight = ctx.saved_tensors
53 |
54 | grad_input = grad_offset = grad_weight = None
55 |
56 | if not grad_output.is_cuda:
57 | raise NotImplementedError
58 | else:
59 | cur_im2col_step = min(ctx.im2col_step, input.shape[0])
60 | assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
61 |
62 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
63 | grad_input = torch.zeros_like(input)
64 | grad_offset = torch.zeros_like(offset)
65 | deform_conv_cuda.deform_conv_backward_input_cuda(
66 | input, offset, grad_output, grad_input, grad_offset, weight, ctx.bufs_[0],
67 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
68 | ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
69 | ctx.deformable_groups, cur_im2col_step)
70 |
71 | if ctx.needs_input_grad[2]:
72 | grad_weight = torch.zeros_like(weight)
73 | deform_conv_cuda.deform_conv_backward_parameters_cuda(
74 | input, offset, grad_output, grad_weight, ctx.bufs_[0], ctx.bufs_[1],
75 | weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
76 | ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
77 | ctx.deformable_groups, 1, cur_im2col_step)
78 |
79 | return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
80 |
81 | @staticmethod
82 | def _output_size(input, weight, padding, dilation, stride):
83 | channels = weight.size(0)
84 | output_size = (input.size(0), channels)
85 | for d in range(input.dim() - 2):
86 | in_size = input.size(d + 2)
87 | pad = padding[d]
88 | kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
89 | stride_ = stride[d]
90 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
91 | if not all(map(lambda s: s > 0, output_size)):
92 | raise ValueError("convolution input is too small (output would be {})".format('x'.join(
93 | map(str, output_size))))
94 | return output_size
95 |
96 |
97 | class ModulatedDeformConvFunction(Function):
98 | @staticmethod
99 | def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1,
100 | groups=1, deformable_groups=1):
101 | ctx.stride = stride
102 | ctx.padding = padding
103 | ctx.dilation = dilation
104 | ctx.groups = groups
105 | ctx.deformable_groups = deformable_groups
106 | ctx.with_bias = bias is not None
107 | if not ctx.with_bias:
108 | bias = input.new_empty(1) # fake tensor
109 | if not input.is_cuda:
110 | raise NotImplementedError
111 | if weight.requires_grad or mask.requires_grad or offset.requires_grad \
112 | or input.requires_grad:
113 | ctx.save_for_backward(input, offset, mask, weight, bias)
114 | output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
115 | ctx._bufs = [input.new_empty(0), input.new_empty(0)]
116 | deform_conv_cuda.modulated_deform_conv_cuda_forward(
117 | input, weight, bias, ctx._bufs[0], offset, mask, output, ctx._bufs[1], weight.shape[2],
118 | weight.shape[3], ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation,
119 | ctx.dilation, ctx.groups, ctx.deformable_groups, ctx.with_bias)
120 | return output
121 |
122 | @staticmethod
123 | @once_differentiable
124 | def backward(ctx, grad_output):
125 | if not grad_output.is_cuda:
126 | raise NotImplementedError
127 | input, offset, mask, weight, bias = ctx.saved_tensors
128 | grad_input = torch.zeros_like(input)
129 | grad_offset = torch.zeros_like(offset)
130 | grad_mask = torch.zeros_like(mask)
131 | grad_weight = torch.zeros_like(weight)
132 | grad_bias = torch.zeros_like(bias)
133 | deform_conv_cuda.modulated_deform_conv_cuda_backward(
134 | input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], grad_input, grad_weight,
135 | grad_bias, grad_offset, grad_mask, grad_output, weight.shape[2], weight.shape[3],
136 | ctx.stride, ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
137 | ctx.groups, ctx.deformable_groups, ctx.with_bias)
138 | if not ctx.with_bias:
139 | grad_bias = None
140 |
141 | return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None,
142 | None)
143 |
144 | @staticmethod
145 | def _infer_shape(ctx, input, weight):
146 | n = input.size(0)
147 | channels_out = weight.size(0)
148 | height, width = input.shape[2:4]
149 | kernel_h, kernel_w = weight.shape[2:4]
150 | height_out = (height + 2 * ctx.padding - (ctx.dilation *
151 | (kernel_h - 1) + 1)) // ctx.stride + 1
152 | width_out = (width + 2 * ctx.padding - (ctx.dilation *
153 | (kernel_w - 1) + 1)) // ctx.stride + 1
154 | return n, channels_out, height_out, width_out
155 |
156 |
157 | deform_conv = DeformConvFunction.apply
158 | modulated_deform_conv = ModulatedDeformConvFunction.apply
159 |
160 |
161 | class DeformConv(nn.Module):
162 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
163 | groups=1, deformable_groups=1, bias=False):
164 | super(DeformConv, self).__init__()
165 |
166 | assert not bias
167 | assert in_channels % groups == 0, \
168 | 'in_channels {} cannot be divisible by groups {}'.format(
169 | in_channels, groups)
170 | assert out_channels % groups == 0, \
171 | 'out_channels {} cannot be divisible by groups {}'.format(
172 | out_channels, groups)
173 |
174 | self.in_channels = in_channels
175 | self.out_channels = out_channels
176 | self.kernel_size = _pair(kernel_size)
177 | self.stride = _pair(stride)
178 | self.padding = _pair(padding)
179 | self.dilation = _pair(dilation)
180 | self.groups = groups
181 | self.deformable_groups = deformable_groups
182 |
183 | self.weight = nn.Parameter(
184 | torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
185 |
186 | self.reset_parameters()
187 |
188 | def reset_parameters(self):
189 | n = self.in_channels
190 | for k in self.kernel_size:
191 | n *= k
192 | stdv = 1. / math.sqrt(n)
193 | self.weight.data.uniform_(-stdv, stdv)
194 |
195 | def forward(self, x, offset):
196 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
197 | self.groups, self.deformable_groups)
198 |
199 |
200 | class DeformConvPack(DeformConv):
201 | def __init__(self, *args, **kwargs):
202 | super(DeformConvPack, self).__init__(*args, **kwargs)
203 |
204 | self.conv_offset = nn.Conv2d(
205 | self.in_channels,
206 | self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
207 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
208 | bias=True)
209 | self.init_offset()
210 |
211 | def init_offset(self):
212 | self.conv_offset.weight.data.zero_()
213 | self.conv_offset.bias.data.zero_()
214 |
215 | def forward(self, x):
216 | offset = self.conv_offset(x)
217 | return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation,
218 | self.groups, self.deformable_groups)
219 |
220 |
221 | class ModulatedDeformConv(nn.Module):
222 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,
223 | groups=1, deformable_groups=1, bias=True):
224 | super(ModulatedDeformConv, self).__init__()
225 | self.in_channels = in_channels
226 | self.out_channels = out_channels
227 | self.kernel_size = _pair(kernel_size)
228 | self.stride = stride
229 | self.padding = padding
230 | self.dilation = dilation
231 | self.groups = groups
232 | self.deformable_groups = deformable_groups
233 | self.with_bias = bias
234 |
235 | self.weight = nn.Parameter(
236 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
237 | if bias:
238 | self.bias = nn.Parameter(torch.Tensor(out_channels))
239 | else:
240 | self.register_parameter('bias', None)
241 | self.reset_parameters()
242 |
243 | def reset_parameters(self):
244 | n = self.in_channels
245 | for k in self.kernel_size:
246 | n *= k
247 | stdv = 1. / math.sqrt(n)
248 | self.weight.data.uniform_(-stdv, stdv)
249 | if self.bias is not None:
250 | self.bias.data.zero_()
251 |
252 | def forward(self, x, offset, mask):
253 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
254 | self.padding, self.dilation, self.groups,
255 | self.deformable_groups)
256 |
257 |
258 | class ModulatedDeformConvPack(ModulatedDeformConv):
259 | def __init__(self, *args, extra_offset_mask=False, **kwargs):
260 | super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
261 |
262 | self.extra_offset_mask = extra_offset_mask
263 | self.conv_offset_mask = nn.Conv2d(
264 | self.in_channels,
265 | self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
266 | kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding),
267 | bias=True)
268 | self.init_offset()
269 |
270 | def init_offset(self):
271 | self.conv_offset_mask.weight.data.zero_()
272 | self.conv_offset_mask.bias.data.zero_()
273 |
274 | def forward(self, x):
275 | if self.extra_offset_mask:
276 | # x = [input, features]
277 | out = self.conv_offset_mask(x[1])
278 | x = x[0]
279 | else:
280 | out = self.conv_offset_mask(x)
281 | o1, o2, mask = torch.chunk(out, 3, dim=1)
282 | offset = torch.cat((o1, o2), dim=1)
283 | mask = torch.sigmoid(mask)
284 |
285 | offset_mean = torch.mean(torch.abs(offset))
286 | if offset_mean > 100:
287 | logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
288 |
289 | return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride,
290 | self.padding, self.dilation, self.groups,
291 | self.deformable_groups)
292 |
--------------------------------------------------------------------------------
/codes/models/archs/dcn/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 |
5 | def make_cuda_ext(name, sources):
6 |
7 | return CUDAExtension(
8 | name='{}'.format(name), sources=[p for p in sources], extra_compile_args={
9 | 'cxx': [],
10 | 'nvcc': [
11 | '-D__CUDA_NO_HALF_OPERATORS__',
12 | '-D__CUDA_NO_HALF_CONVERSIONS__',
13 | '-D__CUDA_NO_HALF2_OPERATORS__',
14 | ]
15 | })
16 |
17 |
18 | setup(
19 | name='deform_conv', ext_modules=[
20 | make_cuda_ext(name='deform_conv_cuda',
21 | sources=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu'])
22 | ], cmdclass={'build_ext': BuildExtension}, zip_safe=False)
23 |
--------------------------------------------------------------------------------
/codes/models/archs/discriminator_vgg_arch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 |
5 |
6 | class Discriminator_VGG_128(nn.Module):
7 | def __init__(self, in_nc, nf):
8 | super(Discriminator_VGG_128, self).__init__()
9 | # [64, 128, 128]
10 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
11 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False)
12 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True)
13 | # [64, 64, 64]
14 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False)
15 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True)
16 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False)
17 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True)
18 | # [128, 32, 32]
19 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False)
20 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True)
21 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False)
22 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True)
23 | # [256, 16, 16]
24 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False)
25 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True)
26 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
27 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True)
28 | # [512, 8, 8]
29 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False)
30 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True)
31 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False)
32 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True)
33 |
34 | self.linear1 = nn.Linear(512 * 4 * 4, 100)
35 | self.linear2 = nn.Linear(100, 1)
36 |
37 | # activation function
38 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
39 |
40 | def forward(self, x):
41 | fea = self.lrelu(self.conv0_0(x))
42 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea)))
43 |
44 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea)))
45 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea)))
46 |
47 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea)))
48 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea)))
49 |
50 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea)))
51 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea)))
52 |
53 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea)))
54 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea)))
55 |
56 | fea = fea.view(fea.size(0), -1)
57 | fea = self.lrelu(self.linear1(fea))
58 | out = self.linear2(fea)
59 | return out
60 |
61 |
62 | class VGGFeatureExtractor(nn.Module):
63 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True,
64 | device=torch.device('cpu')):
65 | super(VGGFeatureExtractor, self).__init__()
66 | self.use_input_norm = use_input_norm
67 | if use_bn:
68 | model = torchvision.models.vgg19_bn(pretrained=True)
69 | else:
70 | model = torchvision.models.vgg19(pretrained=True)
71 | if self.use_input_norm:
72 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device)
73 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1]
74 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device)
75 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1]
76 | self.register_buffer('mean', mean)
77 | self.register_buffer('std', std)
78 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)])
79 | # No need to BP to variable
80 | for k, v in self.features.named_parameters():
81 | v.requires_grad = False
82 |
83 | def forward(self, x):
84 | # Assume input range is [0, 1]
85 | if self.use_input_norm:
86 | x = (x - self.mean) / self.std
87 | output = self.features(x)
88 | return output
89 |
--------------------------------------------------------------------------------
/codes/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.parallel import DistributedDataParallel
6 |
7 |
8 | class BaseModel():
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
12 | self.is_train = opt['is_train']
13 | self.schedulers = []
14 | self.optimizers = []
15 |
16 | def feed_data(self, data):
17 | pass
18 |
19 | def optimize_parameters(self):
20 | pass
21 |
22 | def get_current_visuals(self):
23 | pass
24 |
25 | def get_current_losses(self):
26 | pass
27 |
28 | def print_network(self):
29 | pass
30 |
31 | def save(self, label):
32 | pass
33 |
34 | def load(self):
35 | pass
36 |
37 | def _set_lr(self, lr_groups_l):
38 | """Set learning rate for warmup
39 | lr_groups_l: list for lr_groups. each for a optimizer"""
40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
41 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
42 | param_group['lr'] = lr
43 |
44 | def _get_init_lr(self):
45 | """Get the initial lr, which is set by the scheduler"""
46 | init_lr_groups_l = []
47 | for optimizer in self.optimizers:
48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
49 | return init_lr_groups_l
50 |
51 | def update_learning_rate(self, cur_iter, warmup_iter=-1):
52 | for scheduler in self.schedulers:
53 | scheduler.step()
54 | # set up warm-up learning rate
55 | if cur_iter < warmup_iter:
56 | # get initial lr for each group
57 | init_lr_g_l = self._get_init_lr()
58 | # modify warming-up learning rates
59 | warm_up_lr_l = []
60 | for init_lr_g in init_lr_g_l:
61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
62 | # set learning rate
63 | self._set_lr(warm_up_lr_l)
64 |
65 | def get_current_learning_rate(self):
66 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
67 |
68 | def get_network_description(self, network):
69 | """Get the string and total parameters of the network"""
70 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
71 | network = network.module
72 | return str(network), sum(map(lambda x: x.numel(), network.parameters()))
73 |
74 | def save_network(self, network, network_label, iter_label):
75 | save_filename = '{}_{}.pth'.format(iter_label, network_label)
76 | save_path = os.path.join(self.opt['path']['models'], save_filename)
77 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
78 | network = network.module
79 | state_dict = network.state_dict()
80 | for key, param in state_dict.items():
81 | state_dict[key] = param.cpu()
82 | torch.save(state_dict, save_path)
83 |
84 | def load_network(self, load_path, network, strict=True):
85 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
86 | network = network.module
87 | load_net = torch.load(load_path)
88 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
89 | for k, v in load_net.items():
90 | if k.startswith('module.'):
91 | load_net_clean[k[7:]] = v
92 | else:
93 | load_net_clean[k] = v
94 | network.load_state_dict(load_net_clean, strict=strict)
95 |
96 | def save_training_state(self, epoch, iter_step):
97 | """Save training state during training, which will be used for resuming"""
98 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
99 | for s in self.schedulers:
100 | state['schedulers'].append(s.state_dict())
101 | for o in self.optimizers:
102 | state['optimizers'].append(o.state_dict())
103 | save_filename = '{}.state'.format(iter_step)
104 | save_path = os.path.join(self.opt['path']['training_state'], save_filename)
105 | torch.save(state, save_path)
106 |
107 | def resume_training(self, resume_state):
108 | """Resume the optimizers and schedulers for training"""
109 | resume_optimizers = resume_state['optimizers']
110 | resume_schedulers = resume_state['schedulers']
111 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
112 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
113 | for i, o in enumerate(resume_optimizers):
114 | self.optimizers[i].load_state_dict(o)
115 | for i, s in enumerate(resume_schedulers):
116 | self.schedulers[i].load_state_dict(s)
117 |
--------------------------------------------------------------------------------
/codes/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CharbonnierLoss(nn.Module):
6 | """Charbonnier Loss (L1)"""
7 |
8 | def __init__(self, eps=1e-6):
9 | super(CharbonnierLoss, self).__init__()
10 | self.eps = eps
11 |
12 | def forward(self, x, y):
13 | diff = x - y
14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps))
15 | return loss
16 |
17 |
18 | # Define GAN loss: [vanilla | lsgan | wgan-gp]
19 | class GANLoss(nn.Module):
20 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
21 | super(GANLoss, self).__init__()
22 | self.gan_type = gan_type.lower()
23 | self.real_label_val = real_label_val
24 | self.fake_label_val = fake_label_val
25 |
26 | if self.gan_type == 'gan' or self.gan_type == 'ragan':
27 | self.loss = nn.BCEWithLogitsLoss()
28 | elif self.gan_type == 'lsgan':
29 | self.loss = nn.MSELoss()
30 | elif self.gan_type == 'wgan-gp':
31 |
32 | def wgan_loss(input, target):
33 | # target is boolean
34 | return -1 * input.mean() if target else input.mean()
35 |
36 | self.loss = wgan_loss
37 | else:
38 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
39 |
40 | def get_target_label(self, input, target_is_real):
41 | if self.gan_type == 'wgan-gp':
42 | return target_is_real
43 | if target_is_real:
44 | return torch.empty_like(input).fill_(self.real_label_val)
45 | else:
46 | return torch.empty_like(input).fill_(self.fake_label_val)
47 |
48 | def forward(self, input, target_is_real):
49 | target_label = self.get_target_label(input, target_is_real)
50 | loss = self.loss(input, target_label)
51 | return loss
52 |
53 |
54 | class GradientPenaltyLoss(nn.Module):
55 | def __init__(self, device=torch.device('cpu')):
56 | super(GradientPenaltyLoss, self).__init__()
57 | self.register_buffer('grad_outputs', torch.Tensor())
58 | self.grad_outputs = self.grad_outputs.to(device)
59 |
60 | def get_grad_outputs(self, input):
61 | if self.grad_outputs.size() != input.size():
62 | self.grad_outputs.resize_(input.size()).fill_(1.0)
63 | return self.grad_outputs
64 |
65 | def forward(self, interp, interp_crit):
66 | grad_outputs = self.get_grad_outputs(interp_crit)
67 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
68 | grad_outputs=grad_outputs, create_graph=True,
69 | retain_graph=True, only_inputs=True)[0]
70 | grad_interp = grad_interp.view(grad_interp.size(0), -1)
71 | grad_interp_norm = grad_interp.norm(2, dim=1)
72 |
73 | loss = ((grad_interp_norm - 1)**2).mean()
74 | return loss
75 |
--------------------------------------------------------------------------------
/codes/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from collections import defaultdict
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 |
7 |
8 | class MultiStepLR_Restart(_LRScheduler):
9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10 | clear_state=False, last_epoch=-1):
11 | self.milestones = Counter(milestones)
12 | self.gamma = gamma
13 | self.clear_state = clear_state
14 | self.restarts = restarts if restarts else [0]
15 | self.restarts = [v + 1 for v in self.restarts]
16 | self.restart_weights = weights if weights else [1]
17 | assert len(self.restarts) == len(
18 | self.restart_weights), 'restarts and their weights do not match.'
19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
20 |
21 | def get_lr(self):
22 | if self.last_epoch in self.restarts:
23 | if self.clear_state:
24 | self.optimizer.state = defaultdict(dict)
25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
26 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
27 | if self.last_epoch not in self.milestones:
28 | return [group['lr'] for group in self.optimizer.param_groups]
29 | return [
30 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
31 | for group in self.optimizer.param_groups
32 | ]
33 |
34 |
35 | class CosineAnnealingLR_Restart(_LRScheduler):
36 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
37 | self.T_period = T_period
38 | self.T_max = self.T_period[0] # current T period
39 | self.eta_min = eta_min
40 | self.restarts = restarts if restarts else [0]
41 | self.restarts = [v + 1 for v in self.restarts]
42 | self.restart_weights = weights if weights else [1]
43 | self.last_restart = 0
44 | assert len(self.restarts) == len(
45 | self.restart_weights), 'restarts and their weights do not match.'
46 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
47 |
48 | def get_lr(self):
49 | if self.last_epoch == 0:
50 | return self.base_lrs
51 | elif self.last_epoch in self.restarts:
52 | self.last_restart = self.last_epoch
53 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
54 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
55 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
56 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
57 | return [
58 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
60 | ]
61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
63 | (group['lr'] - self.eta_min) + self.eta_min
64 | for group in self.optimizer.param_groups]
65 |
66 |
67 | if __name__ == "__main__":
68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
69 | betas=(0.9, 0.99))
70 | ##############################
71 | # MultiStepLR_Restart
72 | ##############################
73 | ## Original
74 | lr_steps = [200000, 400000, 600000, 800000]
75 | restarts = None
76 | restart_weights = None
77 |
78 | ## two
79 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
80 | restarts = [500000]
81 | restart_weights = [1]
82 |
83 | ## four
84 | lr_steps = [
85 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
86 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
87 | ]
88 | restarts = [250000, 500000, 750000]
89 | restart_weights = [1, 1, 1]
90 |
91 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
92 | clear_state=False)
93 |
94 | ##############################
95 | # Cosine Annealing Restart
96 | ##############################
97 | ## two
98 | T_period = [500000, 500000]
99 | restarts = [500000]
100 | restart_weights = [1]
101 |
102 | ## four
103 | T_period = [250000, 250000, 250000, 250000]
104 | restarts = [250000, 500000, 750000]
105 | restart_weights = [1, 1, 1]
106 |
107 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
108 | weights=restart_weights)
109 |
110 | ##############################
111 | # Draw figure
112 | ##############################
113 | N_iter = 1000000
114 | lr_l = list(range(N_iter))
115 | for i in range(N_iter):
116 | scheduler.step()
117 | current_lr = optimizer.param_groups[0]['lr']
118 | lr_l[i] = current_lr
119 |
120 | import matplotlib as mpl
121 | from matplotlib import pyplot as plt
122 | import matplotlib.ticker as mtick
123 | mpl.style.use('default')
124 | import seaborn
125 | seaborn.set(style='whitegrid')
126 | seaborn.set_context('paper')
127 |
128 | plt.figure(1)
129 | plt.subplot(111)
130 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
131 | plt.title('Title', fontsize=16, color='k')
132 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
133 | legend = plt.legend(loc='upper right', shadow=False)
134 | ax = plt.gca()
135 | labels = ax.get_xticks().tolist()
136 | for k, v in enumerate(labels):
137 | labels[k] = str(int(v / 1000)) + 'K'
138 | ax.set_xticklabels(labels)
139 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
140 |
141 | ax.set_ylabel('Learning rate')
142 | ax.set_xlabel('Iteration')
143 | fig = plt.gcf()
144 | plt.show()
145 |
--------------------------------------------------------------------------------
/codes/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import models.archs.SRResNet_arch as SRResNet_arch
3 | import models.archs.discriminator_vgg_arch as SRGAN_arch
4 | import models.archs.RRDBNet_arch as RRDBNet_arch
5 | # import models.archs.EDVR_arch as EDVR_arch
6 | import models.archs.CSRNet_arch as CSRNet_arch
7 |
8 |
9 | # Generator
10 | def define_G(opt):
11 | opt_net = opt['network_G']
12 | which_model = opt_net['which_model_G']
13 |
14 | # image restoration
15 | if which_model == 'MSRResNet':
16 | netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
17 | nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale'])
18 | elif which_model == 'RRDBNet':
19 | netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
20 | nf=opt_net['nf'], nb=opt_net['nb'])
21 |
22 | elif which_model == 'AdaFMNet':
23 | netG = AdaFMNet_arch.AdaFMNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
24 | nf=opt_net['nf'], nb=opt_net['nb'], adafm_ksize=opt_net['adafm_ksize'])
25 |
26 | elif which_model == 'CResMDNet':
27 | netG = CResMDNet_arch.CResMDNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
28 | nf=opt_net['nf'], nb=opt_net['nb'], cond_dim=opt_net['cond_dim'])
29 | elif which_model == 'BaseNet':
30 | netG = CResMDNet_arch.BaseNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
31 | nf=opt_net['nf'], nb=opt_net['nb'])
32 | elif which_model == 'CondNet':
33 | netG = CResMDNet_arch.CondNet(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
34 |
35 | # image enhancement
36 | elif which_model == 'CSRNet':
37 | netG = CSRNet_arch.CSRNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], base_nf=opt_net['base_nf'],
38 | cond_nf=opt_net['cond_nf'])
39 |
40 | # video restoration
41 | elif which_model == 'EDVR':
42 | netG = EDVR_arch.EDVR(nf=opt_net['nf'], nframes=opt_net['nframes'],
43 | groups=opt_net['groups'], front_RBs=opt_net['front_RBs'],
44 | back_RBs=opt_net['back_RBs'], center=opt_net['center'],
45 | predeblur=opt_net['predeblur'], HR_in=opt_net['HR_in'],
46 | w_TSA=opt_net['w_TSA'])
47 | else:
48 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
49 |
50 | return netG
51 |
52 |
53 | # Discriminator
54 | def define_D(opt):
55 | opt_net = opt['network_D']
56 | which_model = opt_net['which_model_D']
57 |
58 | if which_model == 'discriminator_vgg_128':
59 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'])
60 | else:
61 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
62 | return netD
63 |
64 |
65 | # Define network used for perceptual loss
66 | def define_F(opt, use_bn=False):
67 | gpu_ids = opt['gpu_ids']
68 | device = torch.device('cuda' if gpu_ids else 'cpu')
69 | # PyTorch pretrained VGG19-54, before ReLU.
70 | if use_bn:
71 | feature_layer = 49
72 | else:
73 | feature_layer = 34
74 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
75 | use_input_norm=True, device=device)
76 | netF.eval() # No need to train
77 | return netF
78 |
--------------------------------------------------------------------------------
/codes/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/codes/options/__init__.py
--------------------------------------------------------------------------------
/codes/options/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | import yaml
5 | from utils.util import OrderedYaml
6 | Loader, Dumper = OrderedYaml()
7 |
8 |
9 | def parse(opt_path, is_train=True):
10 | with open(opt_path, mode='r') as f:
11 | opt = yaml.load(f, Loader=Loader)
12 | # export CUDA_VISIBLE_DEVICES
13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
16 |
17 | opt['is_train'] = is_train
18 | if opt['distortion'] == 'sr':
19 | scale = opt['scale']
20 |
21 | # datasets
22 | for phase, dataset in opt['datasets'].items():
23 | phase = phase.split('_')[0]
24 | dataset['phase'] = phase
25 | if opt['distortion'] == 'sr':
26 | dataset['scale'] = scale
27 | is_lmdb = False
28 | if dataset.get('dataroot_GT', None) is not None:
29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
30 | if dataset['dataroot_GT'].endswith('lmdb'):
31 | is_lmdb = True
32 | if dataset.get('dataroot_LQ', None) is not None:
33 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
34 | if dataset['dataroot_LQ'].endswith('lmdb'):
35 | is_lmdb = True
36 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
37 | if dataset['mode'].endswith('mc'): # for memcached
38 | dataset['data_type'] = 'mc'
39 | dataset['mode'] = dataset['mode'].replace('_mc', '')
40 |
41 | # path
42 | for key, path in opt['path'].items():
43 | if path and key in opt['path'] and key != 'strict_load':
44 | opt['path'][key] = osp.expanduser(path)
45 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
46 | if is_train:
47 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
48 | opt['path']['experiments_root'] = experiments_root
49 | opt['path']['models'] = osp.join(experiments_root, 'models')
50 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
51 | opt['path']['log'] = experiments_root
52 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
53 |
54 | # change some options for debug mode
55 | if 'debug' in opt['name']:
56 | opt['train']['val_freq'] = 8
57 | opt['logger']['print_freq'] = 1
58 | opt['logger']['save_checkpoint_freq'] = 8
59 | else: # test
60 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
61 | opt['path']['results_root'] = results_root
62 | opt['path']['log'] = results_root
63 |
64 | # network
65 | if opt['distortion'] == 'sr':
66 | opt['network_G']['scale'] = scale
67 |
68 | return opt
69 |
70 |
71 | def dict2str(opt, indent_l=1):
72 | '''dict to string for logger'''
73 | msg = ''
74 | for k, v in opt.items():
75 | if isinstance(v, dict):
76 | msg += ' ' * (indent_l * 2) + k + ':[\n'
77 | msg += dict2str(v, indent_l + 1)
78 | msg += ' ' * (indent_l * 2) + ']\n'
79 | else:
80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
81 | return msg
82 |
83 |
84 | class NoneDict(dict):
85 | def __missing__(self, key):
86 | return None
87 |
88 |
89 | # convert to NoneDict, which return None for missing key.
90 | def dict_to_nonedict(opt):
91 | if isinstance(opt, dict):
92 | new_opt = dict()
93 | for key, sub_opt in opt.items():
94 | new_opt[key] = dict_to_nonedict(sub_opt)
95 | return NoneDict(**new_opt)
96 | elif isinstance(opt, list):
97 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
98 | else:
99 | return opt
100 |
101 |
102 | def check_resume(opt, resume_iter):
103 | '''Check resume states and pretrain_model paths'''
104 | logger = logging.getLogger('base')
105 | if opt['path']['resume_state']:
106 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
107 | 'pretrain_model_D', None) is not None:
108 | logger.warning('pretrain_model path will be ignored when resuming training.')
109 |
110 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
111 | '{}_G.pth'.format(resume_iter))
112 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
113 | if 'gan' in opt['model']:
114 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
115 | '{}_D.pth'.format(resume_iter))
116 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
117 |
--------------------------------------------------------------------------------
/codes/options/test/test_ESRGAN.yml:
--------------------------------------------------------------------------------
1 | name: RRDB_ESRGAN_x4
2 | suffix: ~ # add suffix to saved images
3 | model: sr
4 | distortion: sr
5 | scale: 4
6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
7 | gpu_ids: [0]
8 |
9 | datasets:
10 | test_1: # the 1st test dataset
11 | name: set5
12 | mode: LQGT
13 | dataroot_GT: ../datasets/val_set5/Set5
14 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
15 | test_2: # the 2st test dataset
16 | name: set14
17 | mode: LQGT
18 | dataroot_GT: ../datasets/val_set14/Set14
19 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
20 |
21 | #### network structures
22 | network_G:
23 | which_model_G: RRDBNet
24 | in_nc: 3
25 | out_nc: 3
26 | nf: 64
27 | nb: 23
28 | upscale: 4
29 |
30 | #### path
31 | path:
32 | pretrain_model_G: ../experiments/pretrained_models/RRDB_ESRGAN_x4.pth
33 |
--------------------------------------------------------------------------------
/codes/options/test/test_Enhance.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: test_csrnet
3 | suffix: ~ # add suffix to saved images
4 | model: sr
5 | distortion: sr
6 | scale: 1
7 | gpu_ids: [0]
8 |
9 | datasets:
10 | val:
11 | name: MIT_fivek_500
12 | mode: LQGT_enhance
13 | dataroot_GT: ../datasets/expert_C_test
14 | dataroot_LQ: ../datasets/raw_input_test
15 |
16 | #### network structures
17 | network_G:
18 | which_model_G: CSRNet
19 | in_nc: 3
20 | out_nc: 3
21 | base_nf: 64
22 | cond_nf: 32
23 |
24 | #### path
25 | path:
26 | root:
27 | pretrain_model_G: ../experiments/pretrain_models/csrnet.pth
--------------------------------------------------------------------------------
/codes/options/test/test_SRGAN.yml:
--------------------------------------------------------------------------------
1 | name: MSRGANx4
2 | suffix: ~ # add suffix to saved images
3 | model: sr
4 | distortion: sr
5 | scale: 4
6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
7 | gpu_ids: [0]
8 |
9 | datasets:
10 | test_1: # the 1st test dataset
11 | name: set5
12 | mode: LQGT
13 | dataroot_GT: ../datasets/val_set5/Set5
14 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
15 | test_2: # the 2st test dataset
16 | name: set14
17 | mode: LQGT
18 | dataroot_GT: ../datasets/val_set14/Set14
19 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
20 |
21 | #### network structures
22 | network_G:
23 | which_model_G: MSRResNet
24 | in_nc: 3
25 | out_nc: 3
26 | nf: 64
27 | nb: 16
28 | upscale: 4
29 |
30 | #### path
31 | path:
32 | pretrain_model_G: ../experiments/pretrained_models/MSRGANx4.pth
33 |
--------------------------------------------------------------------------------
/codes/options/test/test_SRResNet.yml:
--------------------------------------------------------------------------------
1 | name: MSRResNetx4
2 | suffix: ~ # add suffix to saved images
3 | model: sr
4 | distortion: sr
5 | scale: 4
6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
7 | gpu_ids: [0]
8 |
9 | datasets:
10 | test_1: # the 1st test dataset
11 | name: set5
12 | mode: LQGT
13 | dataroot_GT: ../datasets/val_set5/Set5
14 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
15 | test_2: # the 2st test dataset
16 | name: set14
17 | mode: LQGT
18 | dataroot_GT: ../datasets/val_set14/Set14
19 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
20 | test_3:
21 | name: bsd100
22 | mode: LQGT
23 | dataroot_GT: ../datasets/BSD/BSDS100
24 | dataroot_LQ: ../datasets/BSD/BSDS100_bicLRx4
25 | test_4:
26 | name: urban100
27 | mode: LQGT
28 | dataroot_GT: ../datasets/urban100
29 | dataroot_LQ: ../datasets/urban100_bicLRx4
30 | test_5:
31 | name: div2k100
32 | mode: LQGT
33 | dataroot_GT: ../datasets/DIV2K100/DIV2K_valid_HR
34 | dataroot_LQ: ../datasets/DIV2K100/DIV2K_valid_bicLRx4
35 |
36 |
37 | #### network structures
38 | network_G:
39 | which_model_G: MSRResNet
40 | in_nc: 3
41 | out_nc: 3
42 | nf: 64
43 | nb: 16
44 | upscale: 4
45 |
46 | #### path
47 | path:
48 | pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
49 |
--------------------------------------------------------------------------------
/codes/options/train/train_EDVR_M.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: 002_EDVR_EDVRwoTSAIni_lr4e-4_600k_REDS_LrCAR4S_fixTSA50k_new
3 | use_tb_logger: true
4 | model: video_base
5 | distortion: sr
6 | scale: 4
7 | gpu_ids: [0,1,2,3,4,5,6,7]
8 |
9 | #### datasets
10 | datasets:
11 | train:
12 | name: REDS
13 | mode: REDS
14 | interval_list: [1]
15 | random_reverse: false
16 | border_mode: false
17 | dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
18 | dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
19 | cache_keys: ~
20 |
21 | N_frames: 5
22 | use_shuffle: true
23 | n_workers: 3 # per GPU
24 | batch_size: 32
25 | GT_size: 256
26 | LQ_size: 64
27 | use_flip: true
28 | use_rot: true
29 | color: RGB
30 | val:
31 | name: REDS4
32 | mode: video_test
33 | dataroot_GT: ../datasets/REDS4/GT
34 | dataroot_LQ: ../datasets/REDS4/sharp_bicubic
35 | cache_data: True
36 | N_frames: 5
37 | padding: new_info
38 |
39 | #### network structures
40 | network_G:
41 | which_model_G: EDVR
42 | nf: 64
43 | nframes: 5
44 | groups: 8
45 | front_RBs: 5
46 | back_RBs: 10
47 | predeblur: false
48 | HR_in: false
49 | w_TSA: true
50 |
51 | #### path
52 | path:
53 | pretrain_model_G: ../experiments/pretrained_models/EDVR_REDS_SR_M_woTSA.pth
54 | strict_load: false
55 | resume_state: ~
56 |
57 | #### training settings: learning rate scheme, loss
58 | train:
59 | lr_G: !!float 4e-4
60 | lr_scheme: CosineAnnealingLR_Restart
61 | beta1: 0.9
62 | beta2: 0.99
63 | niter: 600000
64 | ft_tsa_only: 50000
65 | warmup_iter: -1 # -1: no warm up
66 | T_period: [50000, 100000, 150000, 150000, 150000]
67 | restarts: [50000, 150000, 300000, 450000]
68 | restart_weights: [1, 1, 1, 1]
69 | eta_min: !!float 1e-7
70 |
71 | pixel_criterion: cb
72 | pixel_weight: 1.0
73 | val_freq: !!float 5e3
74 |
75 | manual_seed: 0
76 |
77 | #### logger
78 | logger:
79 | print_freq: 100
80 | save_checkpoint_freq: !!float 5e3
81 |
--------------------------------------------------------------------------------
/codes/options/train/train_EDVR_woTSA_M.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: 001_EDVRwoTSA_scratch_lr4e-4_600k_REDS_LrCAR4S
3 | use_tb_logger: true
4 | model: video_base
5 | distortion: sr
6 | scale: 4
7 | gpu_ids: [0,1,2,3,4,5,6,7]
8 |
9 | #### datasets
10 | datasets:
11 | train:
12 | name: REDS
13 | mode: REDS
14 | interval_list: [1]
15 | random_reverse: false
16 | border_mode: false
17 | dataroot_GT: ../datasets/REDS/train_sharp_wval.lmdb
18 | dataroot_LQ: ../datasets/REDS/train_sharp_bicubic_wval.lmdb
19 | cache_keys: ~
20 |
21 | N_frames: 5
22 | use_shuffle: true
23 | n_workers: 3 # per GPU
24 | batch_size: 32
25 | GT_size: 256
26 | LQ_size: 64
27 | use_flip: true
28 | use_rot: true
29 | color: RGB
30 |
31 | #### network structures
32 | network_G:
33 | which_model_G: EDVR
34 | nf: 64
35 | nframes: 5
36 | groups: 8
37 | front_RBs: 5
38 | back_RBs: 10
39 | predeblur: false
40 | HR_in: false
41 | w_TSA: false
42 |
43 | #### path
44 | path:
45 | pretrain_model_G: ~
46 | strict_load: true
47 | resume_state: ~
48 |
49 | #### training settings: learning rate scheme, loss
50 | train:
51 | lr_G: !!float 4e-4
52 | lr_scheme: CosineAnnealingLR_Restart
53 | beta1: 0.9
54 | beta2: 0.99
55 | niter: 600000
56 | warmup_iter: -1 # -1: no warm up
57 | T_period: [150000, 150000, 150000, 150000]
58 | restarts: [150000, 300000, 450000]
59 | restart_weights: [1, 1, 1]
60 | eta_min: !!float 1e-7
61 |
62 | pixel_criterion: cb
63 | pixel_weight: 1.0
64 | val_freq: !!float 5e3
65 |
66 | manual_seed: 0
67 |
68 | #### logger
69 | logger:
70 | print_freq: 100
71 | save_checkpoint_freq: !!float 5e3
72 |
--------------------------------------------------------------------------------
/codes/options/train/train_ESRGAN.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: 003_RRDB_ESRGANx4_DIV2K
3 | use_tb_logger: true
4 | model: srgan
5 | distortion: sr
6 | scale: 4
7 | gpu_ids: [2]
8 |
9 | #### datasets
10 | datasets:
11 | train:
12 | name: DIV2K
13 | mode: LQGT
14 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
15 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
16 |
17 | use_shuffle: true
18 | n_workers: 6 # per GPU
19 | batch_size: 16
20 | GT_size: 128
21 | use_flip: true
22 | use_rot: true
23 | color: RGB
24 | val:
25 | name: val_set14
26 | mode: LQGT
27 | dataroot_GT: ../datasets/val_set14/Set14
28 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
29 |
30 | #### network structures
31 | network_G:
32 | which_model_G: RRDBNet
33 | in_nc: 3
34 | out_nc: 3
35 | nf: 64
36 | nb: 23
37 | network_D:
38 | which_model_D: discriminator_vgg_128
39 | in_nc: 3
40 | nf: 64
41 |
42 | #### path
43 | path:
44 | pretrain_model_G: ../experiments/pretrained_models/RRDB_PSNR_x4.pth
45 | strict_load: true
46 | resume_state: ~
47 |
48 | #### training settings: learning rate scheme, loss
49 | train:
50 | lr_G: !!float 1e-4
51 | weight_decay_G: 0
52 | beta1_G: 0.9
53 | beta2_G: 0.99
54 | lr_D: !!float 1e-4
55 | weight_decay_D: 0
56 | beta1_D: 0.9
57 | beta2_D: 0.99
58 | lr_scheme: MultiStepLR
59 |
60 | niter: 400000
61 | warmup_iter: -1 # no warm up
62 | lr_steps: [50000, 100000, 200000, 300000]
63 | lr_gamma: 0.5
64 |
65 | pixel_criterion: l1
66 | pixel_weight: !!float 1e-2
67 | feature_criterion: l1
68 | feature_weight: 1
69 | gan_type: ragan # gan | ragan
70 | gan_weight: !!float 5e-3
71 |
72 | D_update_ratio: 1
73 | D_init_iters: 0
74 |
75 | manual_seed: 10
76 | val_freq: !!float 5e3
77 |
78 | #### logger
79 | logger:
80 | print_freq: 100
81 | save_checkpoint_freq: !!float 5e3
82 |
--------------------------------------------------------------------------------
/codes/options/train/train_Enhance.yml:
--------------------------------------------------------------------------------
1 | #### general settings
2 | name: csrnet
3 | use_tb_logger: true
4 | model: sr
5 | distortion: sr
6 | scale: 1
7 | gpu_ids: [0]
8 |
9 | #### datasets
10 | datasets:
11 | train:
12 | name: MIT_fivek
13 | mode: LQGT_enhance
14 | dataroot_GT: ../datasets/expert_C_train
15 | dataroot_LQ: ../datasets/raw_input_train
16 |
17 | use_shuffle: true
18 | n_workers: 16
19 | batch_size: 1
20 | color: RGB
21 |
22 | val:
23 | name: MIT_fivek_500
24 | mode: LQGT_enhance
25 | dataroot_GT: ../datasets/expert_C_test
26 | dataroot_LQ: ../datasets/raw_input_test
27 |
28 | #### network structures
29 | network_G:
30 | which_model_G: CSRNet
31 | in_nc: 3
32 | out_nc: 3
33 | base_nf: 64
34 | cond_nf: 32
35 |
36 |
37 | #### path
38 | path:
39 | root:
40 | pretrain_model_G: ~
41 | strict_load: true
42 | resume_state: ~
43 |
44 | #### training settings: learning rate scheme, loss
45 | train:
46 | lr_G: !!float 1e-4
47 | lr_scheme: MultiStepLR # MultiStepLR | CosineAnnealingLR_Restart
48 | beta1: 0.9
49 | beta2: 0.99
50 | niter: 600000
51 | warmup_iter: -1 # no warm up
52 | lr_scheme: MultiStepLR
53 | lr_steps: [100000, 200000, 300000, 400000, 500000]
54 | lr_gamma: 0.5
55 |
56 | pixel_criterion: l1
57 | pixel_weight: 1.0
58 |
59 | manual_seed: 10
60 | val_freq: !!float 5e3
61 |
62 | #### logger
63 | logger:
64 | print_freq: 100
65 | save_checkpoint_freq: !!float 5e3
66 |
--------------------------------------------------------------------------------
/codes/options/train/train_SRGAN.yml:
--------------------------------------------------------------------------------
1 | # Not exactly the same as SRGAN in
2 | # With 16 Residual blocks w/o BN
3 |
4 | #### general settings
5 | name: 002_SRGANx4_MSRResNetx4Ini_DIV2K
6 | use_tb_logger: true
7 | model: srgan
8 | distortion: sr
9 | scale: 4
10 | gpu_ids: [1]
11 |
12 | #### datasets
13 | datasets:
14 | train:
15 | name: DIV2K
16 | mode: LQGT
17 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
18 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
19 |
20 | use_shuffle: true
21 | n_workers: 6 # per GPU
22 | batch_size: 16
23 | GT_size: 128
24 | use_flip: true
25 | use_rot: true
26 | color: RGB
27 | val:
28 | name: val_set14
29 | mode: LQGT
30 | dataroot_GT: ../datasets/val_set14/Set14
31 | dataroot_LQ: ../datasets/val_set14/Set14_bicLRx4
32 |
33 | #### network structures
34 | network_G:
35 | which_model_G: MSRResNet
36 | in_nc: 3
37 | out_nc: 3
38 | nf: 64
39 | nb: 16
40 | upscale: 4
41 | network_D:
42 | which_model_D: discriminator_vgg_128
43 | in_nc: 3
44 | nf: 64
45 |
46 | #### path
47 | path:
48 | pretrain_model_G: ../experiments/pretrained_models/MSRResNetx4.pth
49 | strict_load: true
50 | resume_state: ~
51 |
52 | #### training settings: learning rate scheme, loss
53 | train:
54 | lr_G: !!float 1e-4
55 | weight_decay_G: 0
56 | beta1_G: 0.9
57 | beta2_G: 0.99
58 | lr_D: !!float 1e-4
59 | weight_decay_D: 0
60 | beta1_D: 0.9
61 | beta2_D: 0.99
62 | lr_scheme: MultiStepLR
63 |
64 | niter: 400000
65 | warmup_iter: -1 # no warm up
66 | lr_steps: [50000, 100000, 200000, 300000]
67 | lr_gamma: 0.5
68 |
69 | pixel_criterion: l1
70 | pixel_weight: !!float 1e-2
71 | feature_criterion: l1
72 | feature_weight: 1
73 | gan_type: gan # gan | ragan
74 | gan_weight: !!float 5e-3
75 |
76 | D_update_ratio: 1
77 | D_init_iters: 0
78 |
79 | manual_seed: 10
80 | val_freq: !!float 5e3
81 |
82 | #### logger
83 | logger:
84 | print_freq: 100
85 | save_checkpoint_freq: !!float 5e3
86 |
--------------------------------------------------------------------------------
/codes/options/train/train_SRResNet.yml:
--------------------------------------------------------------------------------
1 | # Not exactly the same as SRResNet in
2 | # With 16 Residual blocks w/o BN
3 |
4 | #### general settings
5 | name: 001_MSRResNetx4_scratch_DIV2K
6 | use_tb_logger: true
7 | model: sr
8 | distortion: sr
9 | scale: 4
10 | gpu_ids: [0]
11 |
12 | #### datasets
13 | datasets:
14 | train:
15 | name: DIV2K
16 | mode: LQGT
17 | dataroot_GT: ../datasets/DIV2K/DIV2K800_sub.lmdb
18 | dataroot_LQ: ../datasets/DIV2K/DIV2K800_sub_bicLRx4.lmdb
19 |
20 | use_shuffle: true
21 | n_workers: 6 # per GPU
22 | batch_size: 16
23 | GT_size: 128
24 | use_flip: true
25 | use_rot: true
26 | color: RGB
27 | val:
28 | name: val_set5
29 | mode: LQGT
30 | dataroot_GT: ../datasets/val_set5/Set5
31 | dataroot_LQ: ../datasets/val_set5/Set5_bicLRx4
32 |
33 | #### network structures
34 | network_G:
35 | which_model_G: MSRResNet
36 | in_nc: 3
37 | out_nc: 3
38 | nf: 64
39 | nb: 16
40 | upscale: 4
41 |
42 | #### path
43 | path:
44 | pretrain_model_G: ~
45 | strict_load: true
46 | resume_state: ~
47 |
48 | #### training settings: learning rate scheme, loss
49 | train:
50 | lr_G: !!float 2e-4
51 | lr_scheme: CosineAnnealingLR_Restart
52 | beta1: 0.9
53 | beta2: 0.99
54 | niter: 1000000
55 | warmup_iter: -1 # no warm up
56 | T_period: [250000, 250000, 250000, 250000]
57 | restarts: [250000, 500000, 750000]
58 | restart_weights: [1, 1, 1]
59 | eta_min: !!float 1e-7
60 |
61 | pixel_criterion: l1
62 | pixel_weight: 1.0
63 |
64 | manual_seed: 10
65 | val_freq: !!float 5e3
66 |
67 | #### logger
68 | logger:
69 | print_freq: 100
70 | save_checkpoint_freq: !!float 5e3
71 |
--------------------------------------------------------------------------------
/codes/run_scripts.sh:
--------------------------------------------------------------------------------
1 | # single GPU training (image SR)
2 | python train.py -opt options/train/train_SRResNet.yml
3 | python train.py -opt options/train/train_SRGAN.yml
4 | python train.py -opt options/train/train_ESRGAN.yml
5 |
6 |
7 | # distributed training (video SR)
8 | # 8 GPUs
9 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_woTSA_M.yml --launcher pytorch
10 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 train.py -opt options/train/train_EDVR_M.yml --launcher pytorch
--------------------------------------------------------------------------------
/codes/scripts/transfer_params_MSRResNet.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import sys
3 | import torch
4 | try:
5 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
6 | import models.archs.SRResNet_arch as SRResNet_arch
7 | except ImportError:
8 | pass
9 |
10 | pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth')
11 | crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3)
12 | crt_net = crt_model.state_dict()
13 |
14 | for k, v in crt_net.items():
15 | if k in pretrained_net and 'upconv1' not in k:
16 | crt_net[k] = pretrained_net[k]
17 | print('replace ... ', k)
18 |
19 | # x4 -> x3
20 | crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2
21 | crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2
22 | crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2
23 | crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2
24 | crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2
25 | crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2
26 |
27 | torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth')
28 |
--------------------------------------------------------------------------------
/codes/test.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import logging
3 | import time
4 | import argparse
5 | from collections import OrderedDict
6 | import torch
7 |
8 | import options.options as option
9 | import utils.util as util
10 | from data.util import bgr2ycbcr
11 | from data import create_dataset, create_dataloader
12 | from models import create_model
13 |
14 | #### options
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
17 | opt = option.parse(parser.parse_args().opt, is_train=False)
18 | opt = option.dict_to_nonedict(opt)
19 |
20 | util.mkdirs(
21 | (path for key, path in opt['path'].items()
22 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
23 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
24 | screen=True, tofile=True)
25 | logger = logging.getLogger('base')
26 | logger.info(option.dict2str(opt))
27 |
28 | #### Create test dataset and dataloader
29 | test_loaders = []
30 | for phase, dataset_opt in sorted(opt['datasets'].items()):
31 | test_set = create_dataset(dataset_opt)
32 | test_loader = create_dataloader(test_set, dataset_opt)
33 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
34 | test_loaders.append(test_loader)
35 |
36 | model = create_model(opt)
37 | for test_loader in test_loaders:
38 | test_set_name = test_loader.dataset.opt['name']
39 | logger.info('\nTesting [{:s}]...'.format(test_set_name))
40 | test_start_time = time.time()
41 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
42 | util.mkdir(dataset_dir)
43 |
44 | test_results = OrderedDict()
45 | test_results['psnr'] = []
46 | test_results['ssim'] = []
47 | test_results['psnr_y'] = []
48 | test_results['ssim_y'] = []
49 |
50 | cond = test_loader.dataset.opt['cond']
51 | cond_norm = test_loader.dataset.opt['cond_norm']
52 | for data in test_loader:
53 | need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True
54 |
55 | if cond is not None:
56 | for i in range(len(cond)):
57 | cond[i] = cond[i] / cond_norm[i]
58 | data['cond'] = torch.Tensor(cond).view(1, -1)
59 | need_cond = True
60 | elif test_loader.dataset.opt['mode'] in ['LQGT_cond']:
61 | need_cond = True
62 | else:
63 | need_cond = False
64 |
65 | model.feed_data(data, need_GT=need_GT, need_cond=need_cond)
66 |
67 | img_path = data['LQ_path'][0]
68 | img_name = osp.splitext(osp.basename(img_path))[0]
69 |
70 | model.test()
71 | visuals = model.get_current_visuals(need_GT=need_GT)
72 |
73 | sr_img = util.tensor2img(visuals['rlt']) # uint8
74 |
75 | # save images
76 | suffix = opt['suffix']
77 | if suffix:
78 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
79 | else:
80 | save_img_path = osp.join(dataset_dir, img_name + '.png')
81 | util.save_img(sr_img, save_img_path)
82 |
83 | # calculate PSNR and SSIM
84 | if need_GT:
85 | gt_img = util.tensor2img(visuals['GT'])
86 | psnr = util.calculate_psnr(sr_img, gt_img)
87 | ssim = util.calculate_ssim(sr_img, gt_img)
88 | test_results['psnr'].append(psnr)
89 | test_results['ssim'].append(ssim)
90 |
91 | if gt_img.shape[2] == 3: # RGB image
92 | sr_img_y = bgr2ycbcr(sr_img / 255., only_y=True)
93 | gt_img_y = bgr2ycbcr(gt_img / 255., only_y=True)
94 |
95 | psnr_y = util.calculate_psnr(sr_img_y * 255, gt_img_y * 255)
96 | ssim_y = util.calculate_ssim(sr_img_y * 255, gt_img_y * 255)
97 | test_results['psnr_y'].append(psnr_y)
98 | test_results['ssim_y'].append(ssim_y)
99 | logger.info(
100 | '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'.
101 | format(img_name, psnr, ssim, psnr_y, ssim_y))
102 | else:
103 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
104 | else:
105 | logger.info(img_name)
106 |
107 | if need_GT: # metrics
108 | # Average PSNR/SSIM results
109 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
110 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
111 | logger.info(
112 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format(
113 | test_set_name, ave_psnr, ave_ssim))
114 | if test_results['psnr_y'] and test_results['ssim_y']:
115 | ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y'])
116 | ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y'])
117 | logger.info(
118 | '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'.
119 | format(ave_psnr_y, ave_ssim_y))
120 |
--------------------------------------------------------------------------------
/codes/test_CSRNet.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import logging
3 | import time
4 | import argparse
5 | from collections import OrderedDict
6 | import torch
7 |
8 | import options.options as option
9 | import utils.util as util
10 | from data.util import bgr2ycbcr
11 | from data import create_dataset, create_dataloader
12 | from models import create_model
13 |
14 | #### options
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.')
17 | opt = option.parse(parser.parse_args().opt, is_train=False)
18 | opt = option.dict_to_nonedict(opt)
19 |
20 | util.mkdirs(
21 | (path for key, path in opt['path'].items()
22 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key))
23 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO,
24 | screen=True, tofile=True)
25 | logger = logging.getLogger('base')
26 | logger.info(option.dict2str(opt))
27 |
28 | #### Create test dataset and dataloader
29 | test_loaders = []
30 | for phase, dataset_opt in sorted(opt['datasets'].items()):
31 | test_set = create_dataset(dataset_opt)
32 | test_loader = create_dataloader(test_set, dataset_opt)
33 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
34 | test_loaders.append(test_loader)
35 |
36 | model = create_model(opt)
37 | for test_loader in test_loaders:
38 | test_set_name = test_loader.dataset.opt['name']
39 | logger.info('\nTesting [{:s}]...'.format(test_set_name))
40 | test_start_time = time.time()
41 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name)
42 | util.mkdir(dataset_dir)
43 |
44 | test_results = OrderedDict()
45 | test_results['psnr'] = []
46 | test_results['ssim'] = []
47 |
48 | for data in test_loader:
49 | need_GT = True
50 | model.feed_data(data, need_GT=need_GT, need_cond=False)
51 |
52 | img_path = data['LQ_path'][0]
53 | img_name = osp.splitext(osp.basename(img_path))[0]
54 |
55 | model.test()
56 | visuals = model.get_current_visuals(need_GT=need_GT)
57 |
58 | sr_img = util.tensor2img(visuals['rlt']) # uint8
59 |
60 | # save images
61 | suffix = opt['suffix']
62 | if suffix:
63 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png')
64 | else:
65 | save_img_path = osp.join(dataset_dir, img_name + '.png')
66 | util.save_img(sr_img, save_img_path)
67 |
68 | # calculate PSNR and SSIM
69 | if need_GT:
70 | gt_img = util.tensor2img(visuals['GT'])
71 | psnr = util.calculate_psnr(sr_img, gt_img)
72 | ssim = util.calculate_ssim(sr_img, gt_img)
73 | test_results['psnr'].append(psnr)
74 | test_results['ssim'].append(ssim)
75 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim))
76 | else:
77 | logger.info(img_name)
78 |
79 | if need_GT: # metrics
80 | # Average PSNR/SSIM results
81 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
82 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
83 | logger.info(
84 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format(
85 | test_set_name, ave_psnr, ave_ssim))
--------------------------------------------------------------------------------
/codes/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/codes/utils/__init__.py
--------------------------------------------------------------------------------
/codes/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import math
5 | import torch.nn.functional as F
6 | from datetime import datetime
7 | import random
8 | import logging
9 | from collections import OrderedDict
10 | import numpy as np
11 | import cv2
12 | import torch
13 | # from torchvision.utils import make_grid
14 | from shutil import get_terminal_size
15 |
16 | import yaml
17 | try:
18 | from yaml import CLoader as Loader, CDumper as Dumper
19 | except ImportError:
20 | from yaml import Loader, Dumper
21 |
22 |
23 | def OrderedYaml():
24 | '''yaml orderedDict support'''
25 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
26 |
27 | def dict_representer(dumper, data):
28 | return dumper.represent_dict(data.items())
29 |
30 | def dict_constructor(loader, node):
31 | return OrderedDict(loader.construct_pairs(node))
32 |
33 | Dumper.add_representer(OrderedDict, dict_representer)
34 | Loader.add_constructor(_mapping_tag, dict_constructor)
35 | return Loader, Dumper
36 |
37 |
38 | ####################
39 | # miscellaneous
40 | ####################
41 |
42 |
43 | def get_timestamp():
44 | return datetime.now().strftime('%y%m%d-%H%M%S')
45 |
46 |
47 | def mkdir(path):
48 | if not os.path.exists(path):
49 | os.makedirs(path)
50 |
51 |
52 | def mkdirs(paths):
53 | if isinstance(paths, str):
54 | mkdir(paths)
55 | else:
56 | for path in paths:
57 | mkdir(path)
58 |
59 |
60 | def mkdir_and_rename(path):
61 | if os.path.exists(path):
62 | new_name = path + '_archived_' + get_timestamp()
63 | print('Path already exists. Rename it to [{:s}]'.format(new_name))
64 | logger = logging.getLogger('base')
65 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
66 | os.rename(path, new_name)
67 | os.makedirs(path)
68 |
69 |
70 | def set_random_seed(seed):
71 | random.seed(seed)
72 | np.random.seed(seed)
73 | torch.manual_seed(seed)
74 | torch.cuda.manual_seed_all(seed)
75 |
76 |
77 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
78 | '''set up logger'''
79 | lg = logging.getLogger(logger_name)
80 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
81 | datefmt='%y-%m-%d %H:%M:%S')
82 | lg.setLevel(level)
83 | if tofile:
84 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
85 | fh = logging.FileHandler(log_file, mode='w')
86 | fh.setFormatter(formatter)
87 | lg.addHandler(fh)
88 | if screen:
89 | sh = logging.StreamHandler()
90 | sh.setFormatter(formatter)
91 | lg.addHandler(sh)
92 |
93 |
94 | ####################
95 | # image convert
96 | ####################
97 | def crop_border(img_list, crop_border):
98 | """Crop borders of images
99 | Args:
100 | img_list (list [Numpy]): HWC
101 | crop_border (int): crop border for each end of height and weight
102 |
103 | Returns:
104 | (list [Numpy]): cropped image list
105 | """
106 | if crop_border == 0:
107 | return img_list
108 | else:
109 | return [v[crop_border:-crop_border, crop_border:-crop_border] for v in img_list]
110 |
111 |
112 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
113 | '''
114 | Converts a torch Tensor into an image Numpy array
115 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
116 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
117 | '''
118 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
119 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
120 | n_dim = tensor.dim()
121 | if n_dim == 4:
122 | n_img = len(tensor)
123 | # img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
124 | # img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
125 | elif n_dim == 3:
126 | img_np = tensor.numpy()
127 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
128 | elif n_dim == 2:
129 | img_np = tensor.numpy()
130 | else:
131 | raise TypeError(
132 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
133 | if out_type == np.uint8:
134 | img_np = (img_np * 255.0).round()
135 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
136 | return img_np.astype(out_type)
137 |
138 |
139 | def save_img(img, img_path, mode='RGB'):
140 | cv2.imwrite(img_path, img)
141 |
142 |
143 | def DUF_downsample(x, scale=4):
144 | """Downsamping with Gaussian kernel used in the DUF official code
145 |
146 | Args:
147 | x (Tensor, [B, T, C, H, W]): frames to be downsampled.
148 | scale (int): downsampling factor: 2 | 3 | 4.
149 | """
150 |
151 | assert scale in [2, 3, 4], 'Scale [{}] is not supported'.format(scale)
152 |
153 | def gkern(kernlen=13, nsig=1.6):
154 | import scipy.ndimage.filters as fi
155 | inp = np.zeros((kernlen, kernlen))
156 | # set element at the middle to one, a dirac delta
157 | inp[kernlen // 2, kernlen // 2] = 1
158 | # gaussian-smooth the dirac, resulting in a gaussian filter mask
159 | return fi.gaussian_filter(inp, nsig)
160 |
161 | B, T, C, H, W = x.size()
162 | x = x.view(-1, 1, H, W)
163 | pad_w, pad_h = 6 + scale * 2, 6 + scale * 2 # 6 is the pad of the gaussian filter
164 | r_h, r_w = 0, 0
165 | if scale == 3:
166 | r_h = 3 - (H % 3)
167 | r_w = 3 - (W % 3)
168 | x = F.pad(x, [pad_w, pad_w + r_w, pad_h, pad_h + r_h], 'reflect')
169 |
170 | gaussian_filter = torch.from_numpy(gkern(13, 0.4 * scale)).type_as(x).unsqueeze(0).unsqueeze(0)
171 | x = F.conv2d(x, gaussian_filter, stride=scale)
172 | x = x[:, :, 2:-2, 2:-2]
173 | x = x.view(B, T, C, x.size(2), x.size(3))
174 | return x
175 |
176 |
177 | def single_forward(model, inp):
178 | """PyTorch model forward (single test), it is just a simple warpper
179 | Args:
180 | model (PyTorch model)
181 | inp (Tensor): inputs defined by the model
182 |
183 | Returns:
184 | output (Tensor): outputs of the model. float, in CPU
185 | """
186 | with torch.no_grad():
187 | model_output = model(inp)
188 | if isinstance(model_output, list) or isinstance(model_output, tuple):
189 | output = model_output[0]
190 | else:
191 | output = model_output
192 | output = output.data.float().cpu()
193 | return output
194 |
195 |
196 | def flipx4_forward(model, inp):
197 | """Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W
198 | Args:
199 | model (PyTorch model)
200 | inp (Tensor): inputs defined by the model
201 |
202 | Returns:
203 | output (Tensor): outputs of the model. float, in CPU
204 | """
205 | # normal
206 | output_f = single_forward(model, inp)
207 |
208 | # flip W
209 | output = single_forward(model, torch.flip(inp, (-1, )))
210 | output_f = output_f + torch.flip(output, (-1, ))
211 | # flip H
212 | output = single_forward(model, torch.flip(inp, (-2, )))
213 | output_f = output_f + torch.flip(output, (-2, ))
214 | # flip both H and W
215 | output = single_forward(model, torch.flip(inp, (-2, -1)))
216 | output_f = output_f + torch.flip(output, (-2, -1))
217 |
218 | return output_f / 4
219 |
220 |
221 | ####################
222 | # metric
223 | ####################
224 |
225 |
226 | def calculate_psnr(img1, img2):
227 | # img1 and img2 have range [0, 255]
228 | img1 = img1.astype(np.float64)
229 | img2 = img2.astype(np.float64)
230 | mse = np.mean((img1 - img2)**2)
231 | if mse == 0:
232 | return float('inf')
233 | return 20 * math.log10(255.0 / math.sqrt(mse))
234 |
235 |
236 | def ssim(img1, img2):
237 | C1 = (0.01 * 255)**2
238 | C2 = (0.03 * 255)**2
239 |
240 | img1 = img1.astype(np.float64)
241 | img2 = img2.astype(np.float64)
242 | kernel = cv2.getGaussianKernel(11, 1.5)
243 | window = np.outer(kernel, kernel.transpose())
244 |
245 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
246 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
247 | mu1_sq = mu1**2
248 | mu2_sq = mu2**2
249 | mu1_mu2 = mu1 * mu2
250 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
251 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
252 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
253 |
254 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
255 | (sigma1_sq + sigma2_sq + C2))
256 | return ssim_map.mean()
257 |
258 |
259 | def calculate_ssim(img1, img2):
260 | '''calculate SSIM
261 | the same outputs as MATLAB's
262 | img1, img2: [0, 255]
263 | '''
264 | if not img1.shape == img2.shape:
265 | raise ValueError('Input images must have the same dimensions.')
266 | if img1.ndim == 2:
267 | return ssim(img1, img2)
268 | elif img1.ndim == 3:
269 | if img1.shape[2] == 3:
270 | ssims = []
271 | for i in range(3):
272 | ssims.append(ssim(img1, img2))
273 | return np.array(ssims).mean()
274 | elif img1.shape[2] == 1:
275 | return ssim(np.squeeze(img1), np.squeeze(img2))
276 | else:
277 | raise ValueError('Wrong input image dimensions.')
278 |
279 |
280 | class ProgressBar(object):
281 | '''A progress bar which can print the progress
282 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
283 | '''
284 |
285 | def __init__(self, task_num=0, bar_width=50, start=True):
286 | self.task_num = task_num
287 | max_bar_width = self._get_max_bar_width()
288 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
289 | self.completed = 0
290 | if start:
291 | self.start()
292 |
293 | def _get_max_bar_width(self):
294 | terminal_width, _ = get_terminal_size()
295 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
296 | if max_bar_width < 10:
297 | print('terminal width is too small ({}), please consider widen the terminal for better '
298 | 'progressbar visualization'.format(terminal_width))
299 | max_bar_width = 10
300 | return max_bar_width
301 |
302 | def start(self):
303 | if self.task_num > 0:
304 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
305 | ' ' * self.bar_width, self.task_num, 'Start...'))
306 | else:
307 | sys.stdout.write('completed: 0, elapsed: 0s')
308 | sys.stdout.flush()
309 | self.start_time = time.time()
310 |
311 | def update(self, msg='In progress...'):
312 | self.completed += 1
313 | elapsed = time.time() - self.start_time
314 | fps = self.completed / elapsed
315 | if self.task_num > 0:
316 | percentage = self.completed / float(self.task_num)
317 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
318 | mark_width = int(self.bar_width * percentage)
319 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
320 | sys.stdout.write('\033[2F') # cursor up 2 lines
321 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
322 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
323 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
324 | else:
325 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
326 | self.completed, int(elapsed + 0.5), fps))
327 | sys.stdout.flush()
328 |
--------------------------------------------------------------------------------
/experiments/pretrain_models/csrnet.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/experiments/pretrain_models/csrnet.pth
--------------------------------------------------------------------------------
/figures/csrnet_fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/figures/csrnet_fig1.png
--------------------------------------------------------------------------------
/figures/csrnet_fig6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hejingwenhejingwen/CSRNet/cfc57bde0bd1e8fcf567a52de488122d7ee5bd6b/figures/csrnet_fig6.png
--------------------------------------------------------------------------------