├── Deblur ├── README.md ├── cal.py ├── config.py ├── data_RGB.py ├── dataset_RGB.py ├── eval.py ├── losses.py ├── test.py ├── train.py ├── trmash.yml └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── arch_utils.cpython-38.pyc │ ├── dataset_utils.cpython-38.pyc │ ├── dir_utils.cpython-38.pyc │ ├── dist_util.cpython-38.pyc │ ├── image_utils.cpython-38.pyc │ ├── logger.cpython-38.pyc │ └── model_utils.cpython-38.pyc │ ├── arch_utils.py │ ├── dataset_utils.py │ ├── dir_utils.py │ ├── dist_util.py │ ├── image_utils.py │ ├── logger.py │ └── model_utils.py ├── Derain ├── README.md ├── cal.py ├── config.py ├── data_RGB.py ├── dataset_RGB.py ├── eval.py ├── evaluate_PSNR_SSIM.py ├── losses.py ├── test.py ├── train.py ├── trmash.yml └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── arch_utils.cpython-38.pyc │ ├── dataset_utils.cpython-38.pyc │ ├── dir_utils.cpython-38.pyc │ ├── dist_util.cpython-38.pyc │ ├── image_utils.cpython-38.pyc │ ├── logger.cpython-38.pyc │ └── model_utils.cpython-38.pyc │ ├── arch_utils.py │ ├── dataset_utils.py │ ├── dir_utils.py │ ├── dist_util.py │ ├── image_utils.py │ ├── logger.py │ └── model_utils.py ├── LICENSE.md ├── MHNet.py ├── README.md ├── fig ├── blur.jpg ├── dau.png ├── deblur.png ├── derain.png ├── fir_h.jpg ├── muti-net.png ├── network.jpg ├── network.png ├── rain.jpg ├── sec_h.jpg └── three_con.png └── pytorch-gradual-warmup-lr ├── build └── lib │ └── warmup_scheduler │ ├── __init__.py │ ├── run.py │ └── scheduler.py ├── dist └── warmup_scheduler-0.3-py3.8.egg ├── setup.py ├── warmup_scheduler.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt └── top_level.txt └── warmup_scheduler ├── __init__.py ├── run.py └── scheduler.py /Deblur/README.md: -------------------------------------------------------------------------------- 1 | ## Training 2 | - Download datasets from the google drive links and place them in Dataset. Your directory tree should look like this 3 | 4 | `GoPro`
5 |   `├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing)
6 |   `└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) 7 | 8 | `HIDE`
9 |   `└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) 10 | 11 | 12 | - Train the model with default arguments by running 13 | 14 | ``` 15 | python train.py 16 | ``` 17 | 18 | ## Evaluation 19 | 20 | ### Download the [model](https://drive.google.com/drive/folders/1qBC3mUoLoCuMyuiseYoZWzvyvImG98TW?usp=drive_link) and place it in ./pre-trained/ 21 | 22 | #### Testing on GoPro dataset 23 | - Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and place them in `./Datasets/GoPro/test/` 24 | - Run 25 | ``` 26 | python test.py --dataset GoPro 27 | ``` 28 | 29 | #### Testing on HIDE dataset 30 | - Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and place them in `./Datasets/HIDE/test/` 31 | - Run 32 | ``` 33 | python test.py --dataset HIDE 34 | ``` 35 | 36 | 37 | 38 | 39 | #### To reproduce PSNR,SSIM scores of the paper on GoPro and HIDE datasets, run 40 | 41 | ``` 42 | python eval.py 43 | ``` 44 | -------------------------------------------------------------------------------- /Deblur/cal.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | import skimage.metrics 6 | import torch 7 | import math 8 | 9 | def calculate_psnr(img1, img2, crop_border, test_y_channel=True): 10 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 11 | 12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 13 | Args: 14 | img1 (ndarray): Images with range [0, 255]. 15 | img2 (ndarray): Images with range [0, 255]. 16 | crop_border (int): Cropped pixels in each edge of an image. These 17 | pixels are not involved in the PSNR calculation. 18 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 19 | Returns: 20 | float: psnr result. 21 | """ 22 | assert img1.shape == img2.shape, ( 23 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 24 | if type(img1) == torch.Tensor: 25 | if len(img1.shape) == 4: 26 | img1 = img1.squeeze(0) 27 | img1 = img1.detach().cpu().numpy().transpose(1,2,0) 28 | if type(img2) == torch.Tensor: 29 | if len(img2.shape) == 4: 30 | img2 = img2.squeeze(0) 31 | img2 = img2.detach().cpu().numpy().transpose(1,2,0) 32 | img1 = img1.astype(np.float64) 33 | img2 = img2.astype(np.float64) 34 | 35 | if crop_border != 0: 36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 38 | 39 | if test_y_channel: 40 | img1 = to_y_channel(img1) 41 | img2 = to_y_channel(img2) 42 | 43 | imdff = np.float32(img1) - np.float32(img2) 44 | rmse = np.sqrt(np.mean(imdff**2)) 45 | ps = 20*np.log10(255/rmse) 46 | return ps 47 | 48 | 49 | def _convert_input_type_range(img): 50 | """Convert the type and range of the input image. 51 | 52 | It converts the input image to np.float32 type and range of [0, 1]. 53 | It is mainly used for pre-processing the input image in colorspace 54 | convertion functions such as rgb2ycbcr and ycbcr2rgb. 55 | Args: 56 | img (ndarray): The input image. It accepts: 57 | 1. np.uint8 type with range [0, 255]; 58 | 2. np.float32 type with range [0, 1]. 59 | Returns: 60 | (ndarray): The converted image with type of np.float32 and range of 61 | [0, 1]. 62 | """ 63 | img_type = img.dtype 64 | img = img.astype(np.float32) 65 | if img_type == np.float32: 66 | pass 67 | elif img_type == np.uint8: 68 | img /= 255. 69 | else: 70 | raise TypeError('The img type should be np.float32 or np.uint8, ' 71 | f'but got {img_type}') 72 | return img 73 | 74 | 75 | def _convert_output_type_range(img, dst_type): 76 | """Convert the type and range of the image according to dst_type. 77 | 78 | It converts the image to desired type and range. If `dst_type` is np.uint8, 79 | images will be converted to np.uint8 type with range [0, 255]. If 80 | `dst_type` is np.float32, it converts the image to np.float32 type with 81 | range [0, 1]. 82 | It is mainly used for post-processing images in colorspace convertion 83 | functions such as rgb2ycbcr and ycbcr2rgb. 84 | Args: 85 | img (ndarray): The image to be converted with np.float32 type and 86 | range [0, 255]. 87 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it 88 | converts the image to np.uint8 type with range [0, 255]. If 89 | dst_type is np.float32, it converts the image to np.float32 type 90 | with range [0, 1]. 91 | Returns: 92 | (ndarray): The converted image with desired type and range. 93 | """ 94 | if dst_type not in (np.uint8, np.float32): 95 | raise TypeError('The dst_type should be np.float32 or np.uint8, ' 96 | f'but got {dst_type}') 97 | if dst_type == np.uint8: 98 | img = img.round() 99 | else: 100 | img /= 255. 101 | 102 | return img.astype(dst_type) 103 | 104 | 105 | def rgb2ycbcr(img, y_only=True): 106 | """Convert a RGB image to YCbCr image. 107 | 108 | This function produces the same results as Matlab's `rgb2ycbcr` function. 109 | It implements the ITU-R BT.601 conversion for standard-definition 110 | television. See more details in 111 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 112 | It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. 113 | In OpenCV, it implements a JPEG conversion. See more details in 114 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 115 | 116 | Args: 117 | img (ndarray): The input image. It accepts: 118 | 1. np.uint8 type with range [0, 255]; 119 | 2. np.float32 type with range [0, 1]. 120 | y_only (bool): Whether to only return Y channel. Default: False. 121 | Returns: 122 | ndarray: The converted YCbCr image. The output image has the same type 123 | and range as input image. 124 | """ 125 | img_type = img.dtype 126 | img = _convert_input_type_range(img) 127 | if y_only: 128 | out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 129 | else: 130 | out_img = np.matmul(img, 131 | [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 132 | [24.966, 112.0, -18.214]]) + [16, 128, 128] 133 | out_img = _convert_output_type_range(out_img, img_type) 134 | return out_img 135 | 136 | 137 | def to_y_channel(img): 138 | """Change to Y channel of YCbCr. 139 | 140 | Args: 141 | img (ndarray): Images with range [0, 255]. 142 | Returns: 143 | (ndarray): Images with range [0, 255] (float type) without round. 144 | """ 145 | img = img.astype(np.float32) / 255. 146 | if img.ndim == 3 and img.shape[2] == 3: 147 | img = rgb2ycbcr(img, y_only=True) 148 | img = img[..., None] 149 | return img * 255. 150 | 151 | def _ssim(img1, img2): 152 | """Calculate SSIM (structural similarity) for one channel images. 153 | 154 | It is called by func:`calculate_ssim`. 155 | 156 | Args: 157 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 158 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 159 | 160 | Returns: 161 | float: ssim result. 162 | """ 163 | 164 | C1 = (0.01 * 255)**2 165 | C2 = (0.03 * 255)**2 166 | 167 | img1 = img1.astype(np.float64) 168 | img2 = img2.astype(np.float64) 169 | kernel = cv2.getGaussianKernel(11, 1.5) 170 | window = np.outer(kernel, kernel.transpose()) 171 | 172 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 173 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 174 | mu1_sq = mu1**2 175 | mu2_sq = mu2**2 176 | mu1_mu2 = mu1 * mu2 177 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 178 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 179 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 180 | 181 | ssim_map = ((2 * mu1_mu2 + C1) * 182 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 183 | (sigma1_sq + sigma2_sq + C2)) 184 | return ssim_map.mean() 185 | 186 | def prepare_for_ssim(img, k): 187 | import torch 188 | with torch.no_grad(): 189 | img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float() 190 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect') 191 | conv.weight.requires_grad = False 192 | conv.weight[:, :, :, :] = 1. / (k * k) 193 | 194 | img = conv(img) 195 | 196 | img = img.squeeze(0).squeeze(0) 197 | img = img[0::k, 0::k] 198 | return img.detach().cpu().numpy() 199 | 200 | def prepare_for_ssim_rgb(img, k): 201 | import torch 202 | with torch.no_grad(): 203 | img = torch.from_numpy(img).float() #HxWx3 204 | 205 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect') 206 | conv.weight.requires_grad = False 207 | conv.weight[:, :, :, :] = 1. / (k * k) 208 | 209 | new_img = [] 210 | 211 | for i in range(3): 212 | new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k]) 213 | 214 | return torch.stack(new_img, dim=2).detach().cpu().numpy() 215 | 216 | def _3d_gaussian_calculator(img, conv3d): 217 | out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) 218 | return out 219 | 220 | def _generate_3d_gaussian_kernel(): 221 | kernel = cv2.getGaussianKernel(11, 1.5) 222 | window = np.outer(kernel, kernel.transpose()) 223 | kernel_3 = cv2.getGaussianKernel(11, 1.5) 224 | kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) 225 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate') 226 | conv3d.weight.requires_grad = False 227 | conv3d.weight[0, 0, :, :, :] = kernel 228 | return conv3d 229 | 230 | def _ssim_3d(img1, img2, max_value): 231 | assert len(img1.shape) == 3 and len(img2.shape) == 3 232 | """Calculate SSIM (structural similarity) for one channel images. 233 | 234 | It is called by func:`calculate_ssim`. 235 | 236 | Args: 237 | img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. 238 | img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. 239 | 240 | Returns: 241 | float: ssim result. 242 | """ 243 | C1 = (0.01 * max_value) ** 2 244 | C2 = (0.03 * max_value) ** 2 245 | img1 = img1.astype(np.float64) 246 | img2 = img2.astype(np.float64) 247 | 248 | kernel = _generate_3d_gaussian_kernel().cuda() 249 | 250 | img1 = torch.tensor(img1).float().cuda() 251 | img2 = torch.tensor(img2).float().cuda() 252 | 253 | 254 | mu1 = _3d_gaussian_calculator(img1, kernel) 255 | mu2 = _3d_gaussian_calculator(img2, kernel) 256 | 257 | mu1_sq = mu1 ** 2 258 | mu2_sq = mu2 ** 2 259 | mu1_mu2 = mu1 * mu2 260 | sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq 261 | sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq 262 | sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2 263 | 264 | ssim_map = ((2 * mu1_mu2 + C1) * 265 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 266 | (sigma1_sq + sigma2_sq + C2)) 267 | return float(ssim_map.mean()) 268 | 269 | def _ssim_cly(img1, img2): 270 | assert len(img1.shape) == 2 and len(img2.shape) == 2 271 | """Calculate SSIM (structural similarity) for one channel images. 272 | 273 | It is called by func:`calculate_ssim`. 274 | 275 | Args: 276 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 277 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 278 | 279 | Returns: 280 | float: ssim result. 281 | """ 282 | 283 | C1 = (0.01 * 255)**2 284 | C2 = (0.03 * 255)**2 285 | img1 = img1.astype(np.float64) 286 | img2 = img2.astype(np.float64) 287 | 288 | kernel = cv2.getGaussianKernel(11, 1.5) 289 | # print(kernel) 290 | window = np.outer(kernel, kernel.transpose()) 291 | 292 | bt = cv2.BORDER_REPLICATE 293 | 294 | mu1 = cv2.filter2D(img1, -1, window, borderType=bt) 295 | mu2 = cv2.filter2D(img2, -1, window,borderType=bt) 296 | 297 | mu1_sq = mu1**2 298 | mu2_sq = mu2**2 299 | mu1_mu2 = mu1 * mu2 300 | sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq 301 | sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq 302 | sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2 303 | 304 | ssim_map = ((2 * mu1_mu2 + C1) * 305 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 306 | (sigma1_sq + sigma2_sq + C2)) 307 | return ssim_map.mean() 308 | def reorder_image(img, input_order='HWC'): 309 | """Reorder images to 'HWC' order. 310 | 311 | If the input_order is (h, w), return (h, w, 1); 312 | If the input_order is (c, h, w), return (h, w, c); 313 | If the input_order is (h, w, c), return as it is. 314 | 315 | Args: 316 | img (ndarray): Input image. 317 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 318 | If the input image shape is (h, w), input_order will not have 319 | effects. Default: 'HWC'. 320 | 321 | Returns: 322 | ndarray: reordered image. 323 | """ 324 | 325 | if input_order not in ['HWC', 'CHW']: 326 | raise ValueError( 327 | f'Wrong input_order {input_order}. Supported input_orders are ' 328 | "'HWC' and 'CHW'") 329 | if len(img.shape) == 2: 330 | img = img[..., None] 331 | if input_order == 'CHW': 332 | img = img.transpose(1, 2, 0) 333 | return img 334 | 335 | 336 | def calculate_ssim(img1, 337 | img2, 338 | crop_border, 339 | input_order='HWC', 340 | test_y_channel=True): 341 | """Calculate SSIM (structural similarity). 342 | 343 | Ref: 344 | Image quality assessment: From error visibility to structural similarity 345 | 346 | The results are the same as that of the official released MATLAB code in 347 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 348 | 349 | For three-channel images, SSIM is calculated for each channel and then 350 | averaged. 351 | 352 | Args: 353 | img1 (ndarray): Images with range [0, 255]. 354 | img2 (ndarray): Images with range [0, 255]. 355 | crop_border (int): Cropped pixels in each edge of an image. These 356 | pixels are not involved in the SSIM calculation. 357 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 358 | Default: 'HWC'. 359 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 360 | 361 | Returns: 362 | float: ssim result. 363 | """ 364 | 365 | assert img1.shape == img2.shape, ( 366 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 367 | if input_order not in ['HWC', 'CHW']: 368 | raise ValueError( 369 | f'Wrong input_order {input_order}. Supported input_orders are ' 370 | '"HWC" and "CHW"') 371 | 372 | if type(img1) == torch.Tensor: 373 | if len(img1.shape) == 4: 374 | img1 = img1.squeeze(0) 375 | img1 = img1.detach().cpu().numpy().transpose(1,2,0) 376 | if type(img2) == torch.Tensor: 377 | if len(img2.shape) == 4: 378 | img2 = img2.squeeze(0) 379 | img2 = img2.detach().cpu().numpy().transpose(1,2,0) 380 | 381 | img1 = reorder_image(img1, input_order=input_order) 382 | img2 = reorder_image(img2, input_order=input_order) 383 | 384 | img1 = img1.astype(np.float64) 385 | img2 = img2.astype(np.float64) 386 | 387 | if crop_border != 0: 388 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 389 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 390 | 391 | if test_y_channel: 392 | img1 = to_y_channel(img1) 393 | img2 = to_y_channel(img2) 394 | return _ssim_cly(img1[..., 0], img2[..., 0]) 395 | 396 | 397 | ssims = [] 398 | # ssims_before = [] 399 | 400 | # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True) 401 | # print('.._skimage', 402 | # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)) 403 | max_value = 1 if img1.max() <= 1 else 255 404 | with torch.no_grad(): 405 | final_ssim = _ssim_3d(img1, img2, max_value) 406 | ssims.append(final_ssim) 407 | 408 | # for i in range(img1.shape[2]): 409 | # ssims_before.append(_ssim(img1, img2)) 410 | 411 | # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before)) 412 | # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False)) 413 | 414 | return np.array(ssims).mean() -------------------------------------------------------------------------------- /Deblur/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | r"""This module provides package-wide configuration management.""" 6 | from typing import Any, List 7 | 8 | from yacs.config import CfgNode as CN 9 | 10 | 11 | class Config(object): 12 | r""" 13 | A collection of all the required configuration parameters. This class is a nested dict-like 14 | structure, with nested keys accessible as attributes. It contains sensible default values for 15 | all the parameters, which may be overriden by (first) through a YAML file and (second) through 16 | a list of attributes and values. 17 | 18 | Extended Summary 19 | ---------------- 20 | This class definition contains default values corresponding to ``joint_training`` phase, as it 21 | is the final training phase and uses almost all the configuration parameters. Modification of 22 | any parameter after instantiating this class is not possible, so you must override required 23 | parameter values in either through ``config_yaml`` file or ``config_override`` list. 24 | 25 | Parameters 26 | ---------- 27 | config_yaml: str 28 | Path to a YAML file containing configuration parameters to override. 29 | config_override: List[Any], optional (default= []) 30 | A list of sequential attributes and values of parameters to override. This happens after 31 | overriding from YAML file. 32 | 33 | Examples 34 | -------- 35 | Let a YAML file named "config.yaml" specify these parameters to override:: 36 | 37 | ALPHA: 1000.0 38 | BETA: 0.5 39 | 40 | >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) 41 | >>> _C.ALPHA # default: 100.0 42 | 1000.0 43 | >>> _C.BATCH_SIZE # default: 256 44 | 2048 45 | >>> _C.BETA # default: 0.1 46 | 0.7 47 | 48 | Attributes 49 | ---------- 50 | """ 51 | 52 | def __init__(self, config_yaml: str, config_override: List[Any] = []): 53 | 54 | self._C = CN() 55 | self._C.GPU = [0] 56 | self._C.VERBOSE = False 57 | 58 | self._C.MODEL = CN() 59 | self._C.MODEL.MODE = 'global' 60 | self._C.MODEL.SESSION = 'ps128_bs1' 61 | 62 | self._C.OPTIM = CN() 63 | self._C.OPTIM.BATCH_SIZE = 1 64 | self._C.OPTIM.NUM_EPOCHS = 100 65 | self._C.OPTIM.NEPOCH_DECAY = [100] 66 | self._C.OPTIM.LR_INITIAL = 0.0002 67 | self._C.OPTIM.LR_MIN = 0.0002 68 | self._C.OPTIM.BETA1 = 0.5 69 | 70 | self._C.TRAINING = CN() 71 | self._C.TRAINING.VAL_AFTER_EVERY = 3 72 | self._C.TRAINING.RESUME = False 73 | self._C.TRAINING.SAVE_IMAGES = False 74 | self._C.TRAINING.TRAIN_DIR = 'images_dir/train' 75 | self._C.TRAINING.VAL_DIR = 'images_dir/val' 76 | self._C.TRAINING.SAVE_DIR = 'checkpoints' 77 | self._C.TRAINING.TRAIN_PS = 64 78 | self._C.TRAINING.VAL_PS = 64 79 | 80 | # Override parameter values from YAML file first, then from override list. 81 | self._C.merge_from_file(config_yaml) 82 | self._C.merge_from_list(config_override) 83 | 84 | # Make an instantiated object of this class immutable. 85 | self._C.freeze() 86 | 87 | def dump(self, file_path: str): 88 | r"""Save config at the specified file path. 89 | 90 | Parameters 91 | ---------- 92 | file_path: str 93 | (YAML) path to save config at. 94 | """ 95 | self._C.dump(stream=open(file_path, "w")) 96 | 97 | def __getattr__(self, attr: str): 98 | return self._C.__getattr__(attr) 99 | 100 | def __repr__(self): 101 | return self._C.__repr__() 102 | -------------------------------------------------------------------------------- /Deblur/data_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest, DataLoaderTest2 3 | 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options) 7 | 8 | def get_validation_data(rgb_dir, img_options): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, img_options) 11 | 12 | def get_test_data(rgb_dir, img_options): 13 | assert os.path.exists(rgb_dir) 14 | return DataLoaderTest(rgb_dir, img_options) 15 | 16 | def get_test_data2(rgb_dir, img_options): 17 | assert os.path.exists(rgb_dir) 18 | return DataLoaderTest2(rgb_dir, img_options) 19 | 20 | 21 | -------------------------------------------------------------------------------- /Deblur/dataset_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torch 5 | from PIL import Image 6 | import torchvision.transforms.functional as TF 7 | from pdb import set_trace as stx 8 | import random 9 | import utils 10 | 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 14 | 15 | 16 | class DataLoaderTrain(Dataset): 17 | def __init__(self, rgb_dir, img_options=None): 18 | super(DataLoaderTrain, self).__init__() 19 | 20 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) 21 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) 22 | 23 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] 24 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)] 25 | 26 | self.img_options = img_options 27 | self.sizex = len(self.tar_filenames) # get the size of target 28 | 29 | self.ps = self.img_options['patch_size'] 30 | 31 | def __len__(self): 32 | return self.sizex 33 | 34 | def __getitem__(self, index): 35 | index_ = index % self.sizex 36 | ps = self.ps 37 | 38 | inp_path = self.inp_filenames[index_] 39 | tar_path = self.tar_filenames[index_] 40 | 41 | inp_img = Image.open(inp_path) 42 | tar_img = Image.open(tar_path) 43 | 44 | w, h = tar_img.size 45 | padw = ps - w if w < ps else 0 46 | padh = ps - h if h < ps else 0 47 | 48 | # Reflect Pad in case image is smaller than patch_size 49 | if padw != 0 or padh != 0: 50 | inp_img = TF.pad(inp_img, (0, 0, padw, padh), padding_mode='reflect') 51 | tar_img = TF.pad(tar_img, (0, 0, padw, padh), padding_mode='reflect') 52 | 53 | 54 | inp_img = TF.to_tensor(inp_img) 55 | tar_img = TF.to_tensor(tar_img) 56 | 57 | hh, ww = tar_img.shape[1], tar_img.shape[2] 58 | 59 | rr = random.randint(0, hh - ps) 60 | cc = random.randint(0, ww - ps) 61 | aug = random.randint(0, 8) 62 | 63 | # Crop patch 64 | inp_img = inp_img[:, rr:rr + ps, cc:cc + ps] 65 | tar_img = tar_img[:, rr:rr + ps, cc:cc + ps] 66 | 67 | # Data Augmentations 68 | if aug == 1: 69 | inp_img = inp_img.flip(1) 70 | tar_img = tar_img.flip(1) 71 | elif aug == 2: 72 | inp_img = inp_img.flip(2) 73 | tar_img = tar_img.flip(2) 74 | elif aug == 3: 75 | inp_img = torch.rot90(inp_img, dims=(1, 2)) 76 | tar_img = torch.rot90(tar_img, dims=(1, 2)) 77 | elif aug == 4: 78 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=2) 79 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=2) 80 | elif aug == 5: 81 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=3) 82 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=3) 83 | elif aug == 6: 84 | inp_img = torch.rot90(inp_img.flip(1), dims=(1, 2)) 85 | tar_img = torch.rot90(tar_img.flip(1), dims=(1, 2)) 86 | elif aug == 7: 87 | inp_img = torch.rot90(inp_img.flip(2), dims=(1, 2)) 88 | tar_img = torch.rot90(tar_img.flip(2), dims=(1, 2)) 89 | 90 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0] 91 | 92 | return tar_img, inp_img, filename 93 | 94 | 95 | class DataLoaderVal(Dataset): 96 | def __init__(self, rgb_dir, img_options=None, rgb_dir2=None): 97 | super(DataLoaderVal, self).__init__() 98 | 99 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) 100 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) 101 | 102 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] 103 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)] 104 | 105 | self.img_options = img_options 106 | self.sizex = len(self.tar_filenames) # get the size of target 107 | 108 | self.ps = self.img_options['patch_size'] 109 | 110 | def __len__(self): 111 | return self.sizex 112 | 113 | def __getitem__(self, index): 114 | index_ = index % self.sizex 115 | ps = self.ps 116 | 117 | inp_path = self.inp_filenames[index_] 118 | tar_path = self.tar_filenames[index_] 119 | 120 | inp_img = Image.open(inp_path) 121 | tar_img = Image.open(tar_path) 122 | 123 | # Validate on center crop 124 | if self.ps is not None: 125 | inp_img = TF.center_crop(inp_img, (ps, ps)) 126 | tar_img = TF.center_crop(tar_img, (ps, ps)) 127 | 128 | inp_img = TF.to_tensor(inp_img) 129 | tar_img = TF.to_tensor(tar_img) 130 | 131 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0] 132 | 133 | 134 | return tar_img, inp_img, filename 135 | 136 | 137 | class DataLoaderTest(Dataset): 138 | def __init__(self, rgb_dir, img_options): 139 | super(DataLoaderTest, self).__init__() 140 | 141 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) 142 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) 143 | 144 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] 145 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)] 146 | 147 | self.inp_size = len(self.inp_filenames) 148 | self.img_options = img_options 149 | 150 | def __len__(self): 151 | return self.inp_size 152 | 153 | def __getitem__(self, index): 154 | path_inp = self.inp_filenames[index] 155 | tar_path = self.tar_filenames[index] 156 | filename = os.path.splitext(os.path.split(path_inp)[-1])[0] 157 | inp = Image.open(path_inp) 158 | tar_img = Image.open(tar_path) 159 | 160 | inp = TF.to_tensor(inp) 161 | tar_img = TF.to_tensor(tar_img) 162 | return inp, tar_img 163 | 164 | -------------------------------------------------------------------------------- /Deblur/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from glob import glob 4 | import cv2 5 | from natsort import natsorted 6 | 7 | from skimage.metrics import structural_similarity,peak_signal_noise_ratio 8 | from cal import calculate_psnr,calculate_ssim 9 | 10 | def read_img(path): 11 | return cv2.imread(path) 12 | 13 | 14 | 15 | 16 | def main(): 17 | datasets = {'GoPr', 'HIDE'} 18 | file_path = os.path.join('resultsmash_g/Raindata/test', 'Rain100H') 19 | gt_path = os.path.join('Dataset/Raindata/test/Rain100H', 'target') 20 | print(file_path) 21 | print(gt_path) 22 | 23 | path_fake = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg'))) 24 | path_real = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg'))) 25 | print(len(path_fake)) 26 | list_psnr = [] 27 | list_ssim = [] 28 | list_mse = [] 29 | 30 | for i in range(len(path_real)): 31 | t1 = read_img(path_real[i]) 32 | t2 = read_img(path_fake[i]) 33 | #result1 = np.zeros(t1.shape,dtype=np.float32) 34 | #result2 = np.zeros(t2.shape,dtype=np.float32) 35 | #cv2.normalize(t1,result1,alpha=0,beta=1,norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_32F) 36 | #cv2.normalize(t2,result2,alpha=0,beta=1,norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_32F) 37 | 38 | 39 | 40 | psnr_num = calculate_psnr(t1, t2,0) 41 | ssim_num = calculate_ssim(t1, t2,0) 42 | 43 | list_ssim.append(ssim_num) 44 | list_psnr.append(psnr_num) 45 | 46 | 47 | 48 | print("AverSSIM:", np.mean(list_ssim)) # ,list_ssim) 49 | print("AverPSNR:", np.mean(list_psnr)) # ,list_ssim) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /Deblur/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class CharbonnierLoss(nn.Module): 7 | """Charbonnier Loss (L1)""" 8 | 9 | def __init__(self, eps=1e-3): 10 | super(CharbonnierLoss, self).__init__() 11 | self.eps = eps 12 | 13 | def forward(self, x, y): 14 | diff = x - y 15 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 16 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 17 | return loss 18 | 19 | class EdgeLoss(nn.Module): 20 | def __init__(self): 21 | super(EdgeLoss, self).__init__() 22 | k = torch.Tensor([[.05, .25, .4, .25, .05]]) 23 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1) 24 | if torch.cuda.is_available(): 25 | self.kernel = self.kernel.cuda() 26 | self.loss = CharbonnierLoss() 27 | 28 | def conv_gauss(self, img): 29 | n_channels, _, kw, kh = self.kernel.shape 30 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 31 | return F.conv2d(img, self.kernel, groups=n_channels) 32 | 33 | def laplacian_kernel(self, current): 34 | filtered = self.conv_gauss(current) # filter 35 | down = filtered[:,:,::2,::2] # downsample 36 | new_filter = torch.zeros_like(filtered) 37 | new_filter[:,:,::2,::2] = down*4 # upsample 38 | filtered = self.conv_gauss(new_filter) # filter 39 | diff = current - filtered 40 | return diff 41 | 42 | def forward(self, x, y): 43 | loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) 44 | return loss 45 | 46 | 47 | class PSNRLoss(nn.Module): 48 | 49 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 50 | super(PSNRLoss, self).__init__() 51 | assert reduction == 'mean' 52 | self.loss_weight = loss_weight 53 | self.scale = 10 / np.log(10) 54 | self.toY = toY 55 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 56 | self.first = True 57 | 58 | def forward(self, pred, target): 59 | assert len(pred.size()) == 4 60 | if self.toY: 61 | if self.first: 62 | self.coef = self.coef.to(pred.device) 63 | self.first = False 64 | 65 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 66 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 67 | 68 | pred, target = pred / 255., target / 255. 69 | pass 70 | assert len(pred.size()) == 4 71 | 72 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 73 | -------------------------------------------------------------------------------- /Deblur/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch.nn as nn 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | import utils 11 | 12 | from data_RGB import get_test_data 13 | from MHNet import MHNet 14 | from skimage import img_as_ubyte 15 | from pdb import set_trace as stx 16 | 17 | parser = argparse.ArgumentParser(description='Image Deraining using MPRNet') 18 | 19 | parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images') 20 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results') 21 | parser.add_argument('--weights', default='./pre-trained/model_best.pth', type=str, help='Path to weights') 22 | parser.add_argument('--gpus', default='2', type=str, help='CUDA_VISIBLE_DEVICES') 23 | 24 | args = parser.parse_args() 25 | 26 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 27 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 28 | 29 | model_restoration = MHNet() 30 | 31 | utils.load_checkpoint(model_restoration,args.weights) 32 | print("===>Testing using weights: ",args.weights) 33 | model_restoration.cuda() 34 | model_restoration = nn.DataParallel(model_restoration) 35 | model_restoration.eval() 36 | 37 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200'] 38 | # datasets = ['Rain100L'] 39 | 40 | for dataset in datasets: 41 | rgb_dir_test = os.path.join(args.input_dir, dataset, 'input') 42 | test_dataset = get_test_data(rgb_dir_test, img_options={}) 43 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) 44 | 45 | result_dir = os.path.join(args.result_dir, dataset) 46 | utils.mkdir(result_dir) 47 | 48 | with torch.no_grad(): 49 | for ii, data_test in enumerate(tqdm(test_loader), 0): 50 | torch.cuda.ipc_collect() 51 | torch.cuda.empty_cache() 52 | 53 | input_ = data_test[0].cuda() 54 | filenames = data_test[1] 55 | 56 | restored = model_restoration(input_) 57 | restored = torch.clamp(restored[0],0,1) 58 | 59 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() 60 | 61 | for batch in range(len(restored)): 62 | restored_img = img_as_ubyte(restored[batch]) 63 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img) 64 | -------------------------------------------------------------------------------- /Deblur/train.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | # coding=utf-8 4 | 5 | import os 6 | from config import Config 7 | 8 | opt = Config('trmash.yml') 9 | 10 | gpus = ','.join([str(i) for i in opt.GPU]) 11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 12 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 13 | 14 | import torch 15 | 16 | torch.backends.cudnn.benchmark = True 17 | 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.optim as optim 21 | from torch.utils.data import DataLoader 22 | import wandb 23 | 24 | import random 25 | import time 26 | import numpy as np 27 | from pathlib import Path 28 | 29 | import utils 30 | from data_RGB import get_training_data, get_validation_data 31 | from MHNet import MHNet 32 | import losses 33 | from warmup_scheduler import GradualWarmupScheduler 34 | from tqdm import tqdm 35 | from pdb import set_trace as stx 36 | 37 | 38 | dir_checkpoint = Path('./mhnetmash/') 39 | 40 | def train(): 41 | 42 | ######### Set Seeds ########### 43 | random.seed(1234) 44 | np.random.seed(1234) 45 | torch.manual_seed(42) 46 | torch.cuda.manual_seed_all(42) 47 | 48 | start_epoch = 1 49 | mode = opt.MODEL.MODE 50 | session = opt.MODEL.SESSION 51 | 52 | result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session) 53 | model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session) 54 | 55 | utils.mkdir(result_dir) 56 | utils.mkdir(model_dir) 57 | 58 | train_dir = opt.TRAINING.TRAIN_DIR 59 | val_dir = opt.TRAINING.VAL_DIR 60 | 61 | ######### Model ########### 62 | model_restoration = MHNet() 63 | print("Total number of param is ", sum(x.numel() for x in model_restoration.parameters())) 64 | model_restoration.cuda() 65 | 66 | device_ids = [i for i in range(torch.cuda.device_count())] 67 | if torch.cuda.device_count() > 1: 68 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") 69 | 70 | 71 | new_lr = opt.OPTIM.LR_INITIAL 72 | 73 | optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8) 74 | 75 | 76 | ######### Scheduler ########### 77 | warmup_epochs = 3 78 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN) 79 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 80 | scheduler.step() 81 | 82 | ######### Resume ########### 83 | if opt.TRAINING.RESUME: 84 | path_chk_rest = './mhnetmash/model_best.pth' 85 | utils.load_checkpoint(model_restoration,path_chk_rest) 86 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 87 | utils.load_optim(optimizer, path_chk_rest) 88 | 89 | for i in range(1, start_epoch): 90 | scheduler.step() 91 | new_lr = scheduler.get_lr()[0] 92 | print('------------------------------------------------------------------------------') 93 | print("==> Resuming Training with learning rate:", new_lr) 94 | print('------------------------------------------------------------------------------') 95 | 96 | if len(device_ids)>1: 97 | model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids) 98 | print("duoka") 99 | 100 | ######### Loss ########### 101 | criterion_mse = losses.PSNRLoss() 102 | ######### DataLoaders ########### 103 | train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) 104 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True) 105 | 106 | val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) 107 | val_loader = DataLoader(dataset=val_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) 108 | 109 | 110 | 111 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1)) 112 | print('===> Loading datasets') 113 | 114 | best_psnr = 0 115 | best_epoch = 0 116 | global_step = 0 117 | 118 | for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1): 119 | epoch_start_time = time.time() 120 | epoch_loss = 0 121 | psnr_train_rgb = [] 122 | psnr_train_rgb1 = [] 123 | psnr_tr = 0 124 | psnr_tr1 = 0 125 | model_restoration.train() 126 | for i, data in enumerate(tqdm(train_loader), 0): 127 | 128 | # zero_grad 129 | for param in model_restoration.parameters(): 130 | param.grad = None 131 | 132 | target = data[0].cuda() 133 | input_ = data[1].cuda() 134 | 135 | restored = model_restoration(input_) 136 | 137 | loss = criterion_mse(restored[0],target) 138 | loss.backward() 139 | optimizer.step() 140 | epoch_loss += loss.item() 141 | global_step = global_step+1 142 | 143 | psnr_te = 0 144 | psnr_te_1 = 0 145 | ssim_te_1 = 0 146 | #### Evaluation #### 147 | if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0: 148 | model_restoration.eval() 149 | psnr_val_rgb = [] 150 | psnr_val_rgb1 = [] 151 | for ii, data_val in enumerate((val_loader), 0): 152 | target = data_val[0].cuda() 153 | input_ = data_val[1].cuda() 154 | 155 | with torch.no_grad(): 156 | restored = model_restoration(input_) 157 | restore = restored[0] 158 | 159 | for res, tar in zip(restore, target): 160 | tssss = utils.torchPSNR(res, tar) 161 | psnr_te = psnr_te + tssss 162 | psnr_val_rgb.append(utils.torchPSNR(res, tar)) 163 | 164 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() 165 | print("te", psnr_te) 166 | 167 | if psnr_val_rgb > best_psnr: 168 | best_psnr = psnr_val_rgb 169 | best_epoch = epoch 170 | Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) 171 | torch.save({'epoch': epoch, 172 | 'state_dict': model_restoration.state_dict(), 173 | 'optimizer': optimizer.state_dict() 174 | }, str(dir_checkpoint / "model_best.pth")) 175 | 176 | 177 | print("[epoch %d PSNR: %.4f best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr)) 178 | Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) 179 | torch.save({'epoch': epoch, 180 | 'state_dict': model_restoration.state_dict(), 181 | 'optimizer': optimizer.state_dict() 182 | }, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1))) 183 | 184 | scheduler.step() 185 | 186 | print("------------------------------------------------------------------") 187 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time, 188 | epoch_loss, scheduler.get_lr()[0])) 189 | print("------------------------------------------------------------------") 190 | 191 | 192 | if __name__=='__main__': 193 | train() 194 | 195 | -------------------------------------------------------------------------------- /Deblur/trmash.yml: -------------------------------------------------------------------------------- 1 | ############### 2 | ## 3 | #### 4 | 5 | 6 | GPU: [0,1,2,3] 7 | 8 | VERBOSE: True 9 | 10 | MODEL: 11 | MODE: 'Deblurring' 12 | SESSION: 'MHNet' 13 | 14 | # Optimization arguments. 15 | OPTIM: 16 | BATCH_SIZE: 32 17 | NUM_EPOCHS: 45000000000 18 | # NEPOCH_DECAY: [10] 19 | LR_INITIAL: 2e-4 20 | LR_MIN: 1e-6 21 | # BETA1: 0.9 22 | 23 | TRAINING: 24 | VAL_AFTER_EVERY: 15 25 | RESUME: True 26 | TRAIN_PS: 256 27 | VAL_PS: 256 28 | TRAIN_DIR: './Datasets/GoPro/train' # path to training data 29 | VAL_DIR: './Datasets/GoPro/test' # path to validation data 30 | SAVE_DIR: './checkpoints' # path to save models and images 31 | # SAVE_IMAGES: False 32 | -------------------------------------------------------------------------------- /Deblur/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .image_utils import * 3 | from .model_utils import * 4 | from .dataset_utils import * 5 | from .logger import (MessageLogger, get_env_info, get_root_logger, 6 | init_tb_logger, init_wandb_logger) 7 | 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'scandir_SIDD', 31 | 'check_resume', 32 | 'sizeof_fmt', 33 | 'padding', 34 | 'create_lmdb_for_reds', 35 | 'create_lmdb_for_gopro', 36 | 'create_lmdb_for_rain13k', 37 | ] 38 | 39 | 40 | -------------------------------------------------------------------------------- /Deblur/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Deblur/utils/__pycache__/arch_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/arch_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Deblur/utils/__pycache__/dataset_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/dataset_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Deblur/utils/__pycache__/dir_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/dir_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Deblur/utils/__pycache__/dist_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/dist_util.cpython-38.pyc -------------------------------------------------------------------------------- /Deblur/utils/__pycache__/image_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/image_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Deblur/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /Deblur/utils/__pycache__/model_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Deblur/utils/__pycache__/model_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Deblur/utils/arch_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | from torch.nn import init as init 6 | from torch.nn.modules.batchnorm import _BatchNorm 7 | 8 | from utils import get_root_logger 9 | 10 | # try: 11 | # from basicsr.models.ops.dcn import (ModulatedDeformConvPack, 12 | # modulated_deform_conv) 13 | # except ImportError: 14 | # # print('Cannot import dcn. Ignore this warning if dcn is not used. ' 15 | # # 'Otherwise install BasicSR with compiling dcn.') 16 | # 17 | 18 | @torch.no_grad() 19 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 20 | """Initialize network weights. 21 | Args: 22 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 23 | scale (float): Scale initialized weights, especially for residual 24 | blocks. Default: 1. 25 | bias_fill (float): The value to fill bias. Default: 0 26 | kwargs (dict): Other arguments for initialization function. 27 | """ 28 | if not isinstance(module_list, list): 29 | module_list = [module_list] 30 | for module in module_list: 31 | for m in module.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal_(m.weight, **kwargs) 34 | m.weight.data *= scale 35 | if m.bias is not None: 36 | m.bias.data.fill_(bias_fill) 37 | elif isinstance(m, nn.Linear): 38 | init.kaiming_normal_(m.weight, **kwargs) 39 | m.weight.data *= scale 40 | if m.bias is not None: 41 | m.bias.data.fill_(bias_fill) 42 | elif isinstance(m, _BatchNorm): 43 | init.constant_(m.weight, 1) 44 | if m.bias is not None: 45 | m.bias.data.fill_(bias_fill) 46 | 47 | 48 | def make_layer(basic_block, num_basic_block, **kwarg): 49 | """Make layers by stacking the same blocks. 50 | Args: 51 | basic_block (nn.module): nn.module class for basic block. 52 | num_basic_block (int): number of blocks. 53 | Returns: 54 | nn.Sequential: Stacked blocks in nn.Sequential. 55 | """ 56 | layers = [] 57 | for _ in range(num_basic_block): 58 | layers.append(basic_block(**kwarg)) 59 | return nn.Sequential(*layers) 60 | 61 | 62 | class ResidualBlockNoBN(nn.Module): 63 | """Residual block without BN. 64 | It has a style of: 65 | ---Conv-ReLU-Conv-+- 66 | |________________| 67 | Args: 68 | num_feat (int): Channel number of intermediate features. 69 | Default: 64. 70 | res_scale (float): Residual scale. Default: 1. 71 | pytorch_init (bool): If set to True, use pytorch default init, 72 | otherwise, use default_init_weights. Default: False. 73 | """ 74 | 75 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 76 | super(ResidualBlockNoBN, self).__init__() 77 | self.res_scale = res_scale 78 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 79 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 80 | self.relu = nn.ReLU(inplace=True) 81 | 82 | if not pytorch_init: 83 | default_init_weights([self.conv1, self.conv2], 0.1) 84 | 85 | def forward(self, x): 86 | identity = x 87 | out = self.conv2(self.relu(self.conv1(x))) 88 | return identity + out * self.res_scale 89 | 90 | 91 | class Upsample(nn.Sequential): 92 | """Upsample module. 93 | Args: 94 | scale (int): Scale factor. Supported scales: 2^n and 3. 95 | num_feat (int): Channel number of intermediate features. 96 | """ 97 | 98 | def __init__(self, scale, num_feat): 99 | m = [] 100 | if (scale & (scale - 1)) == 0: # scale = 2^n 101 | for _ in range(int(math.log(scale, 2))): 102 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 103 | m.append(nn.PixelShuffle(2)) 104 | elif scale == 3: 105 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 106 | m.append(nn.PixelShuffle(3)) 107 | else: 108 | raise ValueError(f'scale {scale} is not supported. ' 109 | 'Supported scales: 2^n and 3.') 110 | super(Upsample, self).__init__(*m) 111 | 112 | 113 | def flow_warp(x, 114 | flow, 115 | interp_mode='bilinear', 116 | padding_mode='zeros', 117 | align_corners=True): 118 | """Warp an image or feature map with optical flow. 119 | Args: 120 | x (Tensor): Tensor with size (n, c, h, w). 121 | flow (Tensor): Tensor with size (n, h, w, 2), normal value. 122 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. 123 | padding_mode (str): 'zeros' or 'border' or 'reflection'. 124 | Default: 'zeros'. 125 | align_corners (bool): Before pytorch 1.3, the default value is 126 | align_corners=True. After pytorch 1.3, the default value is 127 | align_corners=False. Here, we use the True as default. 128 | Returns: 129 | Tensor: Warped image or feature map. 130 | """ 131 | assert x.size()[-2:] == flow.size()[1:3] 132 | _, _, h, w = x.size() 133 | # create mesh grid 134 | grid_y, grid_x = torch.meshgrid( 135 | torch.arange(0, h).type_as(x), 136 | torch.arange(0, w).type_as(x)) 137 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 138 | grid.requires_grad = False 139 | 140 | vgrid = grid + flow 141 | # scale grid to [-1,1] 142 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 143 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 144 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 145 | output = F.grid_sample( 146 | x, 147 | vgrid_scaled, 148 | mode=interp_mode, 149 | padding_mode=padding_mode, 150 | align_corners=align_corners) 151 | 152 | # TODO, what if align_corners=False 153 | return output 154 | 155 | 156 | def resize_flow(flow, 157 | size_type, 158 | sizes, 159 | interp_mode='bilinear', 160 | align_corners=False): 161 | """Resize a flow according to ratio or shape. 162 | Args: 163 | flow (Tensor): Precomputed flow. shape [N, 2, H, W]. 164 | size_type (str): 'ratio' or 'shape'. 165 | sizes (list[int | float]): the ratio for resizing or the final output 166 | shape. 167 | 1) The order of ratio should be [ratio_h, ratio_w]. For 168 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio 169 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., 170 | ratio > 1.0). 171 | 2) The order of output_size should be [out_h, out_w]. 172 | interp_mode (str): The mode of interpolation for resizing. 173 | Default: 'bilinear'. 174 | align_corners (bool): Whether align corners. Default: False. 175 | Returns: 176 | Tensor: Resized flow. 177 | """ 178 | _, _, flow_h, flow_w = flow.size() 179 | if size_type == 'ratio': 180 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) 181 | elif size_type == 'shape': 182 | output_h, output_w = sizes[0], sizes[1] 183 | else: 184 | raise ValueError( 185 | f'Size type should be ratio or shape, but got type {size_type}.') 186 | 187 | input_flow = flow.clone() 188 | ratio_h = output_h / flow_h 189 | ratio_w = output_w / flow_w 190 | input_flow[:, 0, :, :] *= ratio_w 191 | input_flow[:, 1, :, :] *= ratio_h 192 | resized_flow = F.interpolate( 193 | input=input_flow, 194 | size=(output_h, output_w), 195 | mode=interp_mode, 196 | align_corners=align_corners) 197 | return resized_flow 198 | 199 | 200 | # TODO: may write a cpp file 201 | def pixel_unshuffle(x, scale): 202 | """ Pixel unshuffle. 203 | Args: 204 | x (Tensor): Input feature with shape (b, c, hh, hw). 205 | scale (int): Downsample ratio. 206 | Returns: 207 | Tensor: the pixel unshuffled feature. 208 | """ 209 | b, c, hh, hw = x.size() 210 | out_channel = c * (scale**2) 211 | assert hh % scale == 0 and hw % scale == 0 212 | h = hh // scale 213 | w = hw // scale 214 | x_view = x.view(b, c, h, scale, w, scale) 215 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) 216 | 217 | 218 | # class DCNv2Pack(ModulatedDeformConvPack): 219 | # """Modulated deformable conv for deformable alignment. 220 | # 221 | # Different from the official DCNv2Pack, which generates offsets and masks 222 | # from the preceding features, this DCNv2Pack takes another different 223 | # features to generate offsets and masks. 224 | # 225 | # Ref: 226 | # Delving Deep into Deformable Alignment in Video Super-Resolution. 227 | # """ 228 | # 229 | # def forward(self, x, feat): 230 | # out = self.conv_offset(feat) 231 | # o1, o2, mask = torch.chunk(out, 3, dim=1) 232 | # offset = torch.cat((o1, o2), dim=1) 233 | # mask = torch.sigmoid(mask) 234 | # 235 | # offset_absmean = torch.mean(torch.abs(offset)) 236 | # if offset_absmean > 50: 237 | # logger = get_root_logger() 238 | # logger.warning( 239 | # f'Offset abs mean is {offset_absmean}, larger than 50.') 240 | # 241 | # return modulated_deform_conv(x, offset, mask, self.weight, self.bias, 242 | # self.stride, self.padding, self.dilation, 243 | # self.groups, self.deformable_groups) 244 | 245 | 246 | class LayerNormFunction(torch.autograd.Function): 247 | 248 | @staticmethod 249 | def forward(ctx, x, weight, bias, eps): 250 | ctx.eps = eps 251 | N, C, H, W = x.size() 252 | mu = x.mean(1, keepdim=True) 253 | var = (x - mu).pow(2).mean(1, keepdim=True) 254 | y = (x - mu) / (var + eps).sqrt() 255 | ctx.save_for_backward(y, var, weight) 256 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 257 | return y 258 | 259 | @staticmethod 260 | def backward(ctx, grad_output): 261 | eps = ctx.eps 262 | 263 | N, C, H, W = grad_output.size() 264 | y, var, weight = ctx.saved_variables 265 | g = grad_output * weight.view(1, C, 1, 1) 266 | mean_g = g.mean(dim=1, keepdim=True) 267 | 268 | mean_gy = (g * y).mean(dim=1, keepdim=True) 269 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 270 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 271 | dim=0), None 272 | 273 | class LayerNorm2d(nn.Module): 274 | 275 | def __init__(self, channels, eps=1e-6): 276 | super(LayerNorm2d, self).__init__() 277 | self.register_parameter('weight', nn.Parameter(torch.ones(channels))) 278 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) 279 | self.eps = eps 280 | 281 | def forward(self, x): 282 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) 283 | 284 | # handle multiple input 285 | class MySequential(nn.Sequential): 286 | def forward(self, *inputs): 287 | for module in self._modules.values(): 288 | if type(inputs) == tuple: 289 | inputs = module(*inputs) 290 | else: 291 | inputs = module(inputs) 292 | return inputs 293 | 294 | import time 295 | def measure_inference_speed(model, data, max_iter=200, log_interval=50): 296 | model.eval() 297 | 298 | # the first several iterations may be very slow so skip them 299 | num_warmup = 5 300 | pure_inf_time = 0 301 | fps = 0 302 | 303 | # benchmark with 2000 image and take the average 304 | for i in range(max_iter): 305 | 306 | torch.cuda.synchronize() 307 | start_time = time.perf_counter() 308 | 309 | with torch.no_grad(): 310 | model(*data) 311 | 312 | torch.cuda.synchronize() 313 | elapsed = time.perf_counter() - start_time 314 | 315 | if i >= num_warmup: 316 | pure_inf_time += elapsed 317 | if (i + 1) % log_interval == 0: 318 | fps = (i + 1 - num_warmup) / pure_inf_time 319 | print( 320 | f'Done image [{i + 1:<3}/ {max_iter}], ' 321 | f'fps: {fps:.1f} img / s, ' 322 | f'times per image: {1000 / fps:.1f} ms / img', 323 | flush=True) 324 | 325 | if (i + 1) == max_iter: 326 | fps = (i + 1 - num_warmup) / pure_inf_time 327 | print( 328 | f'Overall fps: {fps:.1f} img / s, ' 329 | f'times per image: {1000 / fps:.1f} ms / img', 330 | flush=True) 331 | break 332 | return -------------------------------------------------------------------------------- /Deblur/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MixUp_AUG: 4 | def __init__(self): 5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6])) 6 | 7 | def aug(self, rgb_gt, rgb_noisy): 8 | bs = rgb_gt.size(0) 9 | indices = torch.randperm(bs) 10 | rgb_gt2 = rgb_gt[indices] 11 | rgb_noisy2 = rgb_noisy[indices] 12 | 13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() 14 | 15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 17 | 18 | return rgb_gt, rgb_noisy -------------------------------------------------------------------------------- /Deblur/utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /Deblur/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import subprocess 4 | import torch 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | 8 | 9 | def init_dist(launcher, backend='nccl', **kwargs): 10 | if mp.get_start_method(allow_none=True) is None: 11 | mp.set_start_method('spawn') 12 | if launcher == 'pytorch': 13 | _init_dist_pytorch(backend, **kwargs) 14 | elif launcher == 'slurm': 15 | _init_dist_slurm(backend, **kwargs) 16 | else: 17 | raise ValueError(f'Invalid launcher type: {launcher}') 18 | 19 | 20 | def _init_dist_pytorch(backend, **kwargs): 21 | rank = int(os.environ['RANK']) 22 | num_gpus = torch.cuda.device_count() 23 | torch.cuda.set_device(rank % num_gpus) 24 | dist.init_process_group(backend=backend, **kwargs) 25 | 26 | 27 | def _init_dist_slurm(backend, port=None): 28 | """Initialize slurm distributed training environment. 29 | If argument ``port`` is not specified, then the master port will be system 30 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 31 | environment variable, then a default port ``29500`` will be used. 32 | Args: 33 | backend (str): Backend of torch.distributed. 34 | port (int, optional): Master port. Defaults to None. 35 | """ 36 | proc_id = int(os.environ['SLURM_PROCID']) 37 | ntasks = int(os.environ['SLURM_NTASKS']) 38 | node_list = os.environ['SLURM_NODELIST'] 39 | num_gpus = torch.cuda.device_count() 40 | torch.cuda.set_device(proc_id % num_gpus) 41 | addr = subprocess.getoutput( 42 | f'scontrol show hostname {node_list} | head -n1') 43 | # specify master port 44 | if port is not None: 45 | os.environ['MASTER_PORT'] = str(port) 46 | elif 'MASTER_PORT' in os.environ: 47 | pass # use MASTER_PORT in the environment variable 48 | else: 49 | # 29500 is torch.distributed default port 50 | os.environ['MASTER_PORT'] = '29500' 51 | os.environ['MASTER_ADDR'] = addr 52 | os.environ['WORLD_SIZE'] = str(ntasks) 53 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 54 | os.environ['RANK'] = str(proc_id) 55 | dist.init_process_group(backend=backend) 56 | 57 | 58 | def get_dist_info(): 59 | if dist.is_available(): 60 | initialized = dist.is_initialized() 61 | else: 62 | initialized = False 63 | if initialized: 64 | rank = dist.get_rank() 65 | world_size = dist.get_world_size() 66 | else: 67 | rank = 0 68 | world_size = 1 69 | return rank, world_size 70 | 71 | 72 | def master_only(func): 73 | 74 | @functools.wraps(func) 75 | def wrapper(*args, **kwargs): 76 | rank, _ = get_dist_info() 77 | if rank == 0: 78 | return func(*args, **kwargs) 79 | 80 | return -------------------------------------------------------------------------------- /Deblur/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | def torchPSNR(tar_img, prd_img): 6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 7 | rmse = (imdff**2).mean().sqrt() 8 | ps = 20*torch.log10(1/rmse) 9 | return ps 10 | 11 | def save_img(filepath, img): 12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 13 | 14 | def numpyPSNR(tar_img, prd_img): 15 | imdff = np.float32(prd_img) - np.float32(tar_img) 16 | rmse = np.sqrt(np.mean(imdff**2)) 17 | ps = 20*np.log10(255/rmse) 18 | return ps 19 | -------------------------------------------------------------------------------- /Deblur/utils/logger.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import datetime 8 | import logging 9 | import time 10 | 11 | from .dist_util import get_dist_info, master_only 12 | 13 | 14 | class MessageLogger(): 15 | """Message logger for printing. 16 | Args: 17 | opt (dict): Config. It contains the following keys: 18 | name (str): Exp name. 19 | logger (dict): Contains 'print_freq' (str) for logger interval. 20 | train (dict): Contains 'total_iter' (int) for total iters. 21 | use_tb_logger (bool): Use tensorboard logger. 22 | start_iter (int): Start iter. Default: 1. 23 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 24 | """ 25 | 26 | def __init__(self, opt, start_iter=1, tb_logger=None): 27 | self.exp_name = opt['name'] 28 | self.interval = opt['logger']['print_freq'] 29 | self.start_iter = start_iter 30 | self.max_iters = opt['train']['total_iter'] 31 | self.use_tb_logger = opt['logger']['use_tb_logger'] 32 | self.tb_logger = tb_logger 33 | self.start_time = time.time() 34 | self.logger = get_root_logger() 35 | 36 | @master_only 37 | def __call__(self, log_vars): 38 | """Format logging message. 39 | Args: 40 | log_vars (dict): It contains the following keys: 41 | epoch (int): Epoch number. 42 | iter (int): Current iter. 43 | lrs (list): List for learning rates. 44 | time (float): Iter time. 45 | data_time (float): Data time for each iter. 46 | """ 47 | # epoch, iter, learning rates 48 | epoch = log_vars.pop('epoch') 49 | current_iter = log_vars.pop('iter') 50 | total_iter = log_vars.pop('total_iter') 51 | lrs = log_vars.pop('lrs') 52 | 53 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' 54 | f'iter:{current_iter:8,d}, lr:(') 55 | for v in lrs: 56 | message += f'{v:.3e},' 57 | message += ')] ' 58 | 59 | # time and estimated time 60 | if 'time' in log_vars.keys(): 61 | iter_time = log_vars.pop('time') 62 | data_time = log_vars.pop('data_time') 63 | 64 | total_time = time.time() - self.start_time 65 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 66 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 67 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 68 | message += f'[eta: {eta_str}, ' 69 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 70 | 71 | # other items, especially losses 72 | for k, v in log_vars.items(): 73 | message += f'{k}: {v:.4e} ' 74 | # tensorboard logger 75 | if self.use_tb_logger and 'debug' not in self.exp_name: 76 | normed_step = 10000 * (current_iter / total_iter) 77 | normed_step = int(normed_step) 78 | 79 | if k.startswith('l_'): 80 | self.tb_logger.add_scalar(f'losses/{k}', v, normed_step) 81 | elif k.startswith('m_'): 82 | self.tb_logger.add_scalar(f'metrics/{k}', v, normed_step) 83 | else: 84 | assert 1 == 0 85 | # else: 86 | # self.tb_logger.add_scalar(k, v, current_iter) 87 | self.logger.info(message) 88 | 89 | 90 | @master_only 91 | def init_tb_logger(log_dir): 92 | from torch.utils.tensorboard import SummaryWriter 93 | tb_logger = SummaryWriter(log_dir=log_dir) 94 | return tb_logger 95 | 96 | 97 | @master_only 98 | def init_wandb_logger(opt): 99 | """We now only use wandb to sync tensorboard log.""" 100 | import wandb 101 | logger = logging.getLogger('basicsr') 102 | 103 | project = opt['logger']['wandb']['project'] 104 | resume_id = opt['logger']['wandb'].get('resume_id') 105 | if resume_id: 106 | wandb_id = resume_id 107 | resume = 'allow' 108 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 109 | else: 110 | wandb_id = wandb.util.generate_id() 111 | resume = 'never' 112 | 113 | wandb.init( 114 | id=wandb_id, 115 | resume=resume, 116 | name=opt['name'], 117 | config=opt, 118 | project=project, 119 | sync_tensorboard=True) 120 | 121 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 122 | 123 | 124 | def get_root_logger(logger_name='basicsr', 125 | log_level=logging.INFO, 126 | log_file=None): 127 | """Get the root logger. 128 | The logger will be initialized if it has not been initialized. By default a 129 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 130 | also be added. 131 | Args: 132 | logger_name (str): root logger name. Default: 'basicsr'. 133 | log_file (str | None): The log filename. If specified, a FileHandler 134 | will be added to the root logger. 135 | log_level (int): The root logger level. Note that only the process of 136 | rank 0 is affected, while other processes will set the level to 137 | "Error" and be silent most of the time. 138 | Returns: 139 | logging.Logger: The root logger. 140 | """ 141 | logger = logging.getLogger(logger_name) 142 | # if the logger has been initialized, just return it 143 | if logger.hasHandlers(): 144 | return logger 145 | 146 | format_str = '%(asctime)s %(levelname)s: %(message)s' 147 | logging.basicConfig(format=format_str, level=log_level) 148 | rank, _ = get_dist_info() 149 | if rank != 0: 150 | logger.setLevel('ERROR') 151 | elif log_file is not None: 152 | file_handler = logging.FileHandler(log_file, 'w') 153 | file_handler.setFormatter(logging.Formatter(format_str)) 154 | file_handler.setLevel(log_level) 155 | logger.addHandler(file_handler) 156 | 157 | return logger 158 | 159 | 160 | def get_env_info(): 161 | """Get environment information. 162 | Currently, only log the software version. 163 | """ 164 | import torch 165 | import torchvision 166 | 167 | from basicsr.version import __version__ 168 | msg = r""" 169 | ____ _ _____ ____ 170 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 171 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 172 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 173 | /_____/ \__,_//____//_/ \___//____//_/ |_| 174 | ______ __ __ __ __ 175 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 176 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 177 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 178 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 179 | """ 180 | msg += ('\nVersion Information: ' 181 | f'\n\tBasicSR: {__version__}' 182 | f'\n\tPyTorch: {torch.__version__}' 183 | f'\n\tTorchVision: {torchvision.__version__}') 184 | return -------------------------------------------------------------------------------- /Deblur/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | 5 | def freeze(model): 6 | for p in model.parameters(): 7 | p.requires_grad=False 8 | 9 | def unfreeze(model): 10 | for p in model.parameters(): 11 | p.requires_grad=True 12 | 13 | def is_frozen(model): 14 | x = [p.requires_grad for p in model.parameters()] 15 | return not all(x) 16 | 17 | def save_checkpoint(model_dir, state, session): 18 | epoch = state['epoch'] 19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 20 | torch.save(state, model_out_path) 21 | 22 | def load_checkpoint(model, weights): 23 | checkpoint = torch.load(weights) 24 | try: 25 | model.load_state_dict(checkpoint["state_dict"]) 26 | except: 27 | state_dict = checkpoint["state_dict"] 28 | new_state_dict = OrderedDict() 29 | for k, v in state_dict.items(): 30 | name = k[7:] # remove `module.` 31 | new_state_dict[name] = v 32 | model.load_state_dict(new_state_dict) 33 | 34 | 35 | def load_checkpoint_multigpu(model, weights): 36 | checkpoint = torch.load(weights) 37 | state_dict = checkpoint["state_dict"] 38 | new_state_dict = OrderedDict() 39 | for k, v in state_dict.items(): 40 | name = k[7:] # remove `module.` 41 | new_state_dict[name] = v 42 | model.load_state_dict(new_state_dict) 43 | 44 | def load_start_epoch(weights): 45 | checkpoint = torch.load(weights) 46 | epoch = checkpoint["epoch"] 47 | return epoch 48 | 49 | def load_optim(optimizer, weights): 50 | checkpoint = torch.load(weights) 51 | optimizer.load_state_dict(checkpoint['optimizer']) 52 | # for p in optimizer.param_groups: lr = p['lr'] 53 | # return lr 54 | -------------------------------------------------------------------------------- /Derain/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Training 3 | - Download datasets from the google drive links and place them in this directory. Your directory structure should look something like this 4 | 5 | `Synthetic_Rain_Datasets`
6 |   `├──`[train](https://drive.google.com/drive/folders/1Hnnlc5kI0v9_BtfMytC2LR5VpLAFZtVe?usp=sharing)
7 |   `└──`[test](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs?usp=sharing)
8 |       `├──Test100`
9 |       `├──Rain100H`
10 |       `├──Rain100L`
11 |       `├──Test1200`
12 |        13 | 14 | 15 | - Train the model with default arguments by running 16 | 17 | ``` 18 | python train.py 19 | ``` 20 | 21 | 22 | ## Evaluation 23 | 24 | 1. Download the [model](https://drive.google.com/drive/folders/1qBC3mUoLoCuMyuiseYoZWzvyvImG98TW?usp=drive_link) and place it in `./pretrained_models/` 25 | 26 | 2. Download test datasets (Test100, Rain100H, Rain100L, Test1200) from [here](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs?usp=sharing) and place them in `./Datasets/Synthetic_Rain_Datasets/test/` 27 | 28 | 3. Run 29 | ``` 30 | python test.py 31 | ``` 32 | 33 | #### To reproduce PSNR/SSIM scores of the paper, run 34 | ``` 35 | python eval.py 36 | ``` 37 | -------------------------------------------------------------------------------- /Derain/cal.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | import skimage.metrics 6 | import torch 7 | import math 8 | 9 | def calculate_psnr(img1, img2, crop_border, test_y_channel=True): 10 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 11 | 12 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 13 | Args: 14 | img1 (ndarray): Images with range [0, 255]. 15 | img2 (ndarray): Images with range [0, 255]. 16 | crop_border (int): Cropped pixels in each edge of an image. These 17 | pixels are not involved in the PSNR calculation. 18 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 19 | Returns: 20 | float: psnr result. 21 | """ 22 | assert img1.shape == img2.shape, ( 23 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 24 | if type(img1) == torch.Tensor: 25 | if len(img1.shape) == 4: 26 | img1 = img1.squeeze(0) 27 | img1 = img1.detach().cpu().numpy().transpose(1,2,0) 28 | if type(img2) == torch.Tensor: 29 | if len(img2.shape) == 4: 30 | img2 = img2.squeeze(0) 31 | img2 = img2.detach().cpu().numpy().transpose(1,2,0) 32 | img1 = img1.astype(np.float64) 33 | img2 = img2.astype(np.float64) 34 | 35 | if crop_border != 0: 36 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 37 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 38 | 39 | if test_y_channel: 40 | img1 = to_y_channel(img1) 41 | img2 = to_y_channel(img2) 42 | 43 | imdff = np.float32(img1) - np.float32(img2) 44 | rmse = np.sqrt(np.mean(imdff**2)) 45 | ps = 20*np.log10(255/rmse) 46 | return ps 47 | 48 | 49 | def _convert_input_type_range(img): 50 | """Convert the type and range of the input image. 51 | 52 | It converts the input image to np.float32 type and range of [0, 1]. 53 | It is mainly used for pre-processing the input image in colorspace 54 | convertion functions such as rgb2ycbcr and ycbcr2rgb. 55 | Args: 56 | img (ndarray): The input image. It accepts: 57 | 1. np.uint8 type with range [0, 255]; 58 | 2. np.float32 type with range [0, 1]. 59 | Returns: 60 | (ndarray): The converted image with type of np.float32 and range of 61 | [0, 1]. 62 | """ 63 | img_type = img.dtype 64 | img = img.astype(np.float32) 65 | if img_type == np.float32: 66 | pass 67 | elif img_type == np.uint8: 68 | img /= 255. 69 | else: 70 | raise TypeError('The img type should be np.float32 or np.uint8, ' 71 | f'but got {img_type}') 72 | return img 73 | 74 | 75 | def _convert_output_type_range(img, dst_type): 76 | """Convert the type and range of the image according to dst_type. 77 | 78 | It converts the image to desired type and range. If `dst_type` is np.uint8, 79 | images will be converted to np.uint8 type with range [0, 255]. If 80 | `dst_type` is np.float32, it converts the image to np.float32 type with 81 | range [0, 1]. 82 | It is mainly used for post-processing images in colorspace convertion 83 | functions such as rgb2ycbcr and ycbcr2rgb. 84 | Args: 85 | img (ndarray): The image to be converted with np.float32 type and 86 | range [0, 255]. 87 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it 88 | converts the image to np.uint8 type with range [0, 255]. If 89 | dst_type is np.float32, it converts the image to np.float32 type 90 | with range [0, 1]. 91 | Returns: 92 | (ndarray): The converted image with desired type and range. 93 | """ 94 | if dst_type not in (np.uint8, np.float32): 95 | raise TypeError('The dst_type should be np.float32 or np.uint8, ' 96 | f'but got {dst_type}') 97 | if dst_type == np.uint8: 98 | img = img.round() 99 | else: 100 | img /= 255. 101 | 102 | return img.astype(dst_type) 103 | 104 | 105 | def rgb2ycbcr(img, y_only=True): 106 | """Convert a RGB image to YCbCr image. 107 | 108 | This function produces the same results as Matlab's `rgb2ycbcr` function. 109 | It implements the ITU-R BT.601 conversion for standard-definition 110 | television. See more details in 111 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 112 | It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. 113 | In OpenCV, it implements a JPEG conversion. See more details in 114 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 115 | 116 | Args: 117 | img (ndarray): The input image. It accepts: 118 | 1. np.uint8 type with range [0, 255]; 119 | 2. np.float32 type with range [0, 1]. 120 | y_only (bool): Whether to only return Y channel. Default: False. 121 | Returns: 122 | ndarray: The converted YCbCr image. The output image has the same type 123 | and range as input image. 124 | """ 125 | img_type = img.dtype 126 | img = _convert_input_type_range(img) 127 | if y_only: 128 | out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 129 | else: 130 | out_img = np.matmul(img, 131 | [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 132 | [24.966, 112.0, -18.214]]) + [16, 128, 128] 133 | out_img = _convert_output_type_range(out_img, img_type) 134 | return out_img 135 | 136 | 137 | def to_y_channel(img): 138 | """Change to Y channel of YCbCr. 139 | 140 | Args: 141 | img (ndarray): Images with range [0, 255]. 142 | Returns: 143 | (ndarray): Images with range [0, 255] (float type) without round. 144 | """ 145 | img = img.astype(np.float32) / 255. 146 | if img.ndim == 3 and img.shape[2] == 3: 147 | img = rgb2ycbcr(img, y_only=True) 148 | img = img[..., None] 149 | return img * 255. 150 | 151 | def _ssim(img1, img2): 152 | """Calculate SSIM (structural similarity) for one channel images. 153 | 154 | It is called by func:`calculate_ssim`. 155 | 156 | Args: 157 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 158 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 159 | 160 | Returns: 161 | float: ssim result. 162 | """ 163 | 164 | C1 = (0.01 * 255)**2 165 | C2 = (0.03 * 255)**2 166 | 167 | img1 = img1.astype(np.float64) 168 | img2 = img2.astype(np.float64) 169 | kernel = cv2.getGaussianKernel(11, 1.5) 170 | window = np.outer(kernel, kernel.transpose()) 171 | 172 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 173 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 174 | mu1_sq = mu1**2 175 | mu2_sq = mu2**2 176 | mu1_mu2 = mu1 * mu2 177 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 178 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 179 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 180 | 181 | ssim_map = ((2 * mu1_mu2 + C1) * 182 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 183 | (sigma1_sq + sigma2_sq + C2)) 184 | return ssim_map.mean() 185 | 186 | def prepare_for_ssim(img, k): 187 | import torch 188 | with torch.no_grad(): 189 | img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float() 190 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect') 191 | conv.weight.requires_grad = False 192 | conv.weight[:, :, :, :] = 1. / (k * k) 193 | 194 | img = conv(img) 195 | 196 | img = img.squeeze(0).squeeze(0) 197 | img = img[0::k, 0::k] 198 | return img.detach().cpu().numpy() 199 | 200 | def prepare_for_ssim_rgb(img, k): 201 | import torch 202 | with torch.no_grad(): 203 | img = torch.from_numpy(img).float() #HxWx3 204 | 205 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect') 206 | conv.weight.requires_grad = False 207 | conv.weight[:, :, :, :] = 1. / (k * k) 208 | 209 | new_img = [] 210 | 211 | for i in range(3): 212 | new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k]) 213 | 214 | return torch.stack(new_img, dim=2).detach().cpu().numpy() 215 | 216 | def _3d_gaussian_calculator(img, conv3d): 217 | out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) 218 | return out 219 | 220 | def _generate_3d_gaussian_kernel(): 221 | kernel = cv2.getGaussianKernel(11, 1.5) 222 | window = np.outer(kernel, kernel.transpose()) 223 | kernel_3 = cv2.getGaussianKernel(11, 1.5) 224 | kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) 225 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate') 226 | conv3d.weight.requires_grad = False 227 | conv3d.weight[0, 0, :, :, :] = kernel 228 | return conv3d 229 | 230 | def _ssim_3d(img1, img2, max_value): 231 | assert len(img1.shape) == 3 and len(img2.shape) == 3 232 | """Calculate SSIM (structural similarity) for one channel images. 233 | 234 | It is called by func:`calculate_ssim`. 235 | 236 | Args: 237 | img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. 238 | img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. 239 | 240 | Returns: 241 | float: ssim result. 242 | """ 243 | C1 = (0.01 * max_value) ** 2 244 | C2 = (0.03 * max_value) ** 2 245 | img1 = img1.astype(np.float64) 246 | img2 = img2.astype(np.float64) 247 | 248 | kernel = _generate_3d_gaussian_kernel().cuda() 249 | 250 | img1 = torch.tensor(img1).float().cuda() 251 | img2 = torch.tensor(img2).float().cuda() 252 | 253 | 254 | mu1 = _3d_gaussian_calculator(img1, kernel) 255 | mu2 = _3d_gaussian_calculator(img2, kernel) 256 | 257 | mu1_sq = mu1 ** 2 258 | mu2_sq = mu2 ** 2 259 | mu1_mu2 = mu1 * mu2 260 | sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq 261 | sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq 262 | sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2 263 | 264 | ssim_map = ((2 * mu1_mu2 + C1) * 265 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 266 | (sigma1_sq + sigma2_sq + C2)) 267 | return float(ssim_map.mean()) 268 | 269 | def _ssim_cly(img1, img2): 270 | assert len(img1.shape) == 2 and len(img2.shape) == 2 271 | """Calculate SSIM (structural similarity) for one channel images. 272 | 273 | It is called by func:`calculate_ssim`. 274 | 275 | Args: 276 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 277 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 278 | 279 | Returns: 280 | float: ssim result. 281 | """ 282 | 283 | C1 = (0.01 * 255)**2 284 | C2 = (0.03 * 255)**2 285 | img1 = img1.astype(np.float64) 286 | img2 = img2.astype(np.float64) 287 | 288 | kernel = cv2.getGaussianKernel(11, 1.5) 289 | # print(kernel) 290 | window = np.outer(kernel, kernel.transpose()) 291 | 292 | bt = cv2.BORDER_REPLICATE 293 | 294 | mu1 = cv2.filter2D(img1, -1, window, borderType=bt) 295 | mu2 = cv2.filter2D(img2, -1, window,borderType=bt) 296 | 297 | mu1_sq = mu1**2 298 | mu2_sq = mu2**2 299 | mu1_mu2 = mu1 * mu2 300 | sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq 301 | sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq 302 | sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2 303 | 304 | ssim_map = ((2 * mu1_mu2 + C1) * 305 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 306 | (sigma1_sq + sigma2_sq + C2)) 307 | return ssim_map.mean() 308 | def reorder_image(img, input_order='HWC'): 309 | """Reorder images to 'HWC' order. 310 | 311 | If the input_order is (h, w), return (h, w, 1); 312 | If the input_order is (c, h, w), return (h, w, c); 313 | If the input_order is (h, w, c), return as it is. 314 | 315 | Args: 316 | img (ndarray): Input image. 317 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 318 | If the input image shape is (h, w), input_order will not have 319 | effects. Default: 'HWC'. 320 | 321 | Returns: 322 | ndarray: reordered image. 323 | """ 324 | 325 | if input_order not in ['HWC', 'CHW']: 326 | raise ValueError( 327 | f'Wrong input_order {input_order}. Supported input_orders are ' 328 | "'HWC' and 'CHW'") 329 | if len(img.shape) == 2: 330 | img = img[..., None] 331 | if input_order == 'CHW': 332 | img = img.transpose(1, 2, 0) 333 | return img 334 | 335 | 336 | def calculate_ssim(img1, 337 | img2, 338 | crop_border, 339 | input_order='HWC', 340 | test_y_channel=True): 341 | """Calculate SSIM (structural similarity). 342 | 343 | Ref: 344 | Image quality assessment: From error visibility to structural similarity 345 | 346 | The results are the same as that of the official released MATLAB code in 347 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 348 | 349 | For three-channel images, SSIM is calculated for each channel and then 350 | averaged. 351 | 352 | Args: 353 | img1 (ndarray): Images with range [0, 255]. 354 | img2 (ndarray): Images with range [0, 255]. 355 | crop_border (int): Cropped pixels in each edge of an image. These 356 | pixels are not involved in the SSIM calculation. 357 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 358 | Default: 'HWC'. 359 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 360 | 361 | Returns: 362 | float: ssim result. 363 | """ 364 | 365 | assert img1.shape == img2.shape, ( 366 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 367 | if input_order not in ['HWC', 'CHW']: 368 | raise ValueError( 369 | f'Wrong input_order {input_order}. Supported input_orders are ' 370 | '"HWC" and "CHW"') 371 | 372 | if type(img1) == torch.Tensor: 373 | if len(img1.shape) == 4: 374 | img1 = img1.squeeze(0) 375 | img1 = img1.detach().cpu().numpy().transpose(1,2,0) 376 | if type(img2) == torch.Tensor: 377 | if len(img2.shape) == 4: 378 | img2 = img2.squeeze(0) 379 | img2 = img2.detach().cpu().numpy().transpose(1,2,0) 380 | 381 | img1 = reorder_image(img1, input_order=input_order) 382 | img2 = reorder_image(img2, input_order=input_order) 383 | 384 | img1 = img1.astype(np.float64) 385 | img2 = img2.astype(np.float64) 386 | 387 | if crop_border != 0: 388 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 389 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 390 | 391 | if test_y_channel: 392 | img1 = to_y_channel(img1) 393 | img2 = to_y_channel(img2) 394 | return _ssim_cly(img1[..., 0], img2[..., 0]) 395 | 396 | 397 | ssims = [] 398 | # ssims_before = [] 399 | 400 | # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True) 401 | # print('.._skimage', 402 | # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)) 403 | max_value = 1 if img1.max() <= 1 else 255 404 | with torch.no_grad(): 405 | final_ssim = _ssim_3d(img1, img2, max_value) 406 | ssims.append(final_ssim) 407 | 408 | # for i in range(img1.shape[2]): 409 | # ssims_before.append(_ssim(img1, img2)) 410 | 411 | # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before)) 412 | # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False)) 413 | 414 | return np.array(ssims).mean() -------------------------------------------------------------------------------- /Derain/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | 5 | r"""This module provides package-wide configuration management.""" 6 | from typing import Any, List 7 | 8 | from yacs.config import CfgNode as CN 9 | 10 | 11 | class Config(object): 12 | r""" 13 | A collection of all the required configuration parameters. This class is a nested dict-like 14 | structure, with nested keys accessible as attributes. It contains sensible default values for 15 | all the parameters, which may be overriden by (first) through a YAML file and (second) through 16 | a list of attributes and values. 17 | 18 | Extended Summary 19 | ---------------- 20 | This class definition contains default values corresponding to ``joint_training`` phase, as it 21 | is the final training phase and uses almost all the configuration parameters. Modification of 22 | any parameter after instantiating this class is not possible, so you must override required 23 | parameter values in either through ``config_yaml`` file or ``config_override`` list. 24 | 25 | Parameters 26 | ---------- 27 | config_yaml: str 28 | Path to a YAML file containing configuration parameters to override. 29 | config_override: List[Any], optional (default= []) 30 | A list of sequential attributes and values of parameters to override. This happens after 31 | overriding from YAML file. 32 | 33 | Examples 34 | -------- 35 | Let a YAML file named "config.yaml" specify these parameters to override:: 36 | 37 | ALPHA: 1000.0 38 | BETA: 0.5 39 | 40 | >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) 41 | >>> _C.ALPHA # default: 100.0 42 | 1000.0 43 | >>> _C.BATCH_SIZE # default: 256 44 | 2048 45 | >>> _C.BETA # default: 0.1 46 | 0.7 47 | 48 | Attributes 49 | ---------- 50 | """ 51 | 52 | def __init__(self, config_yaml: str, config_override: List[Any] = []): 53 | 54 | self._C = CN() 55 | self._C.GPU = [0] 56 | self._C.VERBOSE = False 57 | 58 | self._C.MODEL = CN() 59 | self._C.MODEL.MODE = 'global' 60 | self._C.MODEL.SESSION = 'ps128_bs1' 61 | 62 | self._C.OPTIM = CN() 63 | self._C.OPTIM.BATCH_SIZE = 1 64 | self._C.OPTIM.NUM_EPOCHS = 100 65 | self._C.OPTIM.NEPOCH_DECAY = [100] 66 | self._C.OPTIM.LR_INITIAL = 0.0002 67 | self._C.OPTIM.LR_MIN = 0.0002 68 | self._C.OPTIM.BETA1 = 0.5 69 | 70 | self._C.TRAINING = CN() 71 | self._C.TRAINING.VAL_AFTER_EVERY = 3 72 | self._C.TRAINING.RESUME = False 73 | self._C.TRAINING.SAVE_IMAGES = False 74 | self._C.TRAINING.TRAIN_DIR = 'images_dir/train' 75 | self._C.TRAINING.VAL_DIR = 'images_dir/val' 76 | self._C.TRAINING.SAVE_DIR = 'checkpoints' 77 | self._C.TRAINING.TRAIN_PS = 64 78 | self._C.TRAINING.VAL_PS = 64 79 | 80 | # Override parameter values from YAML file first, then from override list. 81 | self._C.merge_from_file(config_yaml) 82 | self._C.merge_from_list(config_override) 83 | 84 | # Make an instantiated object of this class immutable. 85 | self._C.freeze() 86 | 87 | def dump(self, file_path: str): 88 | r"""Save config at the specified file path. 89 | 90 | Parameters 91 | ---------- 92 | file_path: str 93 | (YAML) path to save config at. 94 | """ 95 | self._C.dump(stream=open(file_path, "w")) 96 | 97 | def __getattr__(self, attr: str): 98 | return self._C.__getattr__(attr) 99 | 100 | def __repr__(self): 101 | return self._C.__repr__() 102 | -------------------------------------------------------------------------------- /Derain/data_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest, DataLoaderTest2 3 | 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options) 7 | 8 | def get_validation_data(rgb_dir, img_options): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, img_options) 11 | 12 | def get_test_data(rgb_dir, img_options): 13 | assert os.path.exists(rgb_dir) 14 | return DataLoaderTest(rgb_dir, img_options) 15 | 16 | def get_test_data2(rgb_dir, img_options): 17 | assert os.path.exists(rgb_dir) 18 | return DataLoaderTest2(rgb_dir, img_options) 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /Derain/dataset_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torch 5 | from PIL import Image 6 | import torchvision.transforms.functional as TF 7 | from pdb import set_trace as stx 8 | import random 9 | import utils 10 | 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 14 | 15 | 16 | class DataLoaderTrain(Dataset): 17 | def __init__(self, rgb_dir, img_options=None): 18 | super(DataLoaderTrain, self).__init__() 19 | 20 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) 21 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) 22 | 23 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] 24 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)] 25 | 26 | self.img_options = img_options 27 | self.sizex = len(self.tar_filenames) # get the size of target 28 | 29 | self.ps = self.img_options['patch_size'] 30 | 31 | def __len__(self): 32 | return self.sizex 33 | 34 | def __getitem__(self, index): 35 | index_ = index % self.sizex 36 | ps = self.ps 37 | 38 | inp_path = self.inp_filenames[index_] 39 | tar_path = self.tar_filenames[index_] 40 | 41 | inp_img = Image.open(inp_path) 42 | tar_img = Image.open(tar_path) 43 | 44 | w, h = tar_img.size 45 | padw = ps - w if w < ps else 0 46 | padh = ps - h if h < ps else 0 47 | 48 | # Reflect Pad in case image is smaller than patch_size 49 | if padw != 0 or padh != 0: 50 | inp_img = TF.pad(inp_img, (0, 0, padw, padh), padding_mode='reflect') 51 | tar_img = TF.pad(tar_img, (0, 0, padw, padh), padding_mode='reflect') 52 | 53 | 54 | inp_img = TF.to_tensor(inp_img) 55 | tar_img = TF.to_tensor(tar_img) 56 | 57 | hh, ww = tar_img.shape[1], tar_img.shape[2] 58 | 59 | rr = random.randint(0, hh - ps) 60 | cc = random.randint(0, ww - ps) 61 | aug = random.randint(0, 8) 62 | 63 | # Crop patch 64 | inp_img = inp_img[:, rr:rr + ps, cc:cc + ps] 65 | tar_img = tar_img[:, rr:rr + ps, cc:cc + ps] 66 | 67 | # Data Augmentations 68 | if aug == 1: 69 | inp_img = inp_img.flip(1) 70 | tar_img = tar_img.flip(1) 71 | elif aug == 2: 72 | inp_img = inp_img.flip(2) 73 | tar_img = tar_img.flip(2) 74 | elif aug == 3: 75 | inp_img = torch.rot90(inp_img, dims=(1, 2)) 76 | tar_img = torch.rot90(tar_img, dims=(1, 2)) 77 | elif aug == 4: 78 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=2) 79 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=2) 80 | elif aug == 5: 81 | inp_img = torch.rot90(inp_img, dims=(1, 2), k=3) 82 | tar_img = torch.rot90(tar_img, dims=(1, 2), k=3) 83 | elif aug == 6: 84 | inp_img = torch.rot90(inp_img.flip(1), dims=(1, 2)) 85 | tar_img = torch.rot90(tar_img.flip(1), dims=(1, 2)) 86 | elif aug == 7: 87 | inp_img = torch.rot90(inp_img.flip(2), dims=(1, 2)) 88 | tar_img = torch.rot90(tar_img.flip(2), dims=(1, 2)) 89 | 90 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0] 91 | 92 | return tar_img, inp_img, filename 93 | 94 | 95 | class DataLoaderVal(Dataset): 96 | def __init__(self, rgb_dir, img_options=None, rgb_dir2=None): 97 | super(DataLoaderVal, self).__init__() 98 | 99 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) 100 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) 101 | 102 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] 103 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)] 104 | 105 | self.img_options = img_options 106 | self.sizex = len(self.tar_filenames) # get the size of target 107 | 108 | self.ps = self.img_options['patch_size'] 109 | 110 | def __len__(self): 111 | return self.sizex 112 | 113 | def __getitem__(self, index): 114 | index_ = index % self.sizex 115 | ps = self.ps 116 | 117 | inp_path = self.inp_filenames[index_] 118 | tar_path = self.tar_filenames[index_] 119 | 120 | inp_img = Image.open(inp_path) 121 | tar_img = Image.open(tar_path) 122 | 123 | # Validate on center crop 124 | if self.ps is not None: 125 | inp_img = TF.center_crop(inp_img, (ps, ps)) 126 | tar_img = TF.center_crop(tar_img, (ps, ps)) 127 | 128 | inp_img = TF.to_tensor(inp_img) 129 | tar_img = TF.to_tensor(tar_img) 130 | 131 | filename = os.path.splitext(os.path.split(tar_path)[-1])[0] 132 | 133 | 134 | return tar_img, inp_img, filename 135 | 136 | 137 | class DataLoaderTest(Dataset): 138 | def __init__(self, rgb_dir, img_options): 139 | super(DataLoaderTest, self).__init__() 140 | 141 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) 142 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) 143 | 144 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] 145 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)] 146 | 147 | self.inp_size = len(self.inp_filenames) 148 | self.img_options = img_options 149 | 150 | def __len__(self): 151 | return self.inp_size 152 | 153 | def __getitem__(self, index): 154 | path_inp = self.inp_filenames[index] 155 | tar_path = self.tar_filenames[index] 156 | filename = os.path.splitext(os.path.split(path_inp)[-1])[0] 157 | inp = Image.open(path_inp) 158 | tar_img = Image.open(tar_path) 159 | 160 | inp = TF.to_tensor(inp) 161 | tar_img = TF.to_tensor(tar_img) 162 | return inp, tar_img 163 | 164 | -------------------------------------------------------------------------------- /Derain/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from glob import glob 4 | import cv2 5 | from natsort import natsorted 6 | 7 | from skimage.metrics import structural_similarity,peak_signal_noise_ratio 8 | from cal import calculate_psnr,calculate_ssim 9 | 10 | def read_img(path): 11 | return cv2.imread(path) 12 | 13 | 14 | 15 | 16 | def main(): 17 | datasets = {'GoPr', 'HIDE'} 18 | file_path = os.path.join('resultsmash_g/Raindata/test', 'Rain100H') 19 | gt_path = os.path.join('Dataset/Raindata/test/Rain100H', 'target') 20 | print(file_path) 21 | print(gt_path) 22 | 23 | path_fake = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg'))) 24 | path_real = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg'))) 25 | print(len(path_fake)) 26 | list_psnr = [] 27 | list_ssim = [] 28 | list_mse = [] 29 | 30 | for i in range(len(path_real)): 31 | t1 = read_img(path_real[i]) 32 | t2 = read_img(path_fake[i]) 33 | #result1 = np.zeros(t1.shape,dtype=np.float32) 34 | #result2 = np.zeros(t2.shape,dtype=np.float32) 35 | #cv2.normalize(t1,result1,alpha=0,beta=1,norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_32F) 36 | #cv2.normalize(t2,result2,alpha=0,beta=1,norm_type=cv2.NORM_MINMAX,dtype=cv2.CV_32F) 37 | 38 | 39 | 40 | psnr_num = calculate_psnr(t1, t2,0) 41 | ssim_num = calculate_ssim(t1, t2,0) 42 | 43 | list_ssim.append(ssim_num) 44 | list_psnr.append(psnr_num) 45 | 46 | 47 | 48 | print("AverSSIM:", np.mean(list_ssim)) # ,list_ssim) 49 | print("AverPSNR:", np.mean(list_psnr)) # ,list_ssim) 50 | 51 | 52 | if __name__ == '__main__': 53 | main() -------------------------------------------------------------------------------- /Derain/evaluate_PSNR_SSIM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from glob import glob 4 | from natsort import natsorted 5 | from skimage import io 6 | import cv2 7 | from skimage.metrics import structural_similarity 8 | from tqdm import tqdm 9 | import concurrent.futures 10 | 11 | 12 | def image_align(deblurred, gt): 13 | # this function is based on kohler evaluation code 14 | z = deblurred 15 | c = np.ones_like(z) 16 | x = gt 17 | 18 | zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching 19 | 20 | warp_mode = cv2.MOTION_HOMOGRAPHY 21 | warp_matrix = np.eye(3, 3, dtype=np.float32) 22 | 23 | # Specify the number of iterations. 24 | number_of_iterations = 100 25 | 26 | termination_eps = 0 27 | 28 | criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 29 | number_of_iterations, termination_eps) 30 | 31 | # Run the ECC algorithm. The results are stored in warp_matrix. 32 | (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), 33 | warp_matrix, warp_mode, criteria, inputMask=None) 34 | 35 | target_shape = x.shape 36 | shift = warp_matrix 37 | 38 | zr = cv2.warpPerspective( 39 | zs, 40 | warp_matrix, 41 | (target_shape[1], target_shape[0]), 42 | flags=cv2.INTER_CUBIC + cv2.WARP_INVERSE_MAP, 43 | borderMode=cv2.BORDER_REFLECT) 44 | 45 | cr = cv2.warpPerspective( 46 | np.ones_like(zs, dtype='float32'), 47 | warp_matrix, 48 | (target_shape[1], target_shape[0]), 49 | flags=cv2.INTER_NEAREST + cv2.WARP_INVERSE_MAP, 50 | borderMode=cv2.BORDER_CONSTANT, 51 | borderValue=0) 52 | 53 | zr = zr * cr 54 | xr = x * cr 55 | 56 | return zr, xr, cr, shift 57 | 58 | 59 | def compute_psnr(image_true, image_test, image_mask, data_range=None): 60 | # this function is based on skimage.metrics.peak_signal_noise_ratio 61 | err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask) 62 | return 10 * np.log10((data_range ** 2) / err) 63 | 64 | 65 | def compute_ssim(tar_img, prd_img, cr1): 66 | ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, 67 | use_sample_covariance=False, data_range=1.0, full=True) 68 | ssim_map = ssim_map * cr1 69 | r = int(3.5 * 1.5 + 0.5) # radius as in ndimage 70 | win_size = 2 * r + 1 71 | pad = (win_size - 1) // 2 72 | ssim = ssim_map[pad:-pad, pad:-pad, :] 73 | crop_cr1 = cr1[pad:-pad, pad:-pad, :] 74 | ssim = ssim.sum(axis=0).sum(axis=0) / crop_cr1.sum(axis=0).sum(axis=0) 75 | ssim = np.mean(ssim) 76 | return ssim 77 | 78 | 79 | def proc(filename): 80 | tar, prd = filename 81 | tar_img = io.imread(tar) 82 | prd_img = io.imread(prd) 83 | 84 | tar_img = tar_img.astype(np.float32) / 255.0 85 | prd_img = prd_img.astype(np.float32) / 255.0 86 | 87 | prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img) 88 | 89 | PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1) 90 | SSIM = compute_ssim(tar_img, prd_img, cr1) 91 | return (PSNR, SSIM) 92 | 93 | 94 | def te(): 95 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200'] 96 | 97 | for dataset in datasets: 98 | 99 | file_path = os.path.join('mashresults' , dataset) 100 | gt_path = os.path.join('Datasets','test', dataset, 'target') 101 | 102 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg'))) 103 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg'))) 104 | 105 | assert len(path_list) != 0, "Predicted files not found" 106 | assert len(gt_list) != 0, "Target files not found" 107 | index = 0 108 | psnr, ssim = [], [] 109 | 110 | 111 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 112 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 113 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 114 | psnr.append(PSNR_SSIM[0]) 115 | ssim.append(PSNR_SSIM[1]) 116 | 117 | 118 | 119 | #img_files = [(i, j) for i, j in zip(gt_list, path_list)] 120 | #for i in range(len(img_files)): 121 | # res = proc(img_files[i]) 122 | # psnr.append(res[0]) 123 | # ssim.append(res[1]) 124 | # with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 125 | # for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 126 | # index = index + 1 127 | # print(index) 128 | # psnr.append(PSNR_SSIM[0]) 129 | # ssim.append(PSNR_SSIM[1]) 130 | 131 | avg_psnr = sum(psnr) / len(psnr) 132 | avg_ssim = sum(ssim) / len(ssim) 133 | 134 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 135 | 136 | if __name__=='__main__': 137 | te() -------------------------------------------------------------------------------- /Derain/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | class CharbonnierLoss(nn.Module): 7 | """Charbonnier Loss (L1)""" 8 | 9 | def __init__(self, eps=1e-3): 10 | super(CharbonnierLoss, self).__init__() 11 | self.eps = eps 12 | 13 | def forward(self, x, y): 14 | diff = x - y 15 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 16 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 17 | return loss 18 | 19 | class EdgeLoss(nn.Module): 20 | def __init__(self): 21 | super(EdgeLoss, self).__init__() 22 | k = torch.Tensor([[.05, .25, .4, .25, .05]]) 23 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1) 24 | if torch.cuda.is_available(): 25 | self.kernel = self.kernel.cuda() 26 | self.loss = CharbonnierLoss() 27 | 28 | def conv_gauss(self, img): 29 | n_channels, _, kw, kh = self.kernel.shape 30 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 31 | return F.conv2d(img, self.kernel, groups=n_channels) 32 | 33 | def laplacian_kernel(self, current): 34 | filtered = self.conv_gauss(current) # filter 35 | down = filtered[:,:,::2,::2] # downsample 36 | new_filter = torch.zeros_like(filtered) 37 | new_filter[:,:,::2,::2] = down*4 # upsample 38 | filtered = self.conv_gauss(new_filter) # filter 39 | diff = current - filtered 40 | return diff 41 | 42 | def forward(self, x, y): 43 | loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) 44 | return loss 45 | 46 | 47 | class PSNRLoss(nn.Module): 48 | 49 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 50 | super(PSNRLoss, self).__init__() 51 | assert reduction == 'mean' 52 | self.loss_weight = loss_weight 53 | self.scale = 10 / np.log(10) 54 | self.toY = toY 55 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 56 | self.first = True 57 | 58 | def forward(self, pred, target): 59 | assert len(pred.size()) == 4 60 | if self.toY: 61 | if self.first: 62 | self.coef = self.coef.to(pred.device) 63 | self.first = False 64 | 65 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 66 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 67 | 68 | pred, target = pred / 255., target / 255. 69 | pass 70 | assert len(pred.size()) == 4 71 | 72 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 73 | -------------------------------------------------------------------------------- /Derain/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import torch.nn as nn 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import torch.nn.functional as F 10 | import utils 11 | 12 | from data_RGB import get_test_data 13 | from MHNet import MHNet 14 | from skimage import img_as_ubyte 15 | from pdb import set_trace as stx 16 | 17 | parser = argparse.ArgumentParser(description='Image Deraining using MPRNet') 18 | 19 | parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images') 20 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results') 21 | parser.add_argument('--weights', default='./pre-trained/model_best.pth', type=str, help='Path to weights') 22 | parser.add_argument('--gpus', default='2', type=str, help='CUDA_VISIBLE_DEVICES') 23 | 24 | args = parser.parse_args() 25 | 26 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 27 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 28 | 29 | model_restoration = MHNet() 30 | 31 | utils.load_checkpoint(model_restoration,args.weights) 32 | print("===>Testing using weights: ",args.weights) 33 | model_restoration.cuda() 34 | model_restoration = nn.DataParallel(model_restoration) 35 | model_restoration.eval() 36 | 37 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200'] 38 | # datasets = ['Rain100L'] 39 | 40 | for dataset in datasets: 41 | rgb_dir_test = os.path.join(args.input_dir, dataset, 'input') 42 | test_dataset = get_test_data(rgb_dir_test, img_options={}) 43 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) 44 | 45 | result_dir = os.path.join(args.result_dir, dataset) 46 | utils.mkdir(result_dir) 47 | 48 | with torch.no_grad(): 49 | for ii, data_test in enumerate(tqdm(test_loader), 0): 50 | torch.cuda.ipc_collect() 51 | torch.cuda.empty_cache() 52 | 53 | input_ = data_test[0].cuda() 54 | filenames = data_test[1] 55 | 56 | restored = model_restoration(input_) 57 | restored = torch.clamp(restored[0],0,1) 58 | 59 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() 60 | 61 | for batch in range(len(restored)): 62 | restored_img = img_as_ubyte(restored[batch]) 63 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img) 64 | -------------------------------------------------------------------------------- /Derain/train.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | # coding=utf-8 4 | 5 | import os 6 | from config import Config 7 | 8 | opt = Config('trmash.yml') 9 | 10 | gpus = ','.join([str(i) for i in opt.GPU]) 11 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 12 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 13 | 14 | import torch 15 | 16 | torch.backends.cudnn.benchmark = True 17 | 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.optim as optim 21 | from torch.utils.data import DataLoader 22 | import wandb 23 | 24 | import random 25 | import time 26 | import numpy as np 27 | from pathlib import Path 28 | 29 | import utils 30 | from data_RGB import get_training_data, get_validation_data 31 | from MHNet import MHNet 32 | import losses 33 | from warmup_scheduler import GradualWarmupScheduler 34 | from tqdm import tqdm 35 | from pdb import set_trace as stx 36 | 37 | 38 | dir_checkpoint = Path('./mhnetmash/') 39 | 40 | def train(): 41 | 42 | ######### Set Seeds ########### 43 | random.seed(1234) 44 | np.random.seed(1234) 45 | torch.manual_seed(42) 46 | torch.cuda.manual_seed_all(42) 47 | 48 | start_epoch = 1 49 | mode = opt.MODEL.MODE 50 | session = opt.MODEL.SESSION 51 | 52 | result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session) 53 | model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session) 54 | 55 | utils.mkdir(result_dir) 56 | utils.mkdir(model_dir) 57 | 58 | train_dir = opt.TRAINING.TRAIN_DIR 59 | val_dir = opt.TRAINING.VAL_DIR 60 | 61 | ######### Model ########### 62 | model_restoration = MHNet() 63 | print("Total number of param is ", sum(x.numel() for x in model_restoration.parameters())) 64 | model_restoration.cuda() 65 | 66 | device_ids = [i for i in range(torch.cuda.device_count())] 67 | if torch.cuda.device_count() > 1: 68 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") 69 | 70 | 71 | new_lr = opt.OPTIM.LR_INITIAL 72 | 73 | optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8) 74 | 75 | 76 | ######### Scheduler ########### 77 | warmup_epochs = 3 78 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN) 79 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 80 | scheduler.step() 81 | 82 | ######### Resume ########### 83 | if opt.TRAINING.RESUME: 84 | path_chk_rest = './mhnetmash/model_best.pth' 85 | utils.load_checkpoint(model_restoration,path_chk_rest) 86 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 87 | utils.load_optim(optimizer, path_chk_rest) 88 | 89 | for i in range(1, start_epoch): 90 | scheduler.step() 91 | new_lr = scheduler.get_lr()[0] 92 | print('------------------------------------------------------------------------------') 93 | print("==> Resuming Training with learning rate:", new_lr) 94 | print('------------------------------------------------------------------------------') 95 | 96 | if len(device_ids)>1: 97 | model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids) 98 | print("duoka") 99 | 100 | ######### Loss ########### 101 | criterion_mse = losses.PSNRLoss() 102 | 103 | ######### DataLoaders ########### 104 | train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) 105 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True) 106 | 107 | val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) 108 | val_loader = DataLoader(dataset=val_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) 109 | 110 | 111 | 112 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1)) 113 | print('===> Loading datasets') 114 | 115 | best_psnr = 0 116 | best_epoch = 0 117 | global_step = 0 118 | 119 | for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1): 120 | epoch_start_time = time.time() 121 | epoch_loss = 0 122 | psnr_train_rgb = [] 123 | psnr_train_rgb1 = [] 124 | psnr_tr = 0 125 | psnr_tr1 = 0 126 | model_restoration.train() 127 | for i, data in enumerate(tqdm(train_loader), 0): 128 | 129 | # zero_grad 130 | for param in model_restoration.parameters(): 131 | param.grad = None 132 | 133 | target = data[0].cuda() 134 | input_ = data[1].cuda() 135 | 136 | restored = model_restoration(input_) 137 | 138 | loss = criterion_mse(restored[0],target) 139 | loss.backward() 140 | optimizer.step() 141 | epoch_loss += loss.item() 142 | global_step = global_step+1 143 | 144 | psnr_te = 0 145 | psnr_te_1 = 0 146 | ssim_te_1 = 0 147 | #### Evaluation #### 148 | if epoch % opt.TRAINING.VAL_AFTER_EVERY == 0: 149 | model_restoration.eval() 150 | psnr_val_rgb = [] 151 | psnr_val_rgb1 = [] 152 | for ii, data_val in enumerate((val_loader), 0): 153 | target = data_val[0].cuda() 154 | input_ = data_val[1].cuda() 155 | 156 | with torch.no_grad(): 157 | restored = model_restoration(input_) 158 | restore = restored[0] 159 | 160 | for res, tar in zip(restore, target): 161 | tssss = utils.torchPSNR(res, tar) 162 | psnr_te = psnr_te + tssss 163 | psnr_val_rgb.append(utils.torchPSNR(res, tar)) 164 | 165 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() 166 | print("te", psnr_te) 167 | 168 | if psnr_val_rgb > best_psnr: 169 | best_psnr = psnr_val_rgb 170 | best_epoch = epoch 171 | Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) 172 | torch.save({'epoch': epoch, 173 | 'state_dict': model_restoration.state_dict(), 174 | 'optimizer': optimizer.state_dict() 175 | }, str(dir_checkpoint / "model_best.pth")) 176 | 177 | 178 | print("[epoch %d PSNR: %.4f best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr)) 179 | Path(dir_checkpoint).mkdir(parents=True, exist_ok=True) 180 | torch.save({'epoch': epoch, 181 | 'state_dict': model_restoration.state_dict(), 182 | 'optimizer': optimizer.state_dict() 183 | }, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1))) 184 | 185 | scheduler.step() 186 | 187 | print("------------------------------------------------------------------") 188 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time, 189 | epoch_loss, scheduler.get_lr()[0])) 190 | print("------------------------------------------------------------------") 191 | 192 | 193 | if __name__=='__main__': 194 | train() 195 | 196 | -------------------------------------------------------------------------------- /Derain/trmash.yml: -------------------------------------------------------------------------------- 1 | ############### 2 | ## 3 | #### 4 | 5 | GPU: [2,3] 6 | 7 | VERBOSE: True 8 | 9 | MODEL: 10 | MODE: 'Deraining' 11 | SESSION: 'MHNet' 12 | # Optimization arguments. 13 | OPTIM: 14 | BATCH_SIZE: 16 15 | NUM_EPOCHS: 10000 16 | # NEPOCH_DECAY: [10] 17 | LR_INITIAL: 2e-4 18 | LR_MIN: 1e-6 19 | # BETA1: 0.9 20 | 21 | TRAINING: 22 | 23 | VAL_AFTER_EVERY: 10 24 | RESUME: True 25 | TRAIN_PS: 256 26 | VAL_PS: 128 27 | TRAIN_DIR: './Datasets/train' # path to training data 28 | VAL_DIR: './Datasets/test/Rain100L' # path to validation data 29 | SAVE_DIR: './checkpoints' # path to save models and images 30 | # SAVE_IMAGES: False 31 | -------------------------------------------------------------------------------- /Derain/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .image_utils import * 3 | from .model_utils import * 4 | from .dataset_utils import * 5 | from .logger import (MessageLogger, get_env_info, get_root_logger, 6 | init_tb_logger, init_wandb_logger) 7 | 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'scandir_SIDD', 31 | 'check_resume', 32 | 'sizeof_fmt', 33 | 'padding', 34 | 'create_lmdb_for_reds', 35 | 'create_lmdb_for_gopro', 36 | 'create_lmdb_for_rain13k', 37 | ] 38 | 39 | 40 | -------------------------------------------------------------------------------- /Derain/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /Derain/utils/__pycache__/arch_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/arch_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Derain/utils/__pycache__/dataset_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/dataset_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Derain/utils/__pycache__/dir_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/dir_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Derain/utils/__pycache__/dist_util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/dist_util.cpython-38.pyc -------------------------------------------------------------------------------- /Derain/utils/__pycache__/image_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/image_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Derain/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /Derain/utils/__pycache__/model_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/Derain/utils/__pycache__/model_utils.cpython-38.pyc -------------------------------------------------------------------------------- /Derain/utils/arch_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | from torch.nn import init as init 6 | from torch.nn.modules.batchnorm import _BatchNorm 7 | 8 | from utils import get_root_logger 9 | 10 | # try: 11 | # from basicsr.models.ops.dcn import (ModulatedDeformConvPack, 12 | # modulated_deform_conv) 13 | # except ImportError: 14 | # # print('Cannot import dcn. Ignore this warning if dcn is not used. ' 15 | # # 'Otherwise install BasicSR with compiling dcn.') 16 | # 17 | 18 | @torch.no_grad() 19 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 20 | """Initialize network weights. 21 | Args: 22 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 23 | scale (float): Scale initialized weights, especially for residual 24 | blocks. Default: 1. 25 | bias_fill (float): The value to fill bias. Default: 0 26 | kwargs (dict): Other arguments for initialization function. 27 | """ 28 | if not isinstance(module_list, list): 29 | module_list = [module_list] 30 | for module in module_list: 31 | for m in module.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.kaiming_normal_(m.weight, **kwargs) 34 | m.weight.data *= scale 35 | if m.bias is not None: 36 | m.bias.data.fill_(bias_fill) 37 | elif isinstance(m, nn.Linear): 38 | init.kaiming_normal_(m.weight, **kwargs) 39 | m.weight.data *= scale 40 | if m.bias is not None: 41 | m.bias.data.fill_(bias_fill) 42 | elif isinstance(m, _BatchNorm): 43 | init.constant_(m.weight, 1) 44 | if m.bias is not None: 45 | m.bias.data.fill_(bias_fill) 46 | 47 | 48 | def make_layer(basic_block, num_basic_block, **kwarg): 49 | """Make layers by stacking the same blocks. 50 | Args: 51 | basic_block (nn.module): nn.module class for basic block. 52 | num_basic_block (int): number of blocks. 53 | Returns: 54 | nn.Sequential: Stacked blocks in nn.Sequential. 55 | """ 56 | layers = [] 57 | for _ in range(num_basic_block): 58 | layers.append(basic_block(**kwarg)) 59 | return nn.Sequential(*layers) 60 | 61 | 62 | class ResidualBlockNoBN(nn.Module): 63 | """Residual block without BN. 64 | It has a style of: 65 | ---Conv-ReLU-Conv-+- 66 | |________________| 67 | Args: 68 | num_feat (int): Channel number of intermediate features. 69 | Default: 64. 70 | res_scale (float): Residual scale. Default: 1. 71 | pytorch_init (bool): If set to True, use pytorch default init, 72 | otherwise, use default_init_weights. Default: False. 73 | """ 74 | 75 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 76 | super(ResidualBlockNoBN, self).__init__() 77 | self.res_scale = res_scale 78 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 79 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 80 | self.relu = nn.ReLU(inplace=True) 81 | 82 | if not pytorch_init: 83 | default_init_weights([self.conv1, self.conv2], 0.1) 84 | 85 | def forward(self, x): 86 | identity = x 87 | out = self.conv2(self.relu(self.conv1(x))) 88 | return identity + out * self.res_scale 89 | 90 | 91 | class Upsample(nn.Sequential): 92 | """Upsample module. 93 | Args: 94 | scale (int): Scale factor. Supported scales: 2^n and 3. 95 | num_feat (int): Channel number of intermediate features. 96 | """ 97 | 98 | def __init__(self, scale, num_feat): 99 | m = [] 100 | if (scale & (scale - 1)) == 0: # scale = 2^n 101 | for _ in range(int(math.log(scale, 2))): 102 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 103 | m.append(nn.PixelShuffle(2)) 104 | elif scale == 3: 105 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 106 | m.append(nn.PixelShuffle(3)) 107 | else: 108 | raise ValueError(f'scale {scale} is not supported. ' 109 | 'Supported scales: 2^n and 3.') 110 | super(Upsample, self).__init__(*m) 111 | 112 | 113 | def flow_warp(x, 114 | flow, 115 | interp_mode='bilinear', 116 | padding_mode='zeros', 117 | align_corners=True): 118 | """Warp an image or feature map with optical flow. 119 | Args: 120 | x (Tensor): Tensor with size (n, c, h, w). 121 | flow (Tensor): Tensor with size (n, h, w, 2), normal value. 122 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. 123 | padding_mode (str): 'zeros' or 'border' or 'reflection'. 124 | Default: 'zeros'. 125 | align_corners (bool): Before pytorch 1.3, the default value is 126 | align_corners=True. After pytorch 1.3, the default value is 127 | align_corners=False. Here, we use the True as default. 128 | Returns: 129 | Tensor: Warped image or feature map. 130 | """ 131 | assert x.size()[-2:] == flow.size()[1:3] 132 | _, _, h, w = x.size() 133 | # create mesh grid 134 | grid_y, grid_x = torch.meshgrid( 135 | torch.arange(0, h).type_as(x), 136 | torch.arange(0, w).type_as(x)) 137 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 138 | grid.requires_grad = False 139 | 140 | vgrid = grid + flow 141 | # scale grid to [-1,1] 142 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 143 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 144 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 145 | output = F.grid_sample( 146 | x, 147 | vgrid_scaled, 148 | mode=interp_mode, 149 | padding_mode=padding_mode, 150 | align_corners=align_corners) 151 | 152 | # TODO, what if align_corners=False 153 | return output 154 | 155 | 156 | def resize_flow(flow, 157 | size_type, 158 | sizes, 159 | interp_mode='bilinear', 160 | align_corners=False): 161 | """Resize a flow according to ratio or shape. 162 | Args: 163 | flow (Tensor): Precomputed flow. shape [N, 2, H, W]. 164 | size_type (str): 'ratio' or 'shape'. 165 | sizes (list[int | float]): the ratio for resizing or the final output 166 | shape. 167 | 1) The order of ratio should be [ratio_h, ratio_w]. For 168 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio 169 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., 170 | ratio > 1.0). 171 | 2) The order of output_size should be [out_h, out_w]. 172 | interp_mode (str): The mode of interpolation for resizing. 173 | Default: 'bilinear'. 174 | align_corners (bool): Whether align corners. Default: False. 175 | Returns: 176 | Tensor: Resized flow. 177 | """ 178 | _, _, flow_h, flow_w = flow.size() 179 | if size_type == 'ratio': 180 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) 181 | elif size_type == 'shape': 182 | output_h, output_w = sizes[0], sizes[1] 183 | else: 184 | raise ValueError( 185 | f'Size type should be ratio or shape, but got type {size_type}.') 186 | 187 | input_flow = flow.clone() 188 | ratio_h = output_h / flow_h 189 | ratio_w = output_w / flow_w 190 | input_flow[:, 0, :, :] *= ratio_w 191 | input_flow[:, 1, :, :] *= ratio_h 192 | resized_flow = F.interpolate( 193 | input=input_flow, 194 | size=(output_h, output_w), 195 | mode=interp_mode, 196 | align_corners=align_corners) 197 | return resized_flow 198 | 199 | 200 | # TODO: may write a cpp file 201 | def pixel_unshuffle(x, scale): 202 | """ Pixel unshuffle. 203 | Args: 204 | x (Tensor): Input feature with shape (b, c, hh, hw). 205 | scale (int): Downsample ratio. 206 | Returns: 207 | Tensor: the pixel unshuffled feature. 208 | """ 209 | b, c, hh, hw = x.size() 210 | out_channel = c * (scale**2) 211 | assert hh % scale == 0 and hw % scale == 0 212 | h = hh // scale 213 | w = hw // scale 214 | x_view = x.view(b, c, h, scale, w, scale) 215 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) 216 | 217 | 218 | # class DCNv2Pack(ModulatedDeformConvPack): 219 | # """Modulated deformable conv for deformable alignment. 220 | # 221 | # Different from the official DCNv2Pack, which generates offsets and masks 222 | # from the preceding features, this DCNv2Pack takes another different 223 | # features to generate offsets and masks. 224 | # 225 | # Ref: 226 | # Delving Deep into Deformable Alignment in Video Super-Resolution. 227 | # """ 228 | # 229 | # def forward(self, x, feat): 230 | # out = self.conv_offset(feat) 231 | # o1, o2, mask = torch.chunk(out, 3, dim=1) 232 | # offset = torch.cat((o1, o2), dim=1) 233 | # mask = torch.sigmoid(mask) 234 | # 235 | # offset_absmean = torch.mean(torch.abs(offset)) 236 | # if offset_absmean > 50: 237 | # logger = get_root_logger() 238 | # logger.warning( 239 | # f'Offset abs mean is {offset_absmean}, larger than 50.') 240 | # 241 | # return modulated_deform_conv(x, offset, mask, self.weight, self.bias, 242 | # self.stride, self.padding, self.dilation, 243 | # self.groups, self.deformable_groups) 244 | 245 | 246 | class LayerNormFunction(torch.autograd.Function): 247 | 248 | @staticmethod 249 | def forward(ctx, x, weight, bias, eps): 250 | ctx.eps = eps 251 | N, C, H, W = x.size() 252 | mu = x.mean(1, keepdim=True) 253 | var = (x - mu).pow(2).mean(1, keepdim=True) 254 | y = (x - mu) / (var + eps).sqrt() 255 | ctx.save_for_backward(y, var, weight) 256 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 257 | return y 258 | 259 | @staticmethod 260 | def backward(ctx, grad_output): 261 | eps = ctx.eps 262 | 263 | N, C, H, W = grad_output.size() 264 | y, var, weight = ctx.saved_variables 265 | g = grad_output * weight.view(1, C, 1, 1) 266 | mean_g = g.mean(dim=1, keepdim=True) 267 | 268 | mean_gy = (g * y).mean(dim=1, keepdim=True) 269 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 270 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 271 | dim=0), None 272 | 273 | class LayerNorm2d(nn.Module): 274 | 275 | def __init__(self, channels, eps=1e-6): 276 | super(LayerNorm2d, self).__init__() 277 | self.register_parameter('weight', nn.Parameter(torch.ones(channels))) 278 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) 279 | self.eps = eps 280 | 281 | def forward(self, x): 282 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) 283 | 284 | # handle multiple input 285 | class MySequential(nn.Sequential): 286 | def forward(self, *inputs): 287 | for module in self._modules.values(): 288 | if type(inputs) == tuple: 289 | inputs = module(*inputs) 290 | else: 291 | inputs = module(inputs) 292 | return inputs 293 | 294 | import time 295 | def measure_inference_speed(model, data, max_iter=200, log_interval=50): 296 | model.eval() 297 | 298 | # the first several iterations may be very slow so skip them 299 | num_warmup = 5 300 | pure_inf_time = 0 301 | fps = 0 302 | 303 | # benchmark with 2000 image and take the average 304 | for i in range(max_iter): 305 | 306 | torch.cuda.synchronize() 307 | start_time = time.perf_counter() 308 | 309 | with torch.no_grad(): 310 | model(*data) 311 | 312 | torch.cuda.synchronize() 313 | elapsed = time.perf_counter() - start_time 314 | 315 | if i >= num_warmup: 316 | pure_inf_time += elapsed 317 | if (i + 1) % log_interval == 0: 318 | fps = (i + 1 - num_warmup) / pure_inf_time 319 | print( 320 | f'Done image [{i + 1:<3}/ {max_iter}], ' 321 | f'fps: {fps:.1f} img / s, ' 322 | f'times per image: {1000 / fps:.1f} ms / img', 323 | flush=True) 324 | 325 | if (i + 1) == max_iter: 326 | fps = (i + 1 - num_warmup) / pure_inf_time 327 | print( 328 | f'Overall fps: {fps:.1f} img / s, ' 329 | f'times per image: {1000 / fps:.1f} ms / img', 330 | flush=True) 331 | break 332 | return -------------------------------------------------------------------------------- /Derain/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MixUp_AUG: 4 | def __init__(self): 5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6])) 6 | 7 | def aug(self, rgb_gt, rgb_noisy): 8 | bs = rgb_gt.size(0) 9 | indices = torch.randperm(bs) 10 | rgb_gt2 = rgb_gt[indices] 11 | rgb_noisy2 = rgb_noisy[indices] 12 | 13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() 14 | 15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 17 | 18 | return rgb_gt, rgb_noisy -------------------------------------------------------------------------------- /Derain/utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /Derain/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import subprocess 4 | import torch 5 | import torch.distributed as dist 6 | import torch.multiprocessing as mp 7 | 8 | 9 | def init_dist(launcher, backend='nccl', **kwargs): 10 | if mp.get_start_method(allow_none=True) is None: 11 | mp.set_start_method('spawn') 12 | if launcher == 'pytorch': 13 | _init_dist_pytorch(backend, **kwargs) 14 | elif launcher == 'slurm': 15 | _init_dist_slurm(backend, **kwargs) 16 | else: 17 | raise ValueError(f'Invalid launcher type: {launcher}') 18 | 19 | 20 | def _init_dist_pytorch(backend, **kwargs): 21 | rank = int(os.environ['RANK']) 22 | num_gpus = torch.cuda.device_count() 23 | torch.cuda.set_device(rank % num_gpus) 24 | dist.init_process_group(backend=backend, **kwargs) 25 | 26 | 27 | def _init_dist_slurm(backend, port=None): 28 | """Initialize slurm distributed training environment. 29 | If argument ``port`` is not specified, then the master port will be system 30 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 31 | environment variable, then a default port ``29500`` will be used. 32 | Args: 33 | backend (str): Backend of torch.distributed. 34 | port (int, optional): Master port. Defaults to None. 35 | """ 36 | proc_id = int(os.environ['SLURM_PROCID']) 37 | ntasks = int(os.environ['SLURM_NTASKS']) 38 | node_list = os.environ['SLURM_NODELIST'] 39 | num_gpus = torch.cuda.device_count() 40 | torch.cuda.set_device(proc_id % num_gpus) 41 | addr = subprocess.getoutput( 42 | f'scontrol show hostname {node_list} | head -n1') 43 | # specify master port 44 | if port is not None: 45 | os.environ['MASTER_PORT'] = str(port) 46 | elif 'MASTER_PORT' in os.environ: 47 | pass # use MASTER_PORT in the environment variable 48 | else: 49 | # 29500 is torch.distributed default port 50 | os.environ['MASTER_PORT'] = '29500' 51 | os.environ['MASTER_ADDR'] = addr 52 | os.environ['WORLD_SIZE'] = str(ntasks) 53 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 54 | os.environ['RANK'] = str(proc_id) 55 | dist.init_process_group(backend=backend) 56 | 57 | 58 | def get_dist_info(): 59 | if dist.is_available(): 60 | initialized = dist.is_initialized() 61 | else: 62 | initialized = False 63 | if initialized: 64 | rank = dist.get_rank() 65 | world_size = dist.get_world_size() 66 | else: 67 | rank = 0 68 | world_size = 1 69 | return rank, world_size 70 | 71 | 72 | def master_only(func): 73 | 74 | @functools.wraps(func) 75 | def wrapper(*args, **kwargs): 76 | rank, _ = get_dist_info() 77 | if rank == 0: 78 | return func(*args, **kwargs) 79 | 80 | return -------------------------------------------------------------------------------- /Derain/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | def torchPSNR(tar_img, prd_img): 6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 7 | rmse = (imdff**2).mean().sqrt() 8 | ps = 20*torch.log10(1/rmse) 9 | return ps 10 | 11 | def save_img(filepath, img): 12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 13 | 14 | def numpyPSNR(tar_img, prd_img): 15 | imdff = np.float32(prd_img) - np.float32(tar_img) 16 | rmse = np.sqrt(np.mean(imdff**2)) 17 | ps = 20*np.log10(255/rmse) 18 | return ps 19 | -------------------------------------------------------------------------------- /Derain/utils/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import datetime 3 | import logging 4 | import time 5 | 6 | from .dist_util import get_dist_info, master_only 7 | 8 | 9 | class MessageLogger(): 10 | """Message logger for printing. 11 | Args: 12 | opt (dict): Config. It contains the following keys: 13 | name (str): Exp name. 14 | logger (dict): Contains 'print_freq' (str) for logger interval. 15 | train (dict): Contains 'total_iter' (int) for total iters. 16 | use_tb_logger (bool): Use tensorboard logger. 17 | start_iter (int): Start iter. Default: 1. 18 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 19 | """ 20 | 21 | def __init__(self, opt, start_iter=1, tb_logger=None): 22 | self.exp_name = opt['name'] 23 | self.interval = opt['logger']['print_freq'] 24 | self.start_iter = start_iter 25 | self.max_iters = opt['train']['total_iter'] 26 | self.use_tb_logger = opt['logger']['use_tb_logger'] 27 | self.tb_logger = tb_logger 28 | self.start_time = time.time() 29 | self.logger = get_root_logger() 30 | 31 | @master_only 32 | def __call__(self, log_vars): 33 | """Format logging message. 34 | Args: 35 | log_vars (dict): It contains the following keys: 36 | epoch (int): Epoch number. 37 | iter (int): Current iter. 38 | lrs (list): List for learning rates. 39 | time (float): Iter time. 40 | data_time (float): Data time for each iter. 41 | """ 42 | # epoch, iter, learning rates 43 | epoch = log_vars.pop('epoch') 44 | current_iter = log_vars.pop('iter') 45 | total_iter = log_vars.pop('total_iter') 46 | lrs = log_vars.pop('lrs') 47 | 48 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' 49 | f'iter:{current_iter:8,d}, lr:(') 50 | for v in lrs: 51 | message += f'{v:.3e},' 52 | message += ')] ' 53 | 54 | # time and estimated time 55 | if 'time' in log_vars.keys(): 56 | iter_time = log_vars.pop('time') 57 | data_time = log_vars.pop('data_time') 58 | 59 | total_time = time.time() - self.start_time 60 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 61 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 62 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 63 | message += f'[eta: {eta_str}, ' 64 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 65 | 66 | # other items, especially losses 67 | for k, v in log_vars.items(): 68 | message += f'{k}: {v:.4e} ' 69 | # tensorboard logger 70 | if self.use_tb_logger and 'debug' not in self.exp_name: 71 | normed_step = 10000 * (current_iter / total_iter) 72 | normed_step = int(normed_step) 73 | 74 | if k.startswith('l_'): 75 | self.tb_logger.add_scalar(f'losses/{k}', v, normed_step) 76 | elif k.startswith('m_'): 77 | self.tb_logger.add_scalar(f'metrics/{k}', v, normed_step) 78 | else: 79 | assert 1 == 0 80 | # else: 81 | # self.tb_logger.add_scalar(k, v, current_iter) 82 | self.logger.info(message) 83 | 84 | 85 | @master_only 86 | def init_tb_logger(log_dir): 87 | from torch.utils.tensorboard import SummaryWriter 88 | tb_logger = SummaryWriter(log_dir=log_dir) 89 | return tb_logger 90 | 91 | 92 | @master_only 93 | def init_wandb_logger(opt): 94 | """We now only use wandb to sync tensorboard log.""" 95 | import wandb 96 | logger = logging.getLogger('basicsr') 97 | 98 | project = opt['logger']['wandb']['project'] 99 | resume_id = opt['logger']['wandb'].get('resume_id') 100 | if resume_id: 101 | wandb_id = resume_id 102 | resume = 'allow' 103 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 104 | else: 105 | wandb_id = wandb.util.generate_id() 106 | resume = 'never' 107 | 108 | wandb.init( 109 | id=wandb_id, 110 | resume=resume, 111 | name=opt['name'], 112 | config=opt, 113 | project=project, 114 | sync_tensorboard=True) 115 | 116 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 117 | 118 | 119 | def get_root_logger(logger_name='basicsr', 120 | log_level=logging.INFO, 121 | log_file=None): 122 | """Get the root logger. 123 | The logger will be initialized if it has not been initialized. By default a 124 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 125 | also be added. 126 | Args: 127 | logger_name (str): root logger name. Default: 'basicsr'. 128 | log_file (str | None): The log filename. If specified, a FileHandler 129 | will be added to the root logger. 130 | log_level (int): The root logger level. Note that only the process of 131 | rank 0 is affected, while other processes will set the level to 132 | "Error" and be silent most of the time. 133 | Returns: 134 | logging.Logger: The root logger. 135 | """ 136 | logger = logging.getLogger(logger_name) 137 | # if the logger has been initialized, just return it 138 | if logger.hasHandlers(): 139 | return logger 140 | 141 | format_str = '%(asctime)s %(levelname)s: %(message)s' 142 | logging.basicConfig(format=format_str, level=log_level) 143 | rank, _ = get_dist_info() 144 | if rank != 0: 145 | logger.setLevel('ERROR') 146 | elif log_file is not None: 147 | file_handler = logging.FileHandler(log_file, 'w') 148 | file_handler.setFormatter(logging.Formatter(format_str)) 149 | file_handler.setLevel(log_level) 150 | logger.addHandler(file_handler) 151 | 152 | return logger 153 | 154 | 155 | def get_env_info(): 156 | """Get environment information. 157 | Currently, only log the software version. 158 | """ 159 | import torch 160 | import torchvision 161 | 162 | from basicsr.version import __version__ 163 | msg = r""" 164 | ____ _ _____ ____ 165 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 166 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 167 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 168 | /_____/ \__,_//____//_/ \___//____//_/ |_| 169 | ______ __ __ __ __ 170 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 171 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 172 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 173 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 174 | """ 175 | msg += ('\nVersion Information: ' 176 | f'\n\tBasicSR: {__version__}' 177 | f'\n\tPyTorch: {torch.__version__}' 178 | f'\n\tTorchVision: {torchvision.__version__}') 179 | return 180 | -------------------------------------------------------------------------------- /Derain/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | 5 | def freeze(model): 6 | for p in model.parameters(): 7 | p.requires_grad=False 8 | 9 | def unfreeze(model): 10 | for p in model.parameters(): 11 | p.requires_grad=True 12 | 13 | def is_frozen(model): 14 | x = [p.requires_grad for p in model.parameters()] 15 | return not all(x) 16 | 17 | def save_checkpoint(model_dir, state, session): 18 | epoch = state['epoch'] 19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 20 | torch.save(state, model_out_path) 21 | 22 | def load_checkpoint(model, weights): 23 | checkpoint = torch.load(weights) 24 | try: 25 | model.load_state_dict(checkpoint["state_dict"]) 26 | except: 27 | state_dict = checkpoint["state_dict"] 28 | new_state_dict = OrderedDict() 29 | for k, v in state_dict.items(): 30 | name = k[7:] # remove `module.` 31 | new_state_dict[name] = v 32 | model.load_state_dict(new_state_dict) 33 | 34 | 35 | def load_checkpoint_multigpu(model, weights): 36 | checkpoint = torch.load(weights) 37 | state_dict = checkpoint["state_dict"] 38 | new_state_dict = OrderedDict() 39 | for k, v in state_dict.items(): 40 | name = k[7:] # remove `module.` 41 | new_state_dict[name] = v 42 | model.load_state_dict(new_state_dict) 43 | 44 | def load_start_epoch(weights): 45 | checkpoint = torch.load(weights) 46 | epoch = checkpoint["epoch"] 47 | return epoch 48 | 49 | def load_optim(optimizer, weights): 50 | checkpoint = torch.load(weights) 51 | optimizer.load_state_dict(checkpoint['optimizer']) 52 | # for p in optimizer.param_groups: lr = p['lr'] 53 | # return lr 54 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## ACADEMIC PUBLIC LICENSE 2 | 3 | ### Permissions 4 | :heavy_check_mark: Non-Commercial use 5 | :heavy_check_mark: Modification 6 | :heavy_check_mark: Distribution 7 | :heavy_check_mark: Private use 8 | 9 | ### Limitations 10 | :x: Commercial Use 11 | :x: Liability 12 | :x: Warranty 13 | 14 | ### Conditions 15 | :information_source: License and copyright notice 16 | :information_source: Same License 17 | 18 | MHNet is free for use in noncommercial settings: at academic institutions for teaching and research use, and at non-profit research organizations. 19 | You can use MHNet in your research, academic work, non-commercial work, projects and personal work. We only ask you to credit us appropriately. 20 | 21 | You have the right to use the software, to distribute copies, to receive source code, to change the software and distribute your modifications or the modified software. 22 | If you distribute verbatim or modified copies of this software, they must be distributed under this license. 23 | This license guarantees that you're safe when using MHNet in your work, for teaching or research. 24 | This license guarantees that MHNet will remain available free of charge for nonprofit use. 25 | You can modify MHNet to your purposes, and you can also share your modifications. 26 | 27 | If you would like to use MHNet in commercial settings, contact us so we can discuss options. Send an email to two_bits@163.com 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /MHNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.arch_utils import LayerNorm2d 5 | from einops import rearrange 6 | 7 | 8 | class UpDSample(nn.Module): 9 | def __init__(self, in_channels): 10 | super(UpDSample, self).__init__() 11 | self.up = nn.Sequential( 12 | nn.Conv2d(in_channels, in_channels * 2, 1, bias=False), 13 | nn.PixelShuffle(2) 14 | ) 15 | 16 | def forward(self, x): 17 | x = self.up(x) 18 | return x 19 | 20 | 21 | 22 | 23 | class SimpleGate(nn.Module): 24 | def forward(self, x): 25 | x1, x2 = x.chunk(2, dim=1) 26 | return x1 * x2 27 | 28 | class Attention(nn.Module): 29 | def __init__(self, dim, num_heads, bias): 30 | super(Attention, self).__init__() 31 | self.dim = dim 32 | self.num_heads = num_heads 33 | self.bias = bias 34 | 35 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 36 | 37 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) 38 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) 39 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 40 | self.attn_drop = nn.Dropout(0.) 41 | 42 | self.attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) 43 | self.attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) 44 | self.attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) 45 | self.attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True) 46 | 47 | def forward(self, x): 48 | b, c, h, w = x.shape 49 | 50 | qkv = self.qkv_dwconv(self.qkv(x)) 51 | q, k, v = qkv.chunk(3, dim=1) 52 | 53 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 54 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 55 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 56 | 57 | q = torch.nn.functional.normalize(q, dim=-1) 58 | k = torch.nn.functional.normalize(k, dim=-1) 59 | 60 | _, _, C, _ = q.shape 61 | 62 | mask1 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) 63 | mask2 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) 64 | mask3 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) 65 | mask4 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) 66 | 67 | attn = (q @ k.transpose(-2, -1)) * self.temperature 68 | 69 | index = torch.topk(attn, k=int(C/2), dim=-1, largest=True)[1] 70 | mask1.scatter_(-1, index, 1.) 71 | attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf'))) 72 | 73 | index = torch.topk(attn, k=int(C*2/3), dim=-1, largest=True)[1] 74 | mask2.scatter_(-1, index, 1.) 75 | attn2 = torch.where(mask2 > 0, attn, torch.full_like(attn, float('-inf'))) 76 | 77 | index = torch.topk(attn, k=int(C*3/4), dim=-1, largest=True)[1] 78 | mask3.scatter_(-1, index, 1.) 79 | attn3 = torch.where(mask3 > 0, attn, torch.full_like(attn, float('-inf'))) 80 | 81 | index = torch.topk(attn, k=int(C*4/5), dim=-1, largest=True)[1] 82 | mask4.scatter_(-1, index, 1.) 83 | attn4 = torch.where(mask4 > 0, attn, torch.full_like(attn, float('-inf'))) 84 | 85 | attn1 = attn1.softmax(dim=-1) 86 | attn2 = attn2.softmax(dim=-1) 87 | attn3 = attn3.softmax(dim=-1) 88 | attn4 = attn4.softmax(dim=-1) 89 | 90 | out1 = (attn1 @ v) 91 | out2 = (attn2 @ v) 92 | out3 = (attn3 @ v) 93 | out4 = (attn4 @ v) 94 | 95 | out = out1 * self.attn1 + out2 * self.attn2 + out3 * self.attn3 + out4 * self.attn4 96 | 97 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 98 | 99 | out = self.project_out(out) 100 | return out 101 | 102 | 103 | class BotBlock(nn.Module): 104 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): 105 | super().__init__() 106 | dw_channel = c * DW_Expand 107 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, 108 | bias=True) 109 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, 110 | groups=dw_channel, 111 | bias=True) 112 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 113 | groups=1, bias=True) 114 | 115 | # Simplified Channel Attention 116 | self.sca = nn.Sequential( 117 | nn.AdaptiveAvgPool2d(1), 118 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 119 | groups=1, bias=True), 120 | ) 121 | 122 | # SimpleGate 123 | self.sg = SimpleGate() 124 | 125 | ffn_channel = FFN_Expand * c 126 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, 127 | bias=True) 128 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 129 | groups=1, bias=True) 130 | 131 | self.norm1 = LayerNorm2d(c) 132 | self.norm2 = LayerNorm2d(c) 133 | 134 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 135 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 136 | 137 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 138 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 139 | 140 | def forward(self, inp): 141 | x = inp 142 | 143 | x = self.norm1(x) 144 | 145 | x = self.conv1(x) 146 | x = self.conv2(x) 147 | x = self.sg(x) 148 | x = x * self.sca(x) 149 | x = self.conv3(x) 150 | 151 | x = self.dropout1(x) 152 | 153 | y = inp + x * self.beta 154 | 155 | x = self.conv4(self.norm2(y)) 156 | x = self.sg(x) 157 | x = self.conv5(x) 158 | 159 | x = self.dropout2(x) 160 | 161 | return y + x * self.gamma 162 | 163 | 164 | class Bottneck(nn.Module): 165 | 166 | def __init__(self, img_channel=3, width=32, middle_blk_num=1, enc_blk_nums=[1,1,1,28], dec_blk_nums=[1,1,1,1]): 167 | super().__init__() 168 | 169 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, 170 | groups=1, 171 | bias=True) 172 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, 173 | groups=1, 174 | bias=True) 175 | 176 | self.encoders = nn.ModuleList() 177 | self.decoders = nn.ModuleList() 178 | self.middle_blks = nn.ModuleList() 179 | self.ups = nn.ModuleList() 180 | self.downs = nn.ModuleList() 181 | 182 | chan = width 183 | for num in enc_blk_nums: 184 | self.encoders.append( 185 | nn.Sequential( 186 | *[BotBlock(chan) for _ in range(num)] 187 | ) 188 | ) 189 | self.downs.append( 190 | nn.Conv2d(chan, 2 * chan, 2, 2) 191 | ) 192 | chan = chan * 2 193 | 194 | self.middle_blks = \ 195 | nn.Sequential( 196 | *[Attention(chan, 8, False) for _ in range(middle_blk_num)] 197 | ) 198 | 199 | for num in dec_blk_nums: 200 | self.ups.append( 201 | nn.Sequential( 202 | nn.Conv2d(chan, chan * 2, 1, bias=False), 203 | nn.PixelShuffle(2) 204 | ) 205 | ) 206 | chan = chan // 2 207 | self.decoders.append( 208 | nn.Sequential( 209 | *[BotBlock(chan) for _ in range(num)] 210 | ) 211 | ) 212 | 213 | self.padder_size = 2 ** len(self.encoders) 214 | 215 | def forward(self, inp): 216 | B, C, H, W = inp.shape 217 | inp = self.check_image_size(inp) 218 | 219 | x = self.intro(inp) 220 | 221 | encs = [] 222 | 223 | 224 | for encoder, down in zip(self.encoders, self.downs): 225 | x = encoder(x) 226 | encs.append(x) 227 | x = down(x) 228 | 229 | x = self.middle_blks(x) 230 | decs = [] 231 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 232 | x = up(x) 233 | x = x + enc_skip 234 | x = decoder(x) 235 | decs.append(x) 236 | 237 | x = self.ending(x) 238 | 239 | return encs, decs 240 | 241 | def check_image_size(self, x): 242 | _, _, h, w = x.size() 243 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 244 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 245 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 246 | return x 247 | 248 | 249 | 250 | class CABG(nn.Module): 251 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): 252 | super().__init__() 253 | dw_channel = c * DW_Expand 254 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, 255 | bias=True) 256 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, 257 | groups=dw_channel, 258 | bias=True) 259 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 260 | groups=1, bias=True) 261 | 262 | # Simplified Channel Attention 263 | self.sca = nn.Sequential( 264 | nn.AdaptiveAvgPool2d(1), 265 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 266 | groups=1, bias=True), 267 | ) 268 | 269 | # SimpleGate 270 | self.sg = SimpleGate() 271 | 272 | ffn_channel = FFN_Expand * c 273 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, 274 | bias=True) 275 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, 276 | groups=1, bias=True) 277 | 278 | self.norm1 = LayerNorm2d(c) 279 | self.norm2 = LayerNorm2d(c) 280 | 281 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 282 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 283 | 284 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 285 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 286 | 287 | def forward(self, inp): 288 | x = inp 289 | 290 | x = self.norm1(x) 291 | 292 | x = self.conv1(x) 293 | x = self.conv2(x) 294 | x = self.sg(x) 295 | x = x * self.sca(x) 296 | x = self.conv3(x) 297 | 298 | x = self.dropout1(x) 299 | 300 | y = inp + x * self.beta 301 | 302 | x = self.conv4(self.norm2(y)) 303 | x = self.sg(x) 304 | x = self.conv5(x) 305 | 306 | x = self.dropout2(x) 307 | 308 | return y + x * self.gamma 309 | 310 | class AFFM(nn.Module): 311 | def __init__(self, in_channels, height=3,reduction=8, bias=False): 312 | super(SKFF, self).__init__() 313 | 314 | self.height = height 315 | d = max(int(in_channels/reduction),4) 316 | 317 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 318 | self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias),SimpleGate()) 319 | 320 | self.fcs = nn.ModuleList([]) 321 | for i in range(self.height): 322 | self.fcs.append(nn.Conv2d(d//2, in_channels, kernel_size=1, stride=1,bias=bias)) 323 | 324 | self.softmax = nn.Softmax(dim=1) 325 | 326 | def forward(self, f, f_e, f_d): 327 | 328 | 329 | feats_U = f + f_e + f_d 330 | feats_S = self.avg_pool(feats_U) 331 | feats_Z = self.conv_du(feats_S) 332 | 333 | a = self.softmax(self.fcs[0](feats_Z)) 334 | a_e = self.softmax(self.fcs[1](feats_Z)) 335 | a_d = self.softmax(self.fcs[2](feats_Z)) 336 | 337 | return a*f + f_e*a_e + a_d*f_d 338 | 339 | ########################################################################## 340 | class FRSNet(nn.Module): 341 | def __init__(self, width, bias, num): 342 | super(CASNet, self).__init__() 343 | 344 | 345 | 346 | self.CABG1 = nn.Sequential( 347 | *[BotBlock(width) for _ in range(num)] 348 | ) 349 | self.CABG2 = nn.Sequential( 350 | *[BotBlock(width) for _ in range(num)] 351 | ) 352 | self.CABG3 = nn.Sequential( 353 | *[BotBlock(width) for _ in range(num)] 354 | ) 355 | self.CABG4 = nn.Sequential( 356 | *[BotBlock(width) for _ in range(num)] 357 | ) 358 | 359 | self.up_enc1 = UpDSample( width*2) 360 | self.up_dec1 = UpDSample(width*2) 361 | 362 | self.up_enc2 = nn.Sequential(UpDSample(width*4), UpDSample(width*2)) 363 | self.up_dec2 = nn.Sequential(UpDSample(width*4), UpDSample(width*2)) 364 | 365 | self.up_enc3 = nn.Sequential(UpDSample(width*8), UpDSample(width*4), UpDSample(width*2)) 366 | self.up_dec3 = nn.Sequential(UpDSample(width*8), UpDSample(width*4), UpDSample(width*2)) 367 | 368 | self.norm1 = LayerNorm2d(width) 369 | self.norm2 = LayerNorm2d(width) 370 | self.norm3 = LayerNorm2d(width) 371 | self.norm4 = LayerNorm2d(width) 372 | 373 | self.conv_enc1 = nn.Conv2d(width, width, kernel_size=1, bias=bias) 374 | self.conv_enc2 = nn.Conv2d(width, width, kernel_size=1, bias=bias) 375 | self.conv_enc3 = nn.Conv2d(width, width, kernel_size=1, bias=bias) 376 | self.conv_enc4 = nn.Conv2d(width, width , kernel_size=1, bias=bias) 377 | 378 | self.conv_dnc1 = nn.Conv2d(width, width, kernel_size=1, bias=bias) 379 | self.conv_dnc2 = nn.Conv2d(width, width, kernel_size=1, bias=bias) 380 | self.conv_dnc3 = nn.Conv2d(width, width, kernel_size=1, bias=bias) 381 | self.conv_dnc4 = nn.Conv2d(width, width, kernel_size=1, bias=bias) 382 | 383 | self.skff1 = AFFM(width) 384 | self.skff2 = AFFM(width) 385 | self.skff3 = AFFM(width) 386 | self.skff4 = AFFM(width) 387 | 388 | def forward(self, x, encoder_outs, decoder_outs): 389 | x = self.norm1(x) 390 | x = self.CABG1(x) + x 391 | x = self.skff1(x, self.conv_enc1(encoder_outs[0]), self.conv_dnc1(decoder_outs[3])) 392 | 393 | x = self.norm2(x) 394 | x = self.CABG2(x) + x 395 | x = self.skff2(x, self.conv_enc2(self.up_enc1(encoder_outs[1])), self.conv_dnc1(self.up_dec1(decoder_outs[2]))) 396 | 397 | x = self.norm3(x) 398 | x = self.CABG3(x) + x 399 | x = self.skff3(x, self.conv_enc3(self.up_enc2(encoder_outs[2])), self.conv_dnc1(self.up_dec2(decoder_outs[1]))) 400 | 401 | x = self.norm4(x) 402 | x = self.CABG4(x) + x 403 | x = self.skff4(x, self.conv_enc4(self.up_enc3(encoder_outs[3])), self.conv_dnc1(self.up_dec3(decoder_outs[0]))) 404 | 405 | return x 406 | 407 | 408 | 409 | class MHNet(nn.Module): 410 | def __init__(self, in_c=3, out_c=3, width=64, num_cab=8, 411 | bias=False): 412 | super().__init__() 413 | act = SimpleGate() 414 | self.intro = nn.Conv2d(in_channels=in_c, out_channels=width, kernel_size=3, padding=1, stride=1, 415 | groups=1, 416 | bias=bias) 417 | self.intro2 = nn.Conv2d(in_channels=in_c, out_channels=width, kernel_size=3, padding=1, stride=1, 418 | groups=1, 419 | bias=bias) 420 | self.stage1 = Bottneck(in_c, width) 421 | self.stage2 = FRSNet(width, bias=bias, num=num_cab) 422 | self.concat12 = nn.Conv2d(width*2, width, kernel_size=1, stride=1, padding=0, bias=bias) 423 | self.ending = nn.Conv2d(in_channels=width, out_channels=out_c, kernel_size=3, padding=1, stride=1, 424 | groups=1, 425 | bias=bias) 426 | 427 | 428 | def forward(self, x3_img): 429 | 430 | 431 | x1 = self.intro(x3_img) 432 | 433 | 434 | 435 | 436 | 437 | x1_en, x1_dn = self.stage1(x3_img) 438 | 439 | x2 = self.intro2(x3_img) 440 | t = torch.cat([x2, x1_dn[3]], 1) 441 | 442 | x2_cat = self.concat12(t) 443 | x2_cat = self.stage2(x2_cat, x1_en, x1_dn) 444 | 445 | stage2_img = self.ending(x2_cat) 446 | 447 | return stage2_img + x3_img 448 | 449 | class AvgPool2d(nn.Module): 450 | def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): 451 | super().__init__() 452 | self.kernel_size = kernel_size 453 | self.base_size = base_size 454 | self.auto_pad = auto_pad 455 | 456 | # only used for fast implementation 457 | self.fast_imp = fast_imp 458 | self.rs = [5, 4, 3, 2, 1] 459 | self.max_r1 = self.rs[0] 460 | self.max_r2 = self.rs[0] 461 | self.train_size = train_size 462 | 463 | def extra_repr(self) -> str: 464 | return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( 465 | self.kernel_size, self.base_size, self.kernel_size, self.fast_imp 466 | ) 467 | 468 | def forward(self, x): 469 | if self.kernel_size is None and self.base_size: 470 | train_size = self.train_size 471 | if isinstance(self.base_size, int): 472 | self.base_size = (self.base_size, self.base_size) 473 | self.kernel_size = list(self.base_size) 474 | self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] 475 | self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] 476 | 477 | # only used for fast implementation 478 | self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) 479 | self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) 480 | 481 | if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): 482 | return F.adaptive_avg_pool2d(x, 1) 483 | 484 | if self.fast_imp: # Non-equivalent implementation but faster 485 | h, w = x.shape[2:] 486 | if self.kernel_size[0] >= h and self.kernel_size[1] >= w: 487 | out = F.adaptive_avg_pool2d(x, 1) 488 | else: 489 | r1 = [r for r in self.rs if h % r == 0][0] 490 | r2 = [r for r in self.rs if w % r == 0][0] 491 | # reduction_constraint 492 | r1 = min(self.max_r1, r1) 493 | r2 = min(self.max_r2, r2) 494 | s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) 495 | n, c, h, w = s.shape 496 | k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) 497 | out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) 498 | out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2)) 499 | else: 500 | n, c, h, w = x.shape 501 | s = x.cumsum(dim=-1).cumsum_(dim=-2) 502 | s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience 503 | k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) 504 | s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] 505 | out = s4 + s1 - s2 - s3 506 | out = out / (k1 * k2) 507 | 508 | if self.auto_pad: 509 | n, c, h, w = x.shape 510 | _h, _w = out.shape[2:] 511 | # print(x.shape, self.kernel_size) 512 | pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) 513 | out = torch.nn.functional.pad(out, pad2d, mode='replicate') 514 | 515 | return out 516 | 517 | 518 | 519 | 520 | 521 | 522 | def replace_layers(model, base_size, train_size, fast_imp, **kwargs): 523 | for n, m in model.named_children(): 524 | if len(list(m.children())) > 0: 525 | ## compound module, go inside it 526 | replace_layers(m, base_size, train_size, fast_imp, **kwargs) 527 | 528 | if isinstance(m, nn.AdaptiveAvgPool2d): 529 | pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size) 530 | assert m.output_size == 1 531 | setattr(model, n, pool) 532 | 533 | 534 | 535 | 536 | 537 | class Local_Base(): 538 | def convert(self, *args, train_size, **kwargs): 539 | replace_layers(self, *args, train_size=train_size, **kwargs) 540 | imgs = torch.rand(train_size) 541 | with torch.no_grad(): 542 | self.forward(imgs) 543 | 544 | 545 | class MHNetLocal(Local_Base, MHNet): 546 | def __init__(self, *args, train_size=(1, 3, 256, 256), base_size=None, fast_imp=False, **kwargs): 547 | Local_Base.__init__(self) 548 | MHNet.__init__(self, *args, **kwargs) 549 | N, C, H, W = train_size 550 | if base_size is None: 551 | base_size = (int(H * 1.5), int(W * 1.5)) 552 | 553 | self.eval() 554 | with torch.no_grad(): 555 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) 556 | 557 | 558 | if __name__=='__main__': 559 | x = torch.randn([2, 3, 256, 256]) 560 | model = MHNet() 561 | print("Total number of param is ", sum(i.numel() for i in model.parameters())) 562 | t = model(x) 563 | print(t.shape) 564 | 565 | 566 | from thop import profile 567 | x3 = torch.randn((1, 3, 256, 256)) 568 | flops, params = profile(model, inputs=(x3, )) 569 | print('FLOPs = ' + str(flops/1000**3) + 'G') 570 | print('Params = ' + str(params/1000**2) + 'M') 571 | from ptflops import get_model_complexity_info 572 | FLOPS = 0 573 | inp_shape=(3,256,256) 574 | macs, params = get_model_complexity_info(model, inp_shape, verbose=False, print_per_layer_stat=True) 575 | #print(params) 576 | macs = float(macs[:-4]) + FLOPS / 10 ** 9 577 | 578 | 579 | 580 | print('mac', macs, params) 581 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Mixed Hierarchy Network for Image Restoration 4 | 5 | [![paper](https://img.shields.io/badge/arXiv-Paper-brightgreen)](http://arxiv.org/abs/2302.09554) 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
14 | 15 | 16 | > **Abstract:** Image restoration is a long-standing low-level vision problem, e.g., deblurring and deraining. In the process of image restoration, it is necessary to consider not only the spatial details and contextual information of restoration to ensure the quality but also the system complexity. Although many methods have been able to guarantee the quality of image restoration, the system complexity of the state-of-the-art (SOTA) methods is increasing as well. Motivated by this, we present a mixed hierarchy network that can balance these competing goals. Our main proposal is a mixed hierarchy architecture, that progressively recovers contextual information and spatial details from degraded images while we use simple blocks to reduce system complexity. 17 | Specifically, our model first learns the contextual information at the lower hierarchy using encoder-decoder architectures, and then at the higher hierarchy operates on full-resolution to retain spatial detail information. 18 | Incorporating information exchange between different hierarchies is a crucial aspect of our mixed hierarchy architecture. To achieve this, we design an adaptive feature fusion mechanism that selectively aggregates spatially-precise details and rich contextual information. In addition, we propose a selective multi-head attention mechanism with linear time complexity as the middle block of the encoder-decoder to adaptively retain the most crucial attention scores. 19 | What's more, we use the nonlinear activation free block as our base block to reduce the system complexity. 20 | The resulting tightly interlinked hierarchy architecture, named as MHNet, delivers strong performance gains on several image restoration tasks, including image deraining, and deblurring. 21 | ## Network Architecture 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 |

Overall Framework of MHNet

(a) Encoder-decoder subnetwork. (b) Selective multi-head attention mechanism (SMAM) (c) The architecture of nonlinear activation free block (NAFBlock). (d) Simplified Channel Attention (SCA).

(a) The architecture of nonlinear activation free block groups (NAFG). Each NAFG further contains multiple nonlinear activation free blocks (NAFBlocks). (b) Adaptive feature fusion mechanism (AFFM) between an encoder-decoder subnetwork and FRSNet.

43 | 44 | 45 | ## Our code will be released after the paper is published 46 | 47 | ## Installation 48 | The model is built in PyTorch 1.1.0 and tested on Ubuntu 16.04 environment (Python3.7, CUDA9.0, cuDNN7.5). 49 | 50 | For installing, follow these intructions 51 | ``` 52 | conda create -n pytorch1 python=3.7 53 | conda activate pytorch1 54 | conda install pytorch=1.1 torchvision=0.3 cudatoolkit=9.0 -c pytorch 55 | pip install matplotlib scikit-image opencv-python yacs joblib natsort h5py tqdm 56 | ``` 57 | 58 | Install warmup scheduler 59 | 60 | ``` 61 | cd pytorch-gradual-warmup-lr; python setup.py install; cd .. 62 | ``` 63 | 64 | 65 | 66 | 67 | ## Training and Evaluation 68 | 69 | Training and Testing codes for deblurring and deraining are provided in their respective directories. 70 | 71 | 89 | 90 | 95 | 96 | ## Citations 97 | If our code helps your research or work, please consider citing our paper. 98 | The following is a BibTeX reference: 99 | 100 | ``` 101 | @article{gao2023mixed, 102 | title={Mixed Hierarchy Network for Image Restoration}, 103 | author={Gao, Hu and Dang, Depeng}, 104 | journal={arXiv preprint arXiv:2302.09554}, 105 | year={2023} 106 | } 107 | 108 | 109 | ``` 110 | 111 | 112 | 113 | 114 | ## Contact 115 | Should you have any question, please contact two_bits@163.com 116 | 117 | -------------------------------------------------------------------------------- /fig/blur.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/blur.jpg -------------------------------------------------------------------------------- /fig/dau.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/dau.png -------------------------------------------------------------------------------- /fig/deblur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/deblur.png -------------------------------------------------------------------------------- /fig/derain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/derain.png -------------------------------------------------------------------------------- /fig/fir_h.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/fir_h.jpg -------------------------------------------------------------------------------- /fig/muti-net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/muti-net.png -------------------------------------------------------------------------------- /fig/network.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/network.jpg -------------------------------------------------------------------------------- /fig/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/network.png -------------------------------------------------------------------------------- /fig/rain.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/rain.jpg -------------------------------------------------------------------------------- /fig/sec_h.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/sec_h.jpg -------------------------------------------------------------------------------- /fig/three_con.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/fig/three_con.png -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/build/lib/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/build/lib/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/build/lib/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/dist/warmup_scheduler-0.3-py3.8.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Tombs98/MHNet/9ce8e40a9f5f0f0a6bfd7aaa10ce28df66bf32eb/pytorch-gradual-warmup-lr/dist/warmup_scheduler-0.3-py3.8.egg -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import setuptools 6 | 7 | _VERSION = '0.3' 8 | 9 | REQUIRED_PACKAGES = [ 10 | ] 11 | 12 | DEPENDENCY_LINKS = [ 13 | ] 14 | 15 | setuptools.setup( 16 | name='warmup_scheduler', 17 | version=_VERSION, 18 | description='Gradually Warm-up LR Scheduler for Pytorch', 19 | install_requires=REQUIRED_PACKAGES, 20 | dependency_links=DEPENDENCY_LINKS, 21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr', 22 | license='MIT License', 23 | package_dir={}, 24 | packages=setuptools.find_packages(exclude=['tests']), 25 | ) 26 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: warmup-scheduler 3 | Version: 0.3 4 | Summary: Gradually Warm-up LR Scheduler for Pytorch 5 | Home-page: https://github.com/ildoonet/pytorch-gradual-warmup-lr 6 | License: MIT License 7 | Platform: UNKNOWN 8 | 9 | UNKNOWN 10 | 11 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | warmup_scheduler/__init__.py 3 | warmup_scheduler/run.py 4 | warmup_scheduler/scheduler.py 5 | warmup_scheduler.egg-info/PKG-INFO 6 | warmup_scheduler.egg-info/SOURCES.txt 7 | warmup_scheduler.egg-info/dependency_links.txt 8 | warmup_scheduler.egg-info/top_level.txt -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | warmup_scheduler 2 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | --------------------------------------------------------------------------------