├── .idea ├── .gitignore ├── MSDT.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── __pycache__ ├── data_RGB.cpython-36.pyc ├── dataset_RGB.cpython-36.pyc ├── doconv_pytorch.cpython-36.pyc ├── get_parameter_number.cpython-36.pyc ├── layers.cpython-36.pyc ├── losses.cpython-36.pyc ├── mlp.cpython-36.pyc └── model.cpython-36.pyc ├── data_RGB.py ├── dataset_RGB.py ├── doconv_pytorch.py ├── evaluations ├── Evaluation_DID-Data_DDN-Data │ ├── psnr.m │ ├── ssim.m │ └── statistic.m └── Evalution_Rain200L_Rain200H_SPA-Data │ └── evaluate_PSNR_SSIM.m ├── get_parameter_number.py ├── layers.py ├── losses.py ├── model.py ├── pytorch-gradual-warmup-lr ├── build │ └── lib │ │ └── warmup_scheduler │ │ ├── __init__.py │ │ ├── run.py │ │ └── scheduler.py ├── dist │ └── warmup_scheduler-0.3-py3.8.egg ├── setup.py ├── warmup_scheduler.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt └── warmup_scheduler │ ├── __init__.py │ ├── run.py │ └── scheduler.py ├── test.py ├── train.py ├── train.sh └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-38.pyc ├── dataset_utils.cpython-36.pyc ├── dataset_utils.cpython-38.pyc ├── dir_utils.cpython-36.pyc ├── dir_utils.cpython-38.pyc ├── image_utils.cpython-36.pyc ├── image_utils.cpython-38.pyc ├── model_utils.cpython-36.pyc └── model_utils.cpython-38.pyc ├── dataset_utils.py ├── dir_utils.py ├── image_utils.py └── model_utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/MSDT.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 21 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Rethinking Multi-Scale Representations in Deep Deraining Transformer 4 | 5 |
6 | 7 | 8 | 9 | 10 | ## 🛠️ Training and Testing 11 | 1. Please put datasets in the folder `Datasets/`. 12 | 2. Follow the instructions below to begin training our model. 13 | ``` 14 | bash train.sh 15 | ``` 16 | Run the script then you can find the generated experimental logs in the folder `checkpoints`. 17 | 18 | 3. Follow the instructions below to begin testing our model. 19 | ``` 20 | python test.py 21 | ``` 22 | Run the script then you can find the output visual results in the folder `results/`. 23 | 24 | 25 | ## 🤖 Pre-trained Models 26 | | Models | MSDT | 27 | |:-----: |:-----: | 28 | | Rain200L | [Google Drive](https://drive.google.com/file/d/1qk8pUq7oM4Z4v2X-qmWJpE2LmUuweL4_/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1jikJhCuv51bvkl9vF2AkKw?pwd=8ajd) (8ajd) 29 | | Rain200H | [Google Drive](https://drive.google.com/file/d/1y8gjAvnt0kkf1dSEyauVFu2weLi53LmF/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1jr01T_hzl8K_h2VksrmlFQ?pwd=97lm) (97lm) 30 | | DID-Data | [Google Drive](https://drive.google.com/file/d/1RDvMFZn57UFrkeeojRHXwR7YbvXSGR5i/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1PJrRTDsG4vL4XwhNd8kfHg?pwd=5g4p) (5g4p) 31 | | DDN-Data | [Google Drive](https://drive.google.com/file/d/1p7FVQuZSw4n0nXEvLrsJPtYxzlMyOCK0/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1Y3YRkNO40m6bII-R3-Hi4g?pwd=b0b5) (b0b5) 32 | | SPA-Data | [Google Drive](https://drive.google.com/file/d/1hEpYFrFG0qhKassfYAZmXwUnNUYmGMLs/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1CO7wlaZyhu2egjfdaavFeQ?pwd=x0i5) (x0i5) 33 | 34 | 35 | ## 🚨 Performance Evaluation 36 | See folder "evaluations" 37 | 38 | 1) *for Rain200L/H and SPA-Data datasets*: 39 | PSNR and SSIM results are computed by using this [Matlab Code](https://github.com/sauchm/MSDT/tree/main/evaluations/Evalution_Rain200L_Rain200H_SPA-Data). 40 | 41 | 2) *for DID-Data and DDN-Data datasets*: 42 | PSNR and SSIM results are computed by using this [Matlab Code](https://github.com/sauchm/MSDT/tree/main/evaluations/Evaluation_DID-Data_DDN-Data). 43 | 44 | 45 | 46 | ## 🚀 Visual Deraining Results 47 | 48 | | Methods | MSDT | 49 | |:-----: |:-----: | 50 | | Rain200L | [Baidu Netdisk](https://pan.baidu.com/s/1us3smvwhAe3azJPnunWs8w?pwd=1xkc) (1xkc) 51 | | Rain200H | [Baidu Netdisk](https://pan.baidu.com/s/1S__NNB0jV2ING2ngR0PjiA?pwd=yr3n) (yr3n) 52 | | DID-Data | [Baidu Netdisk](https://pan.baidu.com/s/1Rif4QC1AuDF4ccHteg_A4A?pwd=242e) (242e) 53 | | DDN-Data | [Baidu Netdisk](https://pan.baidu.com/s/1JFHyrTMSdsFotOJ6pKokow?pwd=2pwk) (2pwk) 54 | | SPA-Data | [Baidu Netdisk](https://pan.baidu.com/s/14fSFf_T7AOD44ktso56Rxw?pwd=cag0) (cag0) 55 | 56 | 57 | ## 👍 Acknowledgement 58 | Thanks for their awesome works ([DeepRFT](https://github.com/INVOKERer/DeepRFT) and [DRSformer](https://github.com/cschenxiang/DRSformer)). 59 | 60 | ## 📘 Citation 61 | Please consider citing our work as follows if it is helpful. 62 | ``` 63 | @inproceedings{chen2024rethinking, 64 | title={Rethinking Multi-Scale Representations in Deep Deraining Transformer}, 65 | author={Chen, Hongming and Chen, Xiang and Lu, Jiyang and Li, Yufeng}, 66 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 67 | volume={38}, 68 | number={2}, 69 | pages={1046--1053}, 70 | year={2024} 71 | } 72 | ``` 73 | 74 | -------------------------------------------------------------------------------- /__pycache__/data_RGB.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/data_RGB.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/dataset_RGB.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/dataset_RGB.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/doconv_pytorch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/doconv_pytorch.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/get_parameter_number.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/get_parameter_number.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/mlp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/mlp.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /data_RGB.py: -------------------------------------------------------------------------------- 1 | from dataset_RGB import * 2 | 3 | 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options) 7 | 8 | def get_validation_data(rgb_dir, img_options): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, img_options) 11 | 12 | def get_test_data(rgb_dir, img_options): 13 | assert os.path.exists(rgb_dir) 14 | return DataLoaderTest(rgb_dir, img_options) 15 | -------------------------------------------------------------------------------- /dataset_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torch 5 | from PIL import Image 6 | import torchvision.transforms.functional as TF 7 | import random 8 | 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 12 | 13 | class DataLoaderTrain(Dataset): 14 | def __init__(self, rgb_dir, img_options=None): 15 | super(DataLoaderTrain, self).__init__() 16 | 17 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) 18 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) 19 | 20 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] 21 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)] 22 | 23 | self.img_options = img_options 24 | self.sizex = len(self.tar_filenames) # get the size of target 25 | 26 | self.ps = self.img_options['patch_size'] 27 | 28 | def __len__(self): 29 | return self.sizex 30 | 31 | def __getitem__(self, index): 32 | index_ = index % self.sizex 33 | ps = self.ps 34 | 35 | inp_path = self.inp_filenames[index_] 36 | tar_path = self.tar_filenames[index_] 37 | 38 | inp_img = Image.open(inp_path) 39 | tar_img = Image.open(tar_path) 40 | 41 | w,h = tar_img.size 42 | padw = ps-w if w 1: 61 | self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul)) 62 | init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32) 63 | self.D.data = torch.from_numpy(init_zero) 64 | 65 | eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N)) 66 | D_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N))) 67 | if self.D_mul % (M * N) != 0: # the cases when D_mul > M * N 68 | zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)]) 69 | self.D_diag = Parameter(torch.cat([D_diag, zeros], dim=2), requires_grad=False) 70 | else: # the case when D_mul = M * N 71 | self.D_diag = Parameter(D_diag, requires_grad=False) 72 | ################################################################################################## 73 | if simam: 74 | self.simam_block = simam_module() 75 | if bias: 76 | self.bias = Parameter(torch.Tensor(out_channels)) 77 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.W) 78 | bound = 1 / math.sqrt(fan_in) 79 | init.uniform_(self.bias, -bound, bound) 80 | else: 81 | self.register_parameter('bias', None) 82 | 83 | def extra_repr(self): 84 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 85 | ', stride={stride}') 86 | if self.padding != (0,) * len(self.padding): 87 | s += ', padding={padding}' 88 | if self.dilation != (1,) * len(self.dilation): 89 | s += ', dilation={dilation}' 90 | if self.groups != 1: 91 | s += ', groups={groups}' 92 | if self.bias is None: 93 | s += ', bias=False' 94 | if self.padding_mode != 'zeros': 95 | s += ', padding_mode={padding_mode}' 96 | return s.format(**self.__dict__) 97 | 98 | def __setstate__(self, state): 99 | super(DOConv2d, self).__setstate__(state) 100 | if not hasattr(self, 'padding_mode'): 101 | self.padding_mode = 'zeros' 102 | 103 | def _conv_forward(self, input, weight): 104 | if self.padding_mode != 'zeros': 105 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 106 | weight, self.bias, self.stride, 107 | (0, 0), self.dilation, self.groups) 108 | return F.conv2d(input, weight, self.bias, self.stride, 109 | self.padding, self.dilation, self.groups) 110 | 111 | def forward(self, input): 112 | M = self.kernel_size[0] 113 | N = self.kernel_size[1] 114 | DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N) 115 | if M * N > 1: 116 | ######################### Compute DoW ################# 117 | # (input_channels, D_mul, M * N) 118 | D = self.D + self.D_diag 119 | W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul)) 120 | 121 | # einsum outputs (out_channels // groups, in_channels, M * N), 122 | # which is reshaped to 123 | # (out_channels, in_channels // groups, M, N) 124 | DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape) 125 | ####################################################### 126 | else: 127 | DoW = torch.reshape(self.W, DoW_shape) 128 | if self.simam: 129 | DoW_h1, DoW_h2 = torch.chunk(DoW, 2, dim=2) 130 | DoW = torch.cat([self.simam_block(DoW_h1), DoW_h2], dim=2) 131 | 132 | return self._conv_forward(input, DoW) 133 | class DOConv2d_eval(Module): 134 | """ 135 | DOConv2d can be used as an alternative for torch.nn.Conv2d. 136 | The interface is similar to that of Conv2d, with one exception: 137 | 1. D_mul: the depth multiplier for the over-parameterization. 138 | Note that the groups parameter switchs between DO-Conv (groups=1), 139 | DO-DConv (groups=in_channels), DO-GConv (otherwise). 140 | """ 141 | __constants__ = ['stride', 'padding', 'dilation', 'groups', 142 | 'padding_mode', 'output_padding', 'in_channels', 143 | 'out_channels', 'kernel_size', 'D_mul'] 144 | __annotations__ = {'bias': Optional[torch.Tensor]} 145 | 146 | def __init__(self, in_channels, out_channels, kernel_size=3, D_mul=None, stride=1, 147 | padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros', simam=False): 148 | super(DOConv2d_eval, self).__init__() 149 | 150 | kernel_size = (kernel_size, kernel_size) 151 | stride = (stride, stride) 152 | padding = (padding, padding) 153 | dilation = (dilation, dilation) 154 | 155 | if in_channels % groups != 0: 156 | raise ValueError('in_channels must be divisible by groups') 157 | if out_channels % groups != 0: 158 | raise ValueError('out_channels must be divisible by groups') 159 | valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} 160 | if padding_mode not in valid_padding_modes: 161 | raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format( 162 | valid_padding_modes, padding_mode)) 163 | self.in_channels = in_channels 164 | self.out_channels = out_channels 165 | self.kernel_size = kernel_size 166 | self.stride = stride 167 | self.padding = padding 168 | self.dilation = dilation 169 | self.groups = groups 170 | self.padding_mode = padding_mode 171 | self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2)) 172 | self.simam = simam 173 | #################################### Initailization of D & W ################################### 174 | M = self.kernel_size[0] 175 | N = self.kernel_size[1] 176 | self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, M, N)) 177 | init.kaiming_uniform_(self.W, a=math.sqrt(5)) 178 | 179 | self.register_parameter('bias', None) 180 | def extra_repr(self): 181 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 182 | ', stride={stride}') 183 | if self.padding != (0,) * len(self.padding): 184 | s += ', padding={padding}' 185 | if self.dilation != (1,) * len(self.dilation): 186 | s += ', dilation={dilation}' 187 | if self.groups != 1: 188 | s += ', groups={groups}' 189 | if self.bias is None: 190 | s += ', bias=False' 191 | if self.padding_mode != 'zeros': 192 | s += ', padding_mode={padding_mode}' 193 | return s.format(**self.__dict__) 194 | 195 | def __setstate__(self, state): 196 | super(DOConv2d, self).__setstate__(state) 197 | if not hasattr(self, 'padding_mode'): 198 | self.padding_mode = 'zeros' 199 | 200 | def _conv_forward(self, input, weight): 201 | if self.padding_mode != 'zeros': 202 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 203 | weight, self.bias, self.stride, 204 | (0, 0), self.dilation, self.groups) 205 | return F.conv2d(input, weight, self.bias, self.stride, 206 | self.padding, self.dilation, self.groups) 207 | 208 | def forward(self, input): 209 | return self._conv_forward(input, self.W) 210 | 211 | class simam_module(torch.nn.Module): 212 | def __init__(self, e_lambda=1e-4): 213 | super(simam_module, self).__init__() 214 | self.activaton = nn.Sigmoid() 215 | self.e_lambda = e_lambda 216 | 217 | def forward(self, x): 218 | b, c, h, w = x.size() 219 | n = w * h - 1 220 | x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2) 221 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5 222 | return x * self.activaton(y) -------------------------------------------------------------------------------- /evaluations/Evaluation_DID-Data_DDN-Data/psnr.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/evaluations/Evaluation_DID-Data_DDN-Data/psnr.m -------------------------------------------------------------------------------- /evaluations/Evaluation_DID-Data_DDN-Data/ssim.m: -------------------------------------------------------------------------------- 1 | function [mssim, ssim_map] = ssim(img1, img2, K, window, L) 2 | 3 | % ======================================================================== 4 | % SSIM Index with automatic downsampling, Version 1.0 5 | % Copyright(c) 2009 Zhou Wang 6 | % All Rights Reserved. 7 | % 8 | % ---------------------------------------------------------------------- 9 | % Permission to use, copy, or modify this software and its documentation 10 | % for educational and research purposes only and without fee is hereby 11 | % granted, provided that this copyright notice and the original authors' 12 | % names appear on all copies and supporting documentation. This program 13 | % shall not be used, rewritten, or adapted as the basis of a commercial 14 | % software or hardware product without first obtaining permission of the 15 | % authors. The authors make no representations about the suitability of 16 | % this software for any purpose. It is provided "as is" without express 17 | % or implied warranty. 18 | %---------------------------------------------------------------------- 19 | % 20 | % This is an implementation of the algorithm for calculating the 21 | % Structural SIMilarity (SSIM) index between two images 22 | % 23 | % Please refer to the following paper and the website with suggested usage 24 | % 25 | % Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 26 | % quality assessment: From error visibility to structural similarity," 27 | % IEEE Transactios on Image Processing, vol. 13, no. 4, pp. 600-612, 28 | % Apr. 2004. 29 | % 30 | % http://www.ece.uwaterloo.ca/~z70wang/research/ssim/ 31 | % 32 | % Note: This program is different from ssim_index.m, where no automatic 33 | % downsampling is performed. (downsampling was done in the above paper 34 | % and was described as suggested usage in the above website.) 35 | % 36 | % Kindly report any suggestions or corrections to zhouwang@ieee.org 37 | % 38 | %---------------------------------------------------------------------- 39 | % 40 | %Input : (1) img1: the first image being compared 41 | % (2) img2: the second image being compared 42 | % (3) K: constants in the SSIM index formula (see the above 43 | % reference). defualt value: K = [0.01 0.03] 44 | % (4) window: local window for statistics (see the above 45 | % reference). default widnow is Gaussian given by 46 | % window = fspecial('gaussian', 11, 1.5); 47 | % (5) L: dynamic range of the images. default: L = 255 48 | % 49 | %Output: (1) mssim: the mean SSIM index value between 2 images. 50 | % If one of the images being compared is regarded as 51 | % perfect quality, then mssim can be considered as the 52 | % quality measure of the other image. 53 | % If img1 = img2, then mssim = 1. 54 | % (2) ssim_map: the SSIM index map of the test image. The map 55 | % has a smaller size than the input images. The actual size 56 | % depends on the window size and the downsampling factor. 57 | % 58 | %Basic Usage: 59 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 60 | % 61 | % [mssim, ssim_map] = ssim(img1, img2); 62 | % 63 | %Advanced Usage: 64 | % User defined parameters. For example 65 | % 66 | % K = [0.05 0.05]; 67 | % window = ones(8); 68 | % L = 100; 69 | % [mssim, ssim_map] = ssim(img1, img2, K, window, L); 70 | % 71 | %Visualize the results: 72 | % 73 | % mssim %Gives the mssim value 74 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 75 | %======================================================================== 76 | 77 | 78 | if (nargin < 2 || nargin > 5) 79 | mssim = -Inf; 80 | ssim_map = -Inf; 81 | return; 82 | end 83 | 84 | if (size(img1) ~= size(img2)) 85 | mssim = -Inf; 86 | ssim_map = -Inf; 87 | return; 88 | end 89 | 90 | [M N] = size(img1); 91 | 92 | if (nargin == 2) 93 | if ((M < 11) || (N < 11)) 94 | mssim = -Inf; 95 | ssim_map = -Inf; 96 | return 97 | end 98 | window = fspecial('gaussian', 11, 1.5); % 99 | K(1) = 0.01; % default settings 100 | K(2) = 0.03; % 101 | L = 255; % 102 | end 103 | 104 | if (nargin == 3) 105 | if ((M < 11) || (N < 11)) 106 | mssim = -Inf; 107 | ssim_map = -Inf; 108 | return 109 | end 110 | window = fspecial('gaussian', 11, 1.5); 111 | L = 255; 112 | if (length(K) == 2) 113 | if (K(1) < 0 || K(2) < 0) 114 | mssim = -Inf; 115 | ssim_map = -Inf; 116 | return; 117 | end 118 | else 119 | mssim = -Inf; 120 | ssim_map = -Inf; 121 | return; 122 | end 123 | end 124 | 125 | if (nargin == 4) 126 | [H W] = size(window); 127 | if ((H*W) < 4 || (H > M) || (W > N)) 128 | mssim = -Inf; 129 | ssim_map = -Inf; 130 | return 131 | end 132 | L = 255; 133 | if (length(K) == 2) 134 | if (K(1) < 0 || K(2) < 0) 135 | mssim = -Inf; 136 | ssim_map = -Inf; 137 | return; 138 | end 139 | else 140 | mssim = -Inf; 141 | ssim_map = -Inf; 142 | return; 143 | end 144 | end 145 | 146 | if (nargin == 5) 147 | [H W] = size(window); 148 | if ((H*W) < 4 || (H > M) || (W > N)) 149 | mssim = -Inf; 150 | ssim_map = -Inf; 151 | return 152 | end 153 | if (length(K) == 2) 154 | if (K(1) < 0 || K(2) < 0) 155 | mssim = -Inf; 156 | ssim_map = -Inf; 157 | return; 158 | end 159 | else 160 | mssim = -Inf; 161 | ssim_map = -Inf; 162 | return; 163 | end 164 | end 165 | 166 | 167 | img1 = double(img1); 168 | img2 = double(img2); 169 | 170 | % automatic downsampling 171 | f = max(1,round(min(M,N)/256)); 172 | %downsampling by f 173 | %use a simple low-pass filter 174 | if(f>1) 175 | lpf = ones(f,f); 176 | lpf = lpf/sum(lpf(:)); 177 | img1 = imfilter(img1,lpf,'symmetric','same'); 178 | img2 = imfilter(img2,lpf,'symmetric','same'); 179 | 180 | img1 = img1(1:f:end,1:f:end); 181 | img2 = img2(1:f:end,1:f:end); 182 | end 183 | 184 | C1 = (K(1)*L)^2; 185 | C2 = (K(2)*L)^2; 186 | window = window/sum(sum(window)); 187 | 188 | mu1 = filter2(window, img1, 'valid'); 189 | mu2 = filter2(window, img2, 'valid'); 190 | mu1_sq = mu1.*mu1; 191 | mu2_sq = mu2.*mu2; 192 | mu1_mu2 = mu1.*mu2; 193 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 194 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 195 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 196 | 197 | if (C1 > 0 && C2 > 0) 198 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 199 | else 200 | numerator1 = 2*mu1_mu2 + C1; 201 | numerator2 = 2*sigma12 + C2; 202 | denominator1 = mu1_sq + mu2_sq + C1; 203 | denominator2 = sigma1_sq + sigma2_sq + C2; 204 | ssim_map = ones(size(mu1)); 205 | index = (denominator1.*denominator2 > 0); 206 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 207 | index = (denominator1 ~= 0) & (denominator2 == 0); 208 | ssim_map(index) = numerator1(index)./denominator1(index); 209 | end 210 | 211 | mssim = mean2(ssim_map); 212 | 213 | return -------------------------------------------------------------------------------- /evaluations/Evaluation_DID-Data_DDN-Data/statistic.m: -------------------------------------------------------------------------------- 1 | clear all; 2 | ts =0; 3 | tp =0; 4 | % for i=1:1200 % the number of testing samples DID-Data 5 | for i=1:1400 % the number of testing samples DDN-Data 6 | x_true=im2double(imread(strcat('./gt/DDN-Data/target/',sprintf('%d.jpg',i)))); % groundtruth 7 | % x_true=im2double(imread(strcat('./gt/DID-Data/target/',sprintf('%d.jpg',i)))); % groundtruth 8 | x_true = rgb2ycbcr(x_true); 9 | x_true = x_true(:,:,1); 10 | x = im2double(imread(strcat('./results/DDN-Data/',sprintf('%d.png',i)))); %reconstructed image 11 | % x = im2double(imread(strcat('./results/DID-Data/',sprintf('%d.png',i)))); %reconstructed image 12 | x = rgb2ycbcr(x); 13 | x = x(:,:,1); 14 | tp= tp+ psnr(x,x_true); 15 | ts= ts+ssim(x*255,x_true*255); 16 | end 17 | % fprintf('psnr=%6.4f, ssim=%6.4f\n',tp/1200,ts/1200) % the number of testing samples DID-Data 18 | fprintf('psnr=%6.4f, ssim=%6.4f\n',tp/1400,ts/1400) % the number of testing samples DDN-Data 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /evaluations/Evalution_Rain200L_Rain200H_SPA-Data/evaluate_PSNR_SSIM.m: -------------------------------------------------------------------------------- 1 | clc;close all;clear all;addpath(genpath('./')); 2 | 3 | datasets = {'Rain200L'}; 4 | % datasets = {'Rain200L', 'Rain200H', 'SPA-Data'}; 5 | num_set = length(datasets); 6 | 7 | psnr_alldatasets = 0; 8 | ssim_alldatasets = 0; 9 | for idx_set = 1:num_set 10 | file_path = strcat('./results/', datasets{idx_set}, '/'); 11 | gt_path = strcat('./Datasets/', datasets{idx_set}, '/'); 12 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))]; 13 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))]; 14 | img_num = length(path_list); 15 | 16 | total_psnr = 0; 17 | total_ssim = 0; 18 | if img_num > 0 19 | for j = 1:img_num 20 | image_name = path_list(j).name; 21 | gt_name = gt_list(j).name; 22 | input = imread(strcat(file_path,image_name)); 23 | gt = imread(strcat(gt_path, gt_name)); 24 | ssim_val = compute_ssim(input, gt); 25 | psnr_val = compute_psnr(input, gt); 26 | total_ssim = total_ssim + ssim_val; 27 | total_psnr = total_psnr + psnr_val; 28 | end 29 | end 30 | qm_psnr = total_psnr / img_num; 31 | qm_ssim = total_ssim / img_num; 32 | 33 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim); 34 | 35 | psnr_alldatasets = psnr_alldatasets + qm_psnr; 36 | ssim_alldatasets = ssim_alldatasets + qm_ssim; 37 | 38 | end 39 | 40 | fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set); 41 | 42 | function ssim_mean=compute_ssim(img1,img2) 43 | if size(img1, 3) == 3 44 | img1 = rgb2ycbcr(img1); 45 | img1 = img1(:, :, 1); 46 | end 47 | 48 | if size(img2, 3) == 3 49 | img2 = rgb2ycbcr(img2); 50 | img2 = img2(:, :, 1); 51 | end 52 | ssim_mean = SSIM_index(img1, img2); 53 | end 54 | 55 | function psnr=compute_psnr(img1,img2) 56 | if size(img1, 3) == 3 57 | img1 = rgb2ycbcr(img1); 58 | img1 = img1(:, :, 1); 59 | end 60 | 61 | if size(img2, 3) == 3 62 | img2 = rgb2ycbcr(img2); 63 | img2 = img2(:, :, 1); 64 | end 65 | 66 | imdff = double(img1) - double(img2); 67 | imdff = imdff(:); 68 | rmse = sqrt(mean(imdff.^2)); 69 | psnr = 20*log10(255/rmse); 70 | 71 | end 72 | 73 | function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L) 74 | 75 | %======================================================================== 76 | %SSIM Index, Version 1.0 77 | %Copyright(c) 2003 Zhou Wang 78 | %All Rights Reserved. 79 | % 80 | %The author is with Howard Hughes Medical Institute, and Laboratory 81 | %for Computational Vision at Center for Neural Science and Courant 82 | %Institute of Mathematical Sciences, New York University. 83 | % 84 | %---------------------------------------------------------------------- 85 | %Permission to use, copy, or modify this software and its documentation 86 | %for educational and research purposes only and without fee is hereby 87 | %granted, provided that this copyright notice and the original authors' 88 | %names appear on all copies and supporting documentation. This program 89 | %shall not be used, rewritten, or adapted as the basis of a commercial 90 | %software or hardware product without first obtaining permission of the 91 | %authors. The authors make no representations about the suitability of 92 | %this software for any purpose. It is provided "as is" without express 93 | %or implied warranty. 94 | %---------------------------------------------------------------------- 95 | % 96 | %This is an implementation of the algorithm for calculating the 97 | %Structural SIMilarity (SSIM) index between two images. Please refer 98 | %to the following paper: 99 | % 100 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 101 | %quality assessment: From error measurement to structural similarity" 102 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. 103 | % 104 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 105 | % 106 | %---------------------------------------------------------------------- 107 | % 108 | %Input : (1) img1: the first image being compared 109 | % (2) img2: the second image being compared 110 | % (3) K: constants in the SSIM index formula (see the above 111 | % reference). defualt value: K = [0.01 0.03] 112 | % (4) window: local window for statistics (see the above 113 | % reference). default widnow is Gaussian given by 114 | % window = fspecial('gaussian', 11, 1.5); 115 | % (5) L: dynamic range of the images. default: L = 255 116 | % 117 | %Output: (1) mssim: the mean SSIM index value between 2 images. 118 | % If one of the images being compared is regarded as 119 | % perfect quality, then mssim can be considered as the 120 | % quality measure of the other image. 121 | % If img1 = img2, then mssim = 1. 122 | % (2) ssim_map: the SSIM index map of the test image. The map 123 | % has a smaller size than the input images. The actual size: 124 | % size(img1) - size(window) + 1. 125 | % 126 | %Default Usage: 127 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 128 | % 129 | % [mssim ssim_map] = ssim_index(img1, img2); 130 | % 131 | %Advanced Usage: 132 | % User defined parameters. For example 133 | % 134 | % K = [0.05 0.05]; 135 | % window = ones(8); 136 | % L = 100; 137 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 138 | % 139 | %See the results: 140 | % 141 | % mssim %Gives the mssim value 142 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 143 | % 144 | %======================================================================== 145 | 146 | 147 | if (nargin < 2 || nargin > 5) 148 | ssim_index = -Inf; 149 | ssim_map = -Inf; 150 | return; 151 | end 152 | 153 | if (size(img1) ~= size(img2)) 154 | ssim_index = -Inf; 155 | ssim_map = -Inf; 156 | return; 157 | end 158 | 159 | [M N] = size(img1); 160 | 161 | if (nargin == 2) 162 | if ((M < 11) || (N < 11)) 163 | ssim_index = -Inf; 164 | ssim_map = -Inf; 165 | return 166 | end 167 | window = fspecial('gaussian', 11, 1.5); % 168 | K(1) = 0.01; % default settings 169 | K(2) = 0.03; % 170 | L = 255; % 171 | end 172 | 173 | if (nargin == 3) 174 | if ((M < 11) || (N < 11)) 175 | ssim_index = -Inf; 176 | ssim_map = -Inf; 177 | return 178 | end 179 | window = fspecial('gaussian', 11, 1.5); 180 | L = 255; 181 | if (length(K) == 2) 182 | if (K(1) < 0 || K(2) < 0) 183 | ssim_index = -Inf; 184 | ssim_map = -Inf; 185 | return; 186 | end 187 | else 188 | ssim_index = -Inf; 189 | ssim_map = -Inf; 190 | return; 191 | end 192 | end 193 | 194 | if (nargin == 4) 195 | [H W] = size(window); 196 | if ((H*W) < 4 || (H > M) || (W > N)) 197 | ssim_index = -Inf; 198 | ssim_map = -Inf; 199 | return 200 | end 201 | L = 255; 202 | if (length(K) == 2) 203 | if (K(1) < 0 || K(2) < 0) 204 | ssim_index = -Inf; 205 | ssim_map = -Inf; 206 | return; 207 | end 208 | else 209 | ssim_index = -Inf; 210 | ssim_map = -Inf; 211 | return; 212 | end 213 | end 214 | 215 | if (nargin == 5) 216 | [H W] = size(window); 217 | if ((H*W) < 4 || (H > M) || (W > N)) 218 | ssim_index = -Inf; 219 | ssim_map = -Inf; 220 | return 221 | end 222 | if (length(K) == 2) 223 | if (K(1) < 0 || K(2) < 0) 224 | ssim_index = -Inf; 225 | ssim_map = -Inf; 226 | return; 227 | end 228 | else 229 | ssim_index = -Inf; 230 | ssim_map = -Inf; 231 | return; 232 | end 233 | end 234 | 235 | C1 = (K(1)*L)^2; 236 | C2 = (K(2)*L)^2; 237 | window = window/sum(sum(window)); 238 | img1 = double(img1); 239 | img2 = double(img2); 240 | 241 | mu1 = filter2(window, img1, 'valid'); 242 | mu2 = filter2(window, img2, 'valid'); 243 | mu1_sq = mu1.*mu1; 244 | mu2_sq = mu2.*mu2; 245 | mu1_mu2 = mu1.*mu2; 246 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 247 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 248 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 249 | 250 | if (C1 > 0 & C2 > 0) 251 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 252 | else 253 | numerator1 = 2*mu1_mu2 + C1; 254 | numerator2 = 2*sigma12 + C2; 255 | denominator1 = mu1_sq + mu2_sq + C1; 256 | denominator2 = sigma1_sq + sigma2_sq + C2; 257 | ssim_map = ones(size(mu1)); 258 | index = (denominator1.*denominator2 > 0); 259 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 260 | index = (denominator1 ~= 0) & (denominator2 == 0); 261 | ssim_map(index) = numerator1(index)./denominator1(index); 262 | end 263 | 264 | mssim = mean2(ssim_map); 265 | 266 | end 267 | 268 | -------------------------------------------------------------------------------- /get_parameter_number.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_parameter_number(net): 4 | total_num = sum(np.prod(p.size()) for p in net.parameters()) 5 | trainable_num = sum(np.prod(p.size()) for p in net.parameters() if p.requires_grad) 6 | print('Total: ', total_num) 7 | print('Trainable: ', trainable_num) 8 | 9 | 10 | if __name__=='__main__': 11 | from DeepRFT_MIMO import DeepRFT_flops as Net 12 | import torch 13 | from ptflops import get_model_complexity_info 14 | with torch.cuda.device(0): 15 | net = Net() 16 | macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True, 17 | print_per_layer_stat=True, verbose=True) 18 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 19 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 20 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | from doconv_pytorch import * 2 | 3 | 4 | class BasicConv(nn.Module): 5 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=False, norm=False, relu=True, transpose=False, 6 | channel_shuffle_g=0, norm_method=nn.BatchNorm2d, groups=1): 7 | super(BasicConv, self).__init__() 8 | self.channel_shuffle_g = channel_shuffle_g 9 | self.norm = norm 10 | if bias and norm: 11 | bias = False 12 | 13 | padding = kernel_size // 2 14 | layers = list() 15 | if transpose: 16 | padding = kernel_size // 2 - 1 17 | layers.append( 18 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups)) 19 | else: 20 | layers.append( 21 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups)) 22 | if norm: 23 | layers.append(norm_method(out_channel)) 24 | elif relu: 25 | layers.append(nn.ReLU(inplace=True)) 26 | 27 | self.main = nn.Sequential(*layers) 28 | 29 | def forward(self, x): 30 | return self.main(x) 31 | 32 | class BasicConv_do(nn.Module): 33 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, bias=False, norm=False, relu=True, transpose=False, 34 | relu_method=nn.ReLU, groups=1, norm_method=nn.BatchNorm2d): 35 | super(BasicConv_do, self).__init__() 36 | if bias and norm: 37 | bias = False 38 | 39 | padding = kernel_size // 2 40 | layers = list() 41 | if transpose: 42 | padding = kernel_size // 2 - 1 43 | layers.append( 44 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 45 | else: 46 | layers.append( 47 | DOConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups)) 48 | if norm: 49 | layers.append(norm_method(out_channel)) 50 | if relu: 51 | if relu_method == nn.ReLU: 52 | layers.append(nn.ReLU(inplace=True)) 53 | elif relu_method == nn.LeakyReLU: 54 | layers.append(nn.LeakyReLU(inplace=True)) 55 | else: 56 | layers.append(relu_method()) 57 | self.main = nn.Sequential(*layers) 58 | 59 | def forward(self, x): 60 | return self.main(x) 61 | 62 | class BasicConv_do_eval(nn.Module): 63 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=False, norm=False, relu=True, transpose=False, 64 | relu_method=nn.ReLU, groups=1, norm_method=nn.BatchNorm2d): 65 | super(BasicConv_do_eval, self).__init__() 66 | if bias and norm: 67 | bias = False 68 | 69 | padding = kernel_size // 2 70 | layers = list() 71 | if transpose: 72 | padding = kernel_size // 2 - 1 73 | layers.append( 74 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 75 | else: 76 | layers.append( 77 | DOConv2d_eval(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups)) 78 | if norm: 79 | layers.append(norm_method(out_channel)) 80 | if relu: 81 | if relu_method == nn.ReLU: 82 | layers.append(nn.ReLU(inplace=True)) 83 | elif relu_method == nn.LeakyReLU: 84 | layers.append(nn.LeakyReLU(inplace=True)) 85 | else: 86 | layers.append(relu_method()) 87 | self.main = nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | return self.main(x) 91 | 92 | class ResBlock(nn.Module): 93 | def __init__(self, out_channel): 94 | super(ResBlock, self).__init__() 95 | self.main = nn.Sequential( 96 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=True, norm=False), 97 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False, norm=False) 98 | ) 99 | 100 | def forward(self, x): 101 | return self.main(x) + x 102 | 103 | class ResBlock_do(nn.Module): 104 | def __init__(self, out_channel): 105 | super(ResBlock_do, self).__init__() 106 | self.main = nn.Sequential( 107 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=True), 108 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 109 | ) 110 | 111 | def forward(self, x): 112 | return self.main(x) + x 113 | 114 | class ResBlock_do_eval(nn.Module): 115 | def __init__(self, out_channel): 116 | super(ResBlock_do_eval, self).__init__() 117 | self.main = nn.Sequential( 118 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=True), 119 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 120 | ) 121 | 122 | def forward(self, x): 123 | return self.main(x) + x 124 | 125 | 126 | class ResBlock_do_FECB_bench(nn.Module): 127 | def __init__(self, out_channel, norm='backward'): 128 | super(ResBlock_do_FECB_bench, self).__init__() 129 | self.main = nn.Sequential( 130 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=True), 131 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 132 | ) 133 | self.main_fft = nn.Sequential( 134 | BasicConv_do(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=True), 135 | BasicConv_do(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=False) 136 | ) 137 | self.dim = out_channel 138 | self.norm = norm 139 | def forward(self, x): 140 | _, _, H, W = x.shape 141 | dim = 1 142 | y = torch.fft.rfft2(x, norm=self.norm) 143 | y_imag = y.imag 144 | y_real = y.real 145 | y_f = torch.cat([y_real, y_imag], dim=dim) 146 | y = self.main_fft(y_f) 147 | y_real, y_imag = torch.chunk(y, 2, dim=dim) 148 | y = torch.complex(y_real, y_imag) 149 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm) 150 | return self.main(x) + x + y 151 | 152 | class ResBlock_FECB_bench(nn.Module): 153 | def __init__(self, n_feat, norm='backward'): # 'ortho' 154 | super(ResBlock_FECB_bench, self).__init__() 155 | self.main = nn.Sequential( 156 | BasicConv(n_feat, n_feat, kernel_size=3, stride=1, relu=True), 157 | BasicConv(n_feat, n_feat, kernel_size=3, stride=1, relu=False) 158 | ) 159 | self.main_fft = nn.Sequential( 160 | BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=True), 161 | BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=False) 162 | ) 163 | self.dim = n_feat 164 | self.norm = norm 165 | def forward(self, x): 166 | _, _, H, W = x.shape 167 | dim = 1 168 | y = torch.fft.rfft2(x, norm=self.norm) 169 | y_imag = y.imag 170 | y_real = y.real 171 | y_f = torch.cat([y_real, y_imag], dim=dim) 172 | y = self.main_fft(y_f) 173 | y_real, y_imag = torch.chunk(y, 2, dim=dim) 174 | y = torch.complex(y_real, y_imag) 175 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm) 176 | return self.main(x) + x + y 177 | class ResBlock_do_FECB_bench_eval(nn.Module): 178 | def __init__(self, out_channel, norm='backward'): 179 | super(ResBlock_do_FECB_bench_eval, self).__init__() 180 | self.main = nn.Sequential( 181 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=True), 182 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 183 | ) 184 | self.main_fft = nn.Sequential( 185 | BasicConv_do_eval(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=True), 186 | BasicConv_do_eval(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=False) 187 | ) 188 | self.dim = out_channel 189 | self.norm = norm 190 | def forward(self, x): 191 | _, _, H, W = x.shape 192 | dim = 1 193 | y = torch.fft.rfft2(x, norm=self.norm) 194 | y_imag = y.imag 195 | y_real = y.real 196 | y_f = torch.cat([y_real, y_imag], dim=dim) 197 | y = self.main_fft(y_f) 198 | y_real, y_imag = torch.chunk(y, 2, dim=dim) 199 | y = torch.complex(y_real, y_imag) 200 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm) 201 | return self.main(x) + x + y 202 | 203 | def window_partitions(x, window_size): 204 | """ 205 | Args: 206 | x: (B, C, H, W) 207 | window_size (int): window size 208 | Returns: 209 | windows: (num_windows*B, C, window_size, window_size) 210 | """ 211 | if isinstance(window_size, int): 212 | window_size = [window_size, window_size] 213 | B, C, H, W = x.shape 214 | x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]) 215 | windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1]) 216 | return windows 217 | 218 | 219 | def window_reverses(windows, window_size, H, W): 220 | """ 221 | Args: 222 | windows: (num_windows*B, C, window_size, window_size) 223 | window_size (int): Window size 224 | H (int): Height of image 225 | W (int): Width of image 226 | Returns: 227 | x: (B, C, H, W) 228 | """ 229 | # B = int(windows.shape[0] / (H * W / window_size / window_size)) 230 | # print('B: ', B) 231 | # print(H // window_size) 232 | # print(W // window_size) 233 | if isinstance(window_size, int): 234 | window_size = [window_size, window_size] 235 | C = windows.shape[1] 236 | # print('C: ', C) 237 | x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1]) 238 | x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) 239 | return x 240 | 241 | def window_partitionx(x, window_size): 242 | _, _, H, W = x.shape 243 | h, w = window_size * (H // window_size), window_size * (W // window_size) 244 | x_main = window_partitions(x[:, :, :h, :w], window_size) 245 | b_main = x_main.shape[0] 246 | if h == H and w == W: 247 | return x_main, [b_main] 248 | if h != H and w != W: 249 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size) 250 | b_r = x_r.shape[0] + b_main 251 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size) 252 | b_d = x_d.shape[0] + b_r 253 | x_dd = x[:, :, -window_size:, -window_size:] 254 | b_dd = x_dd.shape[0] + b_d 255 | # batch_list = [b_main, b_r, b_d, b_dd] 256 | return torch.cat([x_main, x_r, x_d, x_dd], dim=0), [b_main, b_r, b_d, b_dd] 257 | if h == H and w != W: 258 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size) 259 | b_r = x_r.shape[0] + b_main 260 | return torch.cat([x_main, x_r], dim=0), [b_main, b_r] 261 | if h != H and w == W: 262 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size) 263 | b_d = x_d.shape[0] + b_main 264 | return torch.cat([x_main, x_d], dim=0), [b_main, b_d] 265 | def window_reversex(windows, window_size, H, W, batch_list): 266 | h, w = window_size * (H // window_size), window_size * (W // window_size) 267 | # print(windows[:batch_list[0], ...].shape) 268 | x_main = window_reverses(windows[:batch_list[0], ...], window_size, h, w) 269 | B, C, _, _ = x_main.shape 270 | # print('windows: ', windows.shape) 271 | # print('batch_list: ', batch_list) 272 | if torch.is_complex(windows): 273 | res = torch.complex(torch.zeros([B, C, H, W]), torch.zeros([B, C, H, W])) 274 | res = res.to(windows.device) 275 | else: 276 | res = torch.zeros([B, C, H, W], device=windows.device) 277 | 278 | res[:, :, :h, :w] = x_main 279 | if h == H and w == W: 280 | return res 281 | if h != H and w != W and len(batch_list) == 4: 282 | x_dd = window_reverses(windows[batch_list[2]:, ...], window_size, window_size, window_size) 283 | res[:, :, h:, w:] = x_dd[:, :, h - H:, w - W:] 284 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size) 285 | res[:, :, :h, w:] = x_r[:, :, :, w - W:] 286 | x_d = window_reverses(windows[batch_list[1]:batch_list[2], ...], window_size, window_size, w) 287 | res[:, :, h:, :w] = x_d[:, :, h - H:, :] 288 | return res 289 | if w != W and len(batch_list) == 2: 290 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size) 291 | res[:, :, :h, w:] = x_r[:, :, :, w - W:] 292 | if h != H and len(batch_list) == 2: 293 | x_d = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, window_size, w) 294 | res[:, :, h:, :w] = x_d[:, :, h - H:, :] 295 | return res 296 | 297 | def window_partitions_old(x, window_size): 298 | """ 299 | Args: 300 | x: (B, C, H, W) 301 | window_size (int): window size 302 | 303 | Returns: 304 | windows: (num_windows*B, C, window_size, window_size) 305 | """ 306 | B, C, H, W = x.shape 307 | x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) 308 | windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size) 309 | return windows 310 | 311 | 312 | def window_reverses_old(windows, window_size, H, W): 313 | """ 314 | Args: 315 | windows: (num_windows*B, C, window_size, window_size) 316 | window_size (int): Window size 317 | H (int): Height of image 318 | W (int): Width of image 319 | 320 | Returns: 321 | x: (B, C, H, W) 322 | """ 323 | # B = int(windows.shape[0] / (H * W / window_size / window_size)) 324 | # print('B: ', B) 325 | # print(H // window_size) 326 | # print(W // window_size) 327 | C = windows.shape[1] 328 | # print('C: ', C) 329 | x = windows.view(-1, H // window_size, W // window_size, C, window_size, window_size) 330 | x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) 331 | return x 332 | 333 | def window_partitionx_old(x, window_size): 334 | _, _, H, W = x.shape 335 | h, w = window_size * (H // window_size), window_size * (W // window_size) 336 | x_main = window_partitions(x[:, :, :h, :w], window_size) 337 | b_main = x_main.shape[0] 338 | if h == H and w == W: 339 | return x_main, [b_main] 340 | if h != H and w != W: 341 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size) 342 | b_r = x_r.shape[0] + b_main 343 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size) 344 | b_d = x_d.shape[0] + b_r 345 | x_dd = x[:, :, -window_size:, -window_size:] 346 | b_dd = x_dd.shape[0] + b_d 347 | # batch_list = [b_main, b_r, b_d, b_dd] 348 | return torch.cat([x_main, x_r, x_d, x_dd], dim=0), [b_main, b_r, b_d, b_dd] 349 | if h == H and w != W: 350 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size) 351 | b_r = x_r.shape[0] + b_main 352 | return torch.cat([x_main, x_r], dim=0), [b_main, b_r] 353 | if h != H and w == W: 354 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size) 355 | b_d = x_d.shape[0] + b_main 356 | return torch.cat([x_main, x_d], dim=0), [b_main, b_d] 357 | 358 | def window_reversex_old(windows, window_size, H, W, batch_list): 359 | h, w = window_size * (H // window_size), window_size * (W // window_size) 360 | x_main = window_reverses(windows[:batch_list[0], ...], window_size, h, w) 361 | B, C, _, _ = x_main.shape 362 | # print('windows: ', windows.shape) 363 | # print('batch_list: ', batch_list) 364 | res = torch.zeros([B, C, H, W],device=windows.device) 365 | res[:, :, :h, :w] = x_main 366 | if h == H and w == W: 367 | return res 368 | if h != H and w != W and len(batch_list) == 4: 369 | x_dd = window_reverses(windows[batch_list[2]:, ...], window_size, window_size, window_size) 370 | res[:, :, h:, w:] = x_dd[:, :, h - H:, w - W:] 371 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size) 372 | res[:, :, :h, w:] = x_r[:, :, :, w - W:] 373 | x_d = window_reverses(windows[batch_list[1]:batch_list[2], ...], window_size, window_size, w) 374 | res[:, :, h:, :w] = x_d[:, :, h - H:, :] 375 | return res 376 | if w != W and len(batch_list) == 2: 377 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size) 378 | res[:, :, :h, w:] = x_r[:, :, :, w - W:] 379 | if h != H and len(batch_list) == 2: 380 | x_d = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, window_size, w) 381 | res[:, :, h:, :w] = x_d[:, :, h - H:, :] 382 | return res -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-3): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x.to('cuda:0') - y.to('cuda:0') 14 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 15 | return loss 16 | 17 | class EdgeLoss(nn.Module): 18 | def __init__(self): 19 | super(EdgeLoss, self).__init__() 20 | k = torch.Tensor([[.05, .25, .4, .25, .05]]) 21 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1) 22 | if torch.cuda.is_available(): 23 | self.kernel = self.kernel.to('cuda:0') 24 | self.loss = CharbonnierLoss() 25 | 26 | def conv_gauss(self, img): 27 | n_channels, _, kw, kh = self.kernel.shape 28 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 29 | return F.conv2d(img, self.kernel, groups=n_channels) 30 | 31 | def laplacian_kernel(self, current): 32 | filtered = self.conv_gauss(current) 33 | down = filtered[:,:,::2,::2] 34 | new_filter = torch.zeros_like(filtered) 35 | new_filter[:,:,::2,::2] = down*4 36 | filtered = self.conv_gauss(new_filter) 37 | diff = current - filtered 38 | return diff 39 | 40 | def forward(self, x, y): 41 | loss = self.loss(self.laplacian_kernel(x.to('cuda:0')), self.laplacian_kernel(y.to('cuda:0'))) 42 | return loss 43 | 44 | class fftLoss(nn.Module): 45 | def __init__(self): 46 | super(fftLoss, self).__init__() 47 | 48 | def forward(self, x, y): 49 | diff = torch.fft.fft2(x.to('cuda:0')) - torch.fft.fft2(y.to('cuda:0')) 50 | loss = torch.mean(abs(diff)) 51 | return loss 52 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from layers import * 2 | import numbers 3 | from einops import rearrange 4 | 5 | class Downsample(nn.Module): 6 | def __init__(self, n_feat): 7 | super(Downsample, self).__init__() 8 | 9 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False), 10 | nn.PixelUnshuffle(2)) 11 | 12 | def forward(self, x): 13 | return self.body(x) 14 | 15 | class Upsample(nn.Module): 16 | def __init__(self, n_feat): 17 | super(Upsample, self).__init__() 18 | 19 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False), 20 | nn.PixelShuffle(2)) 21 | 22 | def forward(self, x): 23 | return self.body(x) 24 | 25 | def to_3d(x): 26 | return rearrange(x, 'b c h w -> b (h w) c') 27 | 28 | def to_4d(x, h, w): 29 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 30 | 31 | class BiasFree_LayerNorm(nn.Module): 32 | def __init__(self, normalized_shape): 33 | super(BiasFree_LayerNorm, self).__init__() 34 | if isinstance(normalized_shape, numbers.Integral): 35 | normalized_shape = (normalized_shape,) 36 | normalized_shape = torch.Size(normalized_shape) 37 | 38 | assert len(normalized_shape) == 1 39 | 40 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 41 | self.normalized_shape = normalized_shape 42 | 43 | def forward(self, x): 44 | sigma = x.var(-1, keepdim=True, unbiased=False) 45 | return x / torch.sqrt(sigma + 1e-5) * self.weight 46 | 47 | class WithBias_LayerNorm(nn.Module): 48 | def __init__(self, normalized_shape): 49 | super(WithBias_LayerNorm, self).__init__() 50 | if isinstance(normalized_shape, numbers.Integral): 51 | normalized_shape = (normalized_shape,) 52 | normalized_shape = torch.Size(normalized_shape) 53 | 54 | assert len(normalized_shape) == 1 55 | 56 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 57 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 58 | self.normalized_shape = normalized_shape 59 | 60 | def forward(self, x): 61 | mu = x.mean(-1, keepdim=True) 62 | sigma = x.var(-1, keepdim=True, unbiased=False) 63 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 64 | 65 | class LayerNorm(nn.Module): 66 | def __init__(self, dim, LayerNorm_type): 67 | super(LayerNorm, self).__init__() 68 | if LayerNorm_type == 'BiasFree': 69 | self.body = BiasFree_LayerNorm(dim) 70 | else: 71 | self.body = WithBias_LayerNorm(dim) 72 | 73 | def forward(self, x): 74 | h, w = x.shape[-2:] 75 | return to_4d(self.body(to_3d(x)), h, w) 76 | 77 | class FeedForward(nn.Module): 78 | def __init__(self, dim, ffn_expansion_factor, bias, BasicConv=BasicConv): 79 | super(FeedForward, self).__init__() 80 | 81 | hidden_features = int(dim * ffn_expansion_factor) 82 | 83 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) 84 | 85 | self.dwconv = BasicConv(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, bias=bias, relu=False, groups=hidden_features * 2) 86 | 87 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 88 | 89 | def forward(self, x): 90 | x = self.project_in(x) 91 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 92 | x = F.gelu(x1) * x2 93 | x = self.project_out(x) 94 | return x 95 | 96 | class Attention(nn.Module): 97 | def __init__(self,scale, dim, num_heads, bias): 98 | super(Attention, self).__init__() 99 | self.num_heads = num_heads 100 | 101 | self.sacle = scale 102 | 103 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 104 | 105 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) 106 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias) 107 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 108 | self.attn_drop = nn.Dropout(0.) 109 | 110 | def forward(self, x): 111 | b, c, h, w = x.shape 112 | 113 | qkv = self.qkv_dwconv(self.qkv(x)) 114 | q, k, v = qkv.chunk(3, dim=1) 115 | 116 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 117 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 118 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 119 | 120 | q = torch.nn.functional.normalize(q, dim=-1) 121 | k = torch.nn.functional.normalize(k, dim=-1) 122 | 123 | _, _, C, _ = q.shape 124 | 125 | mask1 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False) 126 | 127 | attn = (q @ k.transpose(-2, -1)) * self.temperature 128 | 129 | if self.sacle == 1: 130 | index = torch.topk(attn, k=int(C*6/10), dim=-1, largest=True)[1] 131 | elif self.sacle == 0.5: 132 | index = torch.topk(attn, k=int(C * 7 / 10), dim=-1, largest=True)[1] 133 | elif self.sacle == 0.25: 134 | index = torch.topk(attn, k=int(C * 8 / 10), dim=-1, largest=True)[1] 135 | mask1.scatter_(-1, index, 1.) 136 | attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf'))) 137 | 138 | 139 | attn1 = attn1.softmax(dim=-1) 140 | 141 | out = (attn1 @ v) 142 | 143 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 144 | 145 | out = self.project_out(out) 146 | return out 147 | 148 | class TransformerBlock(nn.Module): 149 | def __init__(self, scale,dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, BasicConv=BasicConv): 150 | super(TransformerBlock, self).__init__() 151 | 152 | self.norm1 = LayerNorm(dim, LayerNorm_type) 153 | self.attn = Attention(scale,dim, num_heads, bias) 154 | self.norm2 = LayerNorm(dim, LayerNorm_type) 155 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias, BasicConv=BasicConv) 156 | 157 | def forward(self, x): 158 | x = x + self.attn(self.norm1(x)) 159 | x = x + self.ffn(self.norm2(x)) 160 | 161 | return x 162 | 163 | class FECB_SCTB(nn.Module): 164 | def __init__(self , out_channel, num_res=8, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=1, ffn_expansion_factor=1, bias=False, LayerNorm_type='WithBias', scale=1): 165 | super(FECB_SCTB, self).__init__() 166 | 167 | layers = [] 168 | for _ in range(num_res): 169 | layers.append(ResBlock(out_channel)) 170 | layers.append(TransformerBlock(scale = scale,dim=out_channel, num_heads=num_heads, ffn_expansion_factor=ffn_expansion_factor, bias=bias, 171 | LayerNorm_type=LayerNorm_type, BasicConv=BasicConv)) 172 | 173 | self.layers = nn.Sequential(*layers) 174 | 175 | def forward(self, x): 176 | return self.layers(x) 177 | 178 | class GFM(nn.Module): 179 | def __init__(self, in_channel, out_channel, BasicConv=BasicConv): 180 | super(GFM, self).__init__() 181 | self.conv_max = nn.Sequential( 182 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True), 183 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 184 | ) 185 | self.conv_mid = nn.Sequential( 186 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True), 187 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 188 | ) 189 | self.conv_small = nn.Sequential( 190 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True), 191 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 192 | ) 193 | 194 | self.conv1 =BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True) 195 | self.conv2 = BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True) 196 | 197 | 198 | def forward(self, x_max,x_mid,x_small): 199 | 200 | y_max=x_max +x_mid +x_small 201 | 202 | x_max = self.conv_max(x_max) 203 | x_mid = self.conv_max(x_mid) 204 | x_small = self.conv_max(x_small) 205 | 206 | x =F.tanh(x_mid) * x_max 207 | x = self.conv1(x) 208 | 209 | x =F.tanh(x_small) * x 210 | x = self.conv2(x) 211 | 212 | return x+y_max 213 | 214 | class SCM(nn.Module): 215 | def __init__(self, out_plane, BasicConv=BasicConv, inchannel=3): 216 | super(SCM, self).__init__() 217 | self.main = nn.Sequential( 218 | BasicConv(inchannel, out_plane//4, kernel_size=3, stride=1, relu=True), 219 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 220 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 221 | BasicConv(out_plane // 2, out_plane-inchannel, kernel_size=1, stride=1, relu=True) 222 | ) 223 | 224 | self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False) 225 | 226 | def forward(self, x): 227 | x = torch.cat([x, self.main(x)], dim=1) 228 | return self.conv(x) 229 | 230 | class FAM(nn.Module): 231 | def __init__(self, channel, BasicConv=BasicConv): 232 | super(FAM, self).__init__() 233 | self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False) 234 | 235 | def forward(self, x1, x2): 236 | x = x1 * x2 237 | out = x1 + self.merge(x) 238 | return out 239 | 240 | class MSDT(nn.Module): 241 | def __init__(self, num_res=8, inference=False): 242 | super(MSDT, self).__init__() 243 | self.inference = inference 244 | if not inference: 245 | BasicConv = BasicConv_do 246 | ResBlock = ResBlock_do_FECB_bench 247 | else: 248 | BasicConv = BasicConv_do_eval 249 | ResBlock = ResBlock_do_FECB_bench_eval 250 | base_channel = 32 251 | 252 | heads = [1, 2, 4] 253 | ffn_expansion_factor = 2.66 254 | bias = False 255 | LayerNorm_type = 'WithBias' 256 | scale = [1,0.5,0.25] 257 | 258 | self.Encoder = nn.ModuleList([ 259 | FECB_SCTB(base_channel, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[0], 260 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= scale[0]), 261 | FECB_SCTB(base_channel * 2, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[1], 262 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= scale[1]), 263 | FECB_SCTB(base_channel * 4, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[2], 264 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= scale[2]), 265 | ]) 266 | 267 | self.feat_extract = nn.ModuleList([ 268 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 269 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 270 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 271 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 272 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 273 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 274 | ]) 275 | 276 | self.Decoder = nn.ModuleList([ 277 | FECB_SCTB(base_channel * 4, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[2], 278 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= 0.25), 279 | FECB_SCTB(base_channel * 2, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[1], 280 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= 0.5), 281 | FECB_SCTB(base_channel, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[0], 282 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= 1) 283 | ]) 284 | 285 | self.Convs = nn.ModuleList([ 286 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 287 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 288 | ]) 289 | 290 | self.ConvsOut = nn.ModuleList( 291 | [ 292 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 293 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 294 | ] 295 | ) 296 | 297 | self.GFMs = nn.ModuleList([ 298 | GFM(32, 32, BasicConv=BasicConv), 299 | GFM(64, 64, BasicConv=BasicConv) 300 | ]) 301 | 302 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv) 303 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv) 304 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv) 305 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv) 306 | 307 | self.down_1 = Downsample(32) 308 | 309 | self.up_1 = Upsample(64) 310 | self.up_2 = Upsample(128) 311 | self.up_3 = Upsample(64) 312 | 313 | def forward(self, x): 314 | x_2 = F.interpolate(x, scale_factor=0.5) 315 | x_4 = F.interpolate(x_2, scale_factor=0.5) 316 | z2 = self.SCM2(x_2) 317 | z4 = self.SCM1(x_4) 318 | 319 | outputs = list() 320 | 321 | x_ = self.feat_extract[0](x) 322 | 323 | res1 = self.Encoder[0](x_) 324 | 325 | z = self.feat_extract[1](res1) 326 | z = self.FAM2(z, z2) 327 | 328 | res2 = self.Encoder[1](z) 329 | 330 | z = self.feat_extract[2](res2) 331 | z = self.FAM1(z, z4) 332 | 333 | z = self.Encoder[2](z) 334 | 335 | z21 = self.up_1(res2) 336 | z42 = self.up_2(z) 337 | z41 = self.up_3(z42) 338 | 339 | z12 = self.down_1(res1) 340 | 341 | res1 = self.GFMs[0](res1,z21,z41) 342 | res2 = self.GFMs[1](z12,res2,z42) 343 | 344 | z = self.Decoder[0](z) 345 | z_ = self.ConvsOut[0](z) 346 | z = self.feat_extract[3](z) 347 | if not self.inference: 348 | outputs.append(z_+x_4) 349 | 350 | z = torch.cat([z, res2], dim=1) 351 | z = self.Convs[0](z) 352 | 353 | z = self.Decoder[1](z) 354 | z_ = self.ConvsOut[1](z) 355 | z = self.feat_extract[4](z) 356 | if not self.inference: 357 | outputs.append(z_+x_2) 358 | 359 | z = torch.cat([z, res1], dim=1) 360 | z = self.Convs[1](z) 361 | 362 | z = self.Decoder[2](z) 363 | z = self.feat_extract[5](z) 364 | if not self.inference: 365 | outputs.append(z+x) 366 | return outputs[::-1] 367 | 368 | 369 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/build/lib/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/build/lib/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/build/lib/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/dist/warmup_scheduler-0.3-py3.8.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/pytorch-gradual-warmup-lr/dist/warmup_scheduler-0.3-py3.8.egg -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import setuptools 6 | 7 | _VERSION = '0.3' 8 | 9 | REQUIRED_PACKAGES = [ 10 | ] 11 | 12 | DEPENDENCY_LINKS = [ 13 | ] 14 | 15 | setuptools.setup( 16 | name='warmup_scheduler', 17 | version=_VERSION, 18 | description='Gradually Warm-up LR Scheduler for Pytorch', 19 | install_requires=REQUIRED_PACKAGES, 20 | dependency_links=DEPENDENCY_LINKS, 21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr', 22 | license='MIT License', 23 | package_dir={}, 24 | packages=setuptools.find_packages(exclude=['tests']), 25 | ) 26 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: warmup-scheduler 3 | Version: 0.3 4 | Summary: Gradually Warm-up LR Scheduler for Pytorch 5 | Home-page: https://github.com/ildoonet/pytorch-gradual-warmup-lr 6 | License: MIT License 7 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | setup.py 2 | warmup_scheduler/__init__.py 3 | warmup_scheduler/run.py 4 | warmup_scheduler/scheduler.py 5 | warmup_scheduler.egg-info/PKG-INFO 6 | warmup_scheduler.egg-info/SOURCES.txt 7 | warmup_scheduler.egg-info/dependency_links.txt 8 | warmup_scheduler.egg-info/top_level.txt -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | warmup_scheduler 2 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch.nn as nn 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import utils 7 | from data_RGB import get_test_data 8 | from model import MSDT as mynet 9 | from skimage import img_as_ubyte 10 | from get_parameter_number import get_parameter_number 11 | from tqdm import tqdm 12 | from layers import * 13 | 14 | parser = argparse.ArgumentParser(description='Image Deraining') 15 | parser.add_argument('--input_dir', default='', type=str, help='Directory of validation images') 16 | parser.add_argument('--output_dir', default='', type=str, help='Directory of validation images') 17 | parser.add_argument('--weights', default='', type=str, help='Path to weights') 18 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES') 19 | parser.add_argument('--win_size', default=256, type=int, help='window size, [GoPro, HIDE, RealBlur]=256, [DPDD]=512') 20 | args = parser.parse_args() 21 | result_dir = args.output_dir 22 | win = args.win_size 23 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 24 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 25 | model_restoration = mynet() 26 | get_parameter_number(model_restoration) 27 | utils.load_checkpoint(model_restoration, args.weights) 28 | print("===>Testing using weights: ",args.weights) 29 | model_restoration.cuda() 30 | model_restoration = nn.DataParallel(model_restoration) 31 | model_restoration.eval() 32 | 33 | # dataset = args.dataset 34 | rgb_dir_test = args.input_dir 35 | test_dataset = get_test_data(rgb_dir_test, img_options={}) 36 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) 37 | 38 | utils.mkdir(result_dir) 39 | 40 | with torch.no_grad(): 41 | psnr_list = [] 42 | ssim_list = [] 43 | for ii, data_test in enumerate(tqdm(test_loader), 0): 44 | 45 | torch.cuda.ipc_collect() 46 | torch.cuda.empty_cache() 47 | input_ = data_test[0].cuda() 48 | filenames = data_test[1] 49 | _, _, Hx, Wx = input_.shape 50 | filenames = data_test[1] 51 | input_re, batch_list = window_partitionx(input_, win) 52 | restored = model_restoration(input_re) 53 | restored = window_reversex(restored[0], win, Hx, Wx, batch_list) 54 | 55 | restored = torch.clamp(restored, 0, 1) 56 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() 57 | 58 | for batch in range(len(restored)): 59 | restored_img = restored[batch] 60 | restored_img = img_as_ubyte(restored[batch]) 61 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img) 62 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 4 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 5 | 6 | import torch 7 | 8 | torch.backends.cudnn.benchmark = True 9 | 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | from torch.utils.data import DataLoader 13 | 14 | import random 15 | import time 16 | import numpy as np 17 | 18 | import utils 19 | from data_RGB import get_training_data, get_validation_data 20 | from model import MSDT as myNet 21 | import losses 22 | from warmup_scheduler import GradualWarmupScheduler 23 | from tqdm import tqdm 24 | from get_parameter_number import get_parameter_number 25 | import kornia 26 | from torch.utils.tensorboard import SummaryWriter 27 | import argparse 28 | 29 | from skimage import img_as_ubyte 30 | 31 | ######### Set Seeds ########### 32 | random.seed(1234) 33 | np.random.seed(1234) 34 | torch.manual_seed(1234) 35 | torch.cuda.manual_seed_all(1234) 36 | 37 | start_epoch = 1 38 | 39 | parser = argparse.ArgumentParser(description='Image Deraininig') 40 | 41 | parser.add_argument('--train_dir', default='/home/user/data/21/chm/dataset/Rain200H/Rain200H/train', type=str, help='Directory of train images') 42 | parser.add_argument('--val_dir', default='/home/user/data/21/chm/dataset/Rain200H/Rain200H/test', type=str, help='Directory of validation images') 43 | parser.add_argument('--model_save_dir', default='./checkpoints', type=str, help='Path to save weights') 44 | parser.add_argument('--pretrain_weights', default='', type=str, help='Path to pretrain-weights') 45 | parser.add_argument('--mode', default='Deraininig', type=str) 46 | parser.add_argument('--session', default='Multiscale', type=str, help='session') 47 | parser.add_argument('--patch_size', default=256, type=int, help='patch size') 48 | parser.add_argument('--num_epochs', default=500, type=int, help='num_epochs') 49 | parser.add_argument('--batch_size', default=1, type=int, help='batch_size') 50 | parser.add_argument('--val_epochs', default=1, type=int, help='val_epochs') 51 | args = parser.parse_args() 52 | 53 | mode = args.mode 54 | session = args.session 55 | patch_size = args.patch_size 56 | 57 | model_dir = os.path.join(args.model_save_dir, mode, 'models', session) 58 | utils.mkdir(model_dir) 59 | 60 | train_dir = args.train_dir 61 | val_dir = args.val_dir 62 | 63 | num_epochs = args.num_epochs 64 | batch_size = args.batch_size 65 | val_epochs = args.val_epochs 66 | 67 | start_lr = 1e-4 68 | end_lr = 1e-6 69 | 70 | ######### Model ########### 71 | model_restoration = myNet() 72 | 73 | # print number of model 74 | get_parameter_number(model_restoration) 75 | 76 | model_restoration.cuda() 77 | 78 | device_ids = [i for i in range(torch.cuda.device_count())] 79 | if torch.cuda.device_count() > 1: 80 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") 81 | 82 | optimizer = optim.Adam(model_restoration.parameters(), lr=start_lr, betas=(0.9, 0.999), eps=1e-8) 83 | 84 | ######### Scheduler ########### 85 | warmup_epochs = 3 86 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs - warmup_epochs, eta_min=end_lr) 87 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 88 | 89 | RESUME = False 90 | Pretrain = False 91 | model_pre_dir = '' 92 | ######### Pretrain ########### 93 | if Pretrain: 94 | utils.load_checkpoint(model_restoration, model_pre_dir) 95 | 96 | print('------------------------------------------------------------------------------') 97 | print("==> Retrain Training with: " + model_pre_dir) 98 | print('------------------------------------------------------------------------------') 99 | 100 | ######### Resume ########### 101 | if RESUME: 102 | path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') 103 | utils.load_checkpoint(model_restoration, path_chk_rest) 104 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 105 | utils.load_optim(optimizer, path_chk_rest) 106 | 107 | for i in range(1, start_epoch): 108 | scheduler.step() 109 | new_lr = scheduler.get_lr()[0] 110 | print('------------------------------------------------------------------------------') 111 | print("==> Resuming Training with learning rate:", new_lr) 112 | print('------------------------------------------------------------------------------') 113 | 114 | if len(device_ids) > 1: 115 | model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids) 116 | 117 | ######### Loss ########### 118 | criterion_char = losses.CharbonnierLoss() 119 | criterion_edge = losses.EdgeLoss() 120 | criterion_fft = losses.fftLoss() 121 | ######### DataLoaders ########### 122 | train_dataset = get_training_data(train_dir, {'patch_size': patch_size}) 123 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False, 124 | pin_memory=True) 125 | 126 | val_dataset = get_validation_data(val_dir, {'patch_size': patch_size}) 127 | val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False, 128 | pin_memory=True) 129 | 130 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch, num_epochs + 1)) 131 | print('===> Loading datasets') 132 | 133 | best_psnr = 0 134 | best_epoch = 0 135 | writer = SummaryWriter(model_dir) 136 | iter = 0 137 | 138 | for epoch in range(start_epoch, num_epochs + 1): 139 | epoch_start_time = time.time() 140 | epoch_loss = 0 141 | train_id = 1 142 | 143 | model_restoration.train() 144 | for i, data in enumerate(tqdm(train_loader), 0): 145 | 146 | # zero_grad 147 | for param in model_restoration.parameters(): 148 | param.grad = None 149 | 150 | target_ = data[0].cuda() 151 | input_ = data[1].cuda() 152 | target = kornia.geometry.transform.build_pyramid(target_, 3) 153 | restored = model_restoration(input_) 154 | 155 | loss_fft = criterion_fft(restored[0], target[0]) + criterion_fft(restored[1], target[1]) + criterion_fft(restored[2], target[2]) 156 | loss_char = criterion_char(restored[0], target[0]) + criterion_char(restored[1], target[1]) + criterion_char(restored[2], target[2]) 157 | loss_edge = criterion_edge(restored[0], target[0]) + criterion_edge(restored[1], target[1]) + criterion_edge(restored[2], target[2]) 158 | loss = loss_char + 0.01 * loss_fft + 0.05 * loss_edge 159 | loss.backward() 160 | optimizer.step() 161 | epoch_loss += loss.item() 162 | iter += 1 163 | writer.add_scalar('loss/fft_loss', loss_fft, iter) 164 | writer.add_scalar('loss/char_loss', loss_char, iter) 165 | writer.add_scalar('loss/edge_loss', loss_edge, iter) 166 | writer.add_scalar('loss/iter_loss', loss, iter) 167 | writer.add_scalar('loss/epoch_loss', epoch_loss, epoch) 168 | #### Evaluation #### 169 | if epoch % val_epochs == 0: 170 | model_restoration.eval() 171 | psnr_val_rgb = [] 172 | for ii, data_val in enumerate((val_loader), 0): 173 | target = data_val[0].cuda() 174 | input_ = data_val[1].cuda() 175 | 176 | with torch.no_grad(): 177 | restored = model_restoration(input_) 178 | 179 | for res, tar in zip(restored[0], target): 180 | psnr_val_rgb.append(utils.torchPSNR(res, tar)) 181 | 182 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() 183 | writer.add_scalar('val/psnr', psnr_val_rgb, epoch) 184 | if psnr_val_rgb > best_psnr: 185 | best_psnr = psnr_val_rgb 186 | best_epoch = epoch 187 | torch.save({'epoch': epoch, 188 | 'state_dict': model_restoration.state_dict(), 189 | 'optimizer': optimizer.state_dict() 190 | }, os.path.join(model_dir, "model_best.pth")) 191 | 192 | print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr)) 193 | 194 | torch.save({'epoch': epoch, 195 | 'state_dict': model_restoration.state_dict(), 196 | 'optimizer': optimizer.state_dict() 197 | }, os.path.join(model_dir, f"model_epoch_{epoch}.pth")) 198 | 199 | scheduler.step() 200 | 201 | print("------------------------------------------------------------------") 202 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time, 203 | epoch_loss, scheduler.get_lr()[0])) 204 | print("------------------------------------------------------------------") 205 | 206 | torch.save({'epoch': epoch, 207 | 'state_dict': model_restoration.state_dict(), 208 | 'optimizer': optimizer.state_dict() 209 | }, os.path.join(model_dir, "model_latest.pth")) 210 | 211 | writer.close() 212 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -p gpu20 3 | #SBATCH -t 7-00:00:00 4 | 5 | python train.py | tee logs_Rain200H.txt -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .image_utils import * 3 | from .model_utils import * 4 | from .dataset_utils import * 5 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/dataset_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/dataset_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dir_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/dir_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dir_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/dir_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/image_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/image_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/model_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/model_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MixUp_AUG: 4 | def __init__(self): 5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6])) 6 | 7 | def aug(self, rgb_gt, rgb_noisy): 8 | bs = rgb_gt.size(0) 9 | indices = torch.randperm(bs) 10 | rgb_gt2 = rgb_gt[indices] 11 | rgb_noisy2 = rgb_noisy[indices] 12 | 13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() 14 | 15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 17 | 18 | return rgb_gt, rgb_noisy -------------------------------------------------------------------------------- /utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | def torchPSNR(tar_img, prd_img): 6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 7 | rmse = (imdff**2).mean().sqrt() 8 | ps = 20*torch.log10(1/rmse) 9 | return ps 10 | 11 | def save_img(filepath, img): 12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 13 | 14 | def numpyPSNR(tar_img, prd_img): 15 | imdff = np.float32(prd_img) - np.float32(tar_img) 16 | rmse = np.sqrt(np.mean(imdff**2)) 17 | ps = 20*np.log10(255/rmse) 18 | return ps 19 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | import numpy as np 5 | def freeze(model): 6 | for p in model.parameters(): 7 | p.requires_grad=False 8 | 9 | def unfreeze(model): 10 | for p in model.parameters(): 11 | p.requires_grad=True 12 | 13 | def is_frozen(model): 14 | x = [p.requires_grad for p in model.parameters()] 15 | return not all(x) 16 | 17 | def save_checkpoint(model_dir, state, session): 18 | epoch = state['epoch'] 19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 20 | torch.save(state, model_out_path) 21 | 22 | def load_checkpoint(model, weights): 23 | checkpoint = torch.load(weights) 24 | # print(checkpoint) 25 | try: 26 | model.load_state_dict(checkpoint["state_dict"]) 27 | except: 28 | state_dict = checkpoint["state_dict"] 29 | new_state_dict = OrderedDict() 30 | for k, v in state_dict.items(): 31 | # print(k) 32 | name = k[7:] # remove `module.` 33 | new_state_dict[name] = v 34 | 35 | model.load_state_dict(new_state_dict) 36 | 37 | 38 | def load_checkpoint_compress_doconv(model, weights): 39 | checkpoint = torch.load(weights) 40 | # print(checkpoint) 41 | # state_dict = OrderedDict() 42 | # try: 43 | # model.load_state_dict(checkpoint["state_dict"]) 44 | # state_dict = checkpoint["state_dict"] 45 | # except: 46 | old_state_dict = checkpoint["state_dict"] 47 | state_dict = OrderedDict() 48 | for k, v in old_state_dict.items(): 49 | # print(k) 50 | name = k 51 | if k[:7] == 'module.': 52 | name = k[7:] # remove `module.` 53 | state_dict[name] = v 54 | # state_dict = checkpoint["state_dict"] 55 | do_state_dict = OrderedDict() 56 | for k, v in state_dict.items(): 57 | if k[-1] == 'W' and k[:-1] + 'D' in state_dict: 58 | k_D = k[:-1] + 'D' 59 | k_D_diag = k_D + '_diag' 60 | W = v 61 | D = state_dict[k_D] 62 | D_diag = state_dict[k_D_diag] 63 | D = D + D_diag 64 | # W = torch.reshape(W, (out_channels, in_channels, D_mul)) 65 | out_channels, in_channels, MN = W.shape 66 | M = int(np.sqrt(MN)) 67 | DoW_shape = (out_channels, in_channels, M, M) 68 | DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape) 69 | do_state_dict[k] = DoW 70 | elif k[-1] == 'D' or k[-6:] == 'D_diag': 71 | continue 72 | elif k[-1] == 'W': 73 | out_channels, in_channels, MN = v.shape 74 | M = int(np.sqrt(MN)) 75 | W_shape = (out_channels, in_channels, M, M) 76 | do_state_dict[k] = torch.reshape(v, W_shape) 77 | else: 78 | do_state_dict[k] = v 79 | model.load_state_dict(do_state_dict) 80 | def load_checkpoint_hin(model, weights): 81 | checkpoint = torch.load(weights) 82 | # print(checkpoint) 83 | try: 84 | model.load_state_dict(checkpoint) 85 | except: 86 | state_dict = checkpoint 87 | new_state_dict = OrderedDict() 88 | for k, v in state_dict.items(): 89 | name = k[7:] # remove `module.` 90 | new_state_dict[name] = v 91 | model.load_state_dict(new_state_dict) 92 | def load_checkpoint_multigpu(model, weights): 93 | checkpoint = torch.load(weights) 94 | state_dict = checkpoint["state_dict"] 95 | new_state_dict = OrderedDict() 96 | for k, v in state_dict.items(): 97 | name = k[7:] # remove `module.` 98 | new_state_dict[name] = v 99 | model.load_state_dict(new_state_dict) 100 | 101 | def load_start_epoch(weights): 102 | checkpoint = torch.load(weights) 103 | epoch = checkpoint["epoch"] 104 | return epoch 105 | 106 | def load_optim(optimizer, weights): 107 | checkpoint = torch.load(weights) 108 | optimizer.load_state_dict(checkpoint['optimizer']) 109 | # for p in optimizer.param_groups: lr = p['lr'] 110 | # return lr 111 | --------------------------------------------------------------------------------