├── .idea
├── .gitignore
├── MSDT.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── __pycache__
├── data_RGB.cpython-36.pyc
├── dataset_RGB.cpython-36.pyc
├── doconv_pytorch.cpython-36.pyc
├── get_parameter_number.cpython-36.pyc
├── layers.cpython-36.pyc
├── losses.cpython-36.pyc
├── mlp.cpython-36.pyc
└── model.cpython-36.pyc
├── data_RGB.py
├── dataset_RGB.py
├── doconv_pytorch.py
├── evaluations
├── Evaluation_DID-Data_DDN-Data
│ ├── psnr.m
│ ├── ssim.m
│ └── statistic.m
└── Evalution_Rain200L_Rain200H_SPA-Data
│ └── evaluate_PSNR_SSIM.m
├── get_parameter_number.py
├── layers.py
├── losses.py
├── model.py
├── pytorch-gradual-warmup-lr
├── build
│ └── lib
│ │ └── warmup_scheduler
│ │ ├── __init__.py
│ │ ├── run.py
│ │ └── scheduler.py
├── dist
│ └── warmup_scheduler-0.3-py3.8.egg
├── setup.py
├── warmup_scheduler.egg-info
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ └── top_level.txt
└── warmup_scheduler
│ ├── __init__.py
│ ├── run.py
│ └── scheduler.py
├── test.py
├── train.py
├── train.sh
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── __init__.cpython-38.pyc
├── dataset_utils.cpython-36.pyc
├── dataset_utils.cpython-38.pyc
├── dir_utils.cpython-36.pyc
├── dir_utils.cpython-38.pyc
├── image_utils.cpython-36.pyc
├── image_utils.cpython-38.pyc
├── model_utils.cpython-36.pyc
└── model_utils.cpython-38.pyc
├── dataset_utils.py
├── dir_utils.py
├── image_utils.py
└── model_utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/MSDT.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Rethinking Multi-Scale Representations in Deep Deraining Transformer
4 |
5 |
6 |
7 |
8 |
9 |
10 | ## 🛠️ Training and Testing
11 | 1. Please put datasets in the folder `Datasets/`.
12 | 2. Follow the instructions below to begin training our model.
13 | ```
14 | bash train.sh
15 | ```
16 | Run the script then you can find the generated experimental logs in the folder `checkpoints`.
17 |
18 | 3. Follow the instructions below to begin testing our model.
19 | ```
20 | python test.py
21 | ```
22 | Run the script then you can find the output visual results in the folder `results/`.
23 |
24 |
25 | ## 🤖 Pre-trained Models
26 | | Models | MSDT |
27 | |:-----: |:-----: |
28 | | Rain200L | [Google Drive](https://drive.google.com/file/d/1qk8pUq7oM4Z4v2X-qmWJpE2LmUuweL4_/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1jikJhCuv51bvkl9vF2AkKw?pwd=8ajd) (8ajd)
29 | | Rain200H | [Google Drive](https://drive.google.com/file/d/1y8gjAvnt0kkf1dSEyauVFu2weLi53LmF/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1jr01T_hzl8K_h2VksrmlFQ?pwd=97lm) (97lm)
30 | | DID-Data | [Google Drive](https://drive.google.com/file/d/1RDvMFZn57UFrkeeojRHXwR7YbvXSGR5i/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1PJrRTDsG4vL4XwhNd8kfHg?pwd=5g4p) (5g4p)
31 | | DDN-Data | [Google Drive](https://drive.google.com/file/d/1p7FVQuZSw4n0nXEvLrsJPtYxzlMyOCK0/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1Y3YRkNO40m6bII-R3-Hi4g?pwd=b0b5) (b0b5)
32 | | SPA-Data | [Google Drive](https://drive.google.com/file/d/1hEpYFrFG0qhKassfYAZmXwUnNUYmGMLs/view?usp=drive_link) / [Baidu Netdisk](https://pan.baidu.com/s/1CO7wlaZyhu2egjfdaavFeQ?pwd=x0i5) (x0i5)
33 |
34 |
35 | ## 🚨 Performance Evaluation
36 | See folder "evaluations"
37 |
38 | 1) *for Rain200L/H and SPA-Data datasets*:
39 | PSNR and SSIM results are computed by using this [Matlab Code](https://github.com/sauchm/MSDT/tree/main/evaluations/Evalution_Rain200L_Rain200H_SPA-Data).
40 |
41 | 2) *for DID-Data and DDN-Data datasets*:
42 | PSNR and SSIM results are computed by using this [Matlab Code](https://github.com/sauchm/MSDT/tree/main/evaluations/Evaluation_DID-Data_DDN-Data).
43 |
44 |
45 |
46 | ## 🚀 Visual Deraining Results
47 |
48 | | Methods | MSDT |
49 | |:-----: |:-----: |
50 | | Rain200L | [Baidu Netdisk](https://pan.baidu.com/s/1us3smvwhAe3azJPnunWs8w?pwd=1xkc) (1xkc)
51 | | Rain200H | [Baidu Netdisk](https://pan.baidu.com/s/1S__NNB0jV2ING2ngR0PjiA?pwd=yr3n) (yr3n)
52 | | DID-Data | [Baidu Netdisk](https://pan.baidu.com/s/1Rif4QC1AuDF4ccHteg_A4A?pwd=242e) (242e)
53 | | DDN-Data | [Baidu Netdisk](https://pan.baidu.com/s/1JFHyrTMSdsFotOJ6pKokow?pwd=2pwk) (2pwk)
54 | | SPA-Data | [Baidu Netdisk](https://pan.baidu.com/s/14fSFf_T7AOD44ktso56Rxw?pwd=cag0) (cag0)
55 |
56 |
57 | ## 👍 Acknowledgement
58 | Thanks for their awesome works ([DeepRFT](https://github.com/INVOKERer/DeepRFT) and [DRSformer](https://github.com/cschenxiang/DRSformer)).
59 |
60 | ## 📘 Citation
61 | Please consider citing our work as follows if it is helpful.
62 | ```
63 | @inproceedings{chen2024rethinking,
64 | title={Rethinking Multi-Scale Representations in Deep Deraining Transformer},
65 | author={Chen, Hongming and Chen, Xiang and Lu, Jiyang and Li, Yufeng},
66 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
67 | volume={38},
68 | number={2},
69 | pages={1046--1053},
70 | year={2024}
71 | }
72 | ```
73 |
74 |
--------------------------------------------------------------------------------
/__pycache__/data_RGB.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/data_RGB.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/dataset_RGB.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/dataset_RGB.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/doconv_pytorch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/doconv_pytorch.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/get_parameter_number.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/get_parameter_number.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/layers.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/layers.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/losses.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/losses.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/mlp.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/mlp.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/data_RGB.py:
--------------------------------------------------------------------------------
1 | from dataset_RGB import *
2 |
3 |
4 | def get_training_data(rgb_dir, img_options):
5 | assert os.path.exists(rgb_dir)
6 | return DataLoaderTrain(rgb_dir, img_options)
7 |
8 | def get_validation_data(rgb_dir, img_options):
9 | assert os.path.exists(rgb_dir)
10 | return DataLoaderVal(rgb_dir, img_options)
11 |
12 | def get_test_data(rgb_dir, img_options):
13 | assert os.path.exists(rgb_dir)
14 | return DataLoaderTest(rgb_dir, img_options)
15 |
--------------------------------------------------------------------------------
/dataset_RGB.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import torch
5 | from PIL import Image
6 | import torchvision.transforms.functional as TF
7 | import random
8 |
9 |
10 | def is_image_file(filename):
11 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
12 |
13 | class DataLoaderTrain(Dataset):
14 | def __init__(self, rgb_dir, img_options=None):
15 | super(DataLoaderTrain, self).__init__()
16 |
17 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input')))
18 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target')))
19 |
20 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)]
21 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)]
22 |
23 | self.img_options = img_options
24 | self.sizex = len(self.tar_filenames) # get the size of target
25 |
26 | self.ps = self.img_options['patch_size']
27 |
28 | def __len__(self):
29 | return self.sizex
30 |
31 | def __getitem__(self, index):
32 | index_ = index % self.sizex
33 | ps = self.ps
34 |
35 | inp_path = self.inp_filenames[index_]
36 | tar_path = self.tar_filenames[index_]
37 |
38 | inp_img = Image.open(inp_path)
39 | tar_img = Image.open(tar_path)
40 |
41 | w,h = tar_img.size
42 | padw = ps-w if w 1:
61 | self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul))
62 | init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
63 | self.D.data = torch.from_numpy(init_zero)
64 |
65 | eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N))
66 | D_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N)))
67 | if self.D_mul % (M * N) != 0: # the cases when D_mul > M * N
68 | zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)])
69 | self.D_diag = Parameter(torch.cat([D_diag, zeros], dim=2), requires_grad=False)
70 | else: # the case when D_mul = M * N
71 | self.D_diag = Parameter(D_diag, requires_grad=False)
72 | ##################################################################################################
73 | if simam:
74 | self.simam_block = simam_module()
75 | if bias:
76 | self.bias = Parameter(torch.Tensor(out_channels))
77 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
78 | bound = 1 / math.sqrt(fan_in)
79 | init.uniform_(self.bias, -bound, bound)
80 | else:
81 | self.register_parameter('bias', None)
82 |
83 | def extra_repr(self):
84 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
85 | ', stride={stride}')
86 | if self.padding != (0,) * len(self.padding):
87 | s += ', padding={padding}'
88 | if self.dilation != (1,) * len(self.dilation):
89 | s += ', dilation={dilation}'
90 | if self.groups != 1:
91 | s += ', groups={groups}'
92 | if self.bias is None:
93 | s += ', bias=False'
94 | if self.padding_mode != 'zeros':
95 | s += ', padding_mode={padding_mode}'
96 | return s.format(**self.__dict__)
97 |
98 | def __setstate__(self, state):
99 | super(DOConv2d, self).__setstate__(state)
100 | if not hasattr(self, 'padding_mode'):
101 | self.padding_mode = 'zeros'
102 |
103 | def _conv_forward(self, input, weight):
104 | if self.padding_mode != 'zeros':
105 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
106 | weight, self.bias, self.stride,
107 | (0, 0), self.dilation, self.groups)
108 | return F.conv2d(input, weight, self.bias, self.stride,
109 | self.padding, self.dilation, self.groups)
110 |
111 | def forward(self, input):
112 | M = self.kernel_size[0]
113 | N = self.kernel_size[1]
114 | DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
115 | if M * N > 1:
116 | ######################### Compute DoW #################
117 | # (input_channels, D_mul, M * N)
118 | D = self.D + self.D_diag
119 | W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))
120 |
121 | # einsum outputs (out_channels // groups, in_channels, M * N),
122 | # which is reshaped to
123 | # (out_channels, in_channels // groups, M, N)
124 | DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
125 | #######################################################
126 | else:
127 | DoW = torch.reshape(self.W, DoW_shape)
128 | if self.simam:
129 | DoW_h1, DoW_h2 = torch.chunk(DoW, 2, dim=2)
130 | DoW = torch.cat([self.simam_block(DoW_h1), DoW_h2], dim=2)
131 |
132 | return self._conv_forward(input, DoW)
133 | class DOConv2d_eval(Module):
134 | """
135 | DOConv2d can be used as an alternative for torch.nn.Conv2d.
136 | The interface is similar to that of Conv2d, with one exception:
137 | 1. D_mul: the depth multiplier for the over-parameterization.
138 | Note that the groups parameter switchs between DO-Conv (groups=1),
139 | DO-DConv (groups=in_channels), DO-GConv (otherwise).
140 | """
141 | __constants__ = ['stride', 'padding', 'dilation', 'groups',
142 | 'padding_mode', 'output_padding', 'in_channels',
143 | 'out_channels', 'kernel_size', 'D_mul']
144 | __annotations__ = {'bias': Optional[torch.Tensor]}
145 |
146 | def __init__(self, in_channels, out_channels, kernel_size=3, D_mul=None, stride=1,
147 | padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros', simam=False):
148 | super(DOConv2d_eval, self).__init__()
149 |
150 | kernel_size = (kernel_size, kernel_size)
151 | stride = (stride, stride)
152 | padding = (padding, padding)
153 | dilation = (dilation, dilation)
154 |
155 | if in_channels % groups != 0:
156 | raise ValueError('in_channels must be divisible by groups')
157 | if out_channels % groups != 0:
158 | raise ValueError('out_channels must be divisible by groups')
159 | valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
160 | if padding_mode not in valid_padding_modes:
161 | raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
162 | valid_padding_modes, padding_mode))
163 | self.in_channels = in_channels
164 | self.out_channels = out_channels
165 | self.kernel_size = kernel_size
166 | self.stride = stride
167 | self.padding = padding
168 | self.dilation = dilation
169 | self.groups = groups
170 | self.padding_mode = padding_mode
171 | self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))
172 | self.simam = simam
173 | #################################### Initailization of D & W ###################################
174 | M = self.kernel_size[0]
175 | N = self.kernel_size[1]
176 | self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, M, N))
177 | init.kaiming_uniform_(self.W, a=math.sqrt(5))
178 |
179 | self.register_parameter('bias', None)
180 | def extra_repr(self):
181 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
182 | ', stride={stride}')
183 | if self.padding != (0,) * len(self.padding):
184 | s += ', padding={padding}'
185 | if self.dilation != (1,) * len(self.dilation):
186 | s += ', dilation={dilation}'
187 | if self.groups != 1:
188 | s += ', groups={groups}'
189 | if self.bias is None:
190 | s += ', bias=False'
191 | if self.padding_mode != 'zeros':
192 | s += ', padding_mode={padding_mode}'
193 | return s.format(**self.__dict__)
194 |
195 | def __setstate__(self, state):
196 | super(DOConv2d, self).__setstate__(state)
197 | if not hasattr(self, 'padding_mode'):
198 | self.padding_mode = 'zeros'
199 |
200 | def _conv_forward(self, input, weight):
201 | if self.padding_mode != 'zeros':
202 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
203 | weight, self.bias, self.stride,
204 | (0, 0), self.dilation, self.groups)
205 | return F.conv2d(input, weight, self.bias, self.stride,
206 | self.padding, self.dilation, self.groups)
207 |
208 | def forward(self, input):
209 | return self._conv_forward(input, self.W)
210 |
211 | class simam_module(torch.nn.Module):
212 | def __init__(self, e_lambda=1e-4):
213 | super(simam_module, self).__init__()
214 | self.activaton = nn.Sigmoid()
215 | self.e_lambda = e_lambda
216 |
217 | def forward(self, x):
218 | b, c, h, w = x.size()
219 | n = w * h - 1
220 | x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
221 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
222 | return x * self.activaton(y)
--------------------------------------------------------------------------------
/evaluations/Evaluation_DID-Data_DDN-Data/psnr.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/evaluations/Evaluation_DID-Data_DDN-Data/psnr.m
--------------------------------------------------------------------------------
/evaluations/Evaluation_DID-Data_DDN-Data/ssim.m:
--------------------------------------------------------------------------------
1 | function [mssim, ssim_map] = ssim(img1, img2, K, window, L)
2 |
3 | % ========================================================================
4 | % SSIM Index with automatic downsampling, Version 1.0
5 | % Copyright(c) 2009 Zhou Wang
6 | % All Rights Reserved.
7 | %
8 | % ----------------------------------------------------------------------
9 | % Permission to use, copy, or modify this software and its documentation
10 | % for educational and research purposes only and without fee is hereby
11 | % granted, provided that this copyright notice and the original authors'
12 | % names appear on all copies and supporting documentation. This program
13 | % shall not be used, rewritten, or adapted as the basis of a commercial
14 | % software or hardware product without first obtaining permission of the
15 | % authors. The authors make no representations about the suitability of
16 | % this software for any purpose. It is provided "as is" without express
17 | % or implied warranty.
18 | %----------------------------------------------------------------------
19 | %
20 | % This is an implementation of the algorithm for calculating the
21 | % Structural SIMilarity (SSIM) index between two images
22 | %
23 | % Please refer to the following paper and the website with suggested usage
24 | %
25 | % Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
26 | % quality assessment: From error visibility to structural similarity,"
27 | % IEEE Transactios on Image Processing, vol. 13, no. 4, pp. 600-612,
28 | % Apr. 2004.
29 | %
30 | % http://www.ece.uwaterloo.ca/~z70wang/research/ssim/
31 | %
32 | % Note: This program is different from ssim_index.m, where no automatic
33 | % downsampling is performed. (downsampling was done in the above paper
34 | % and was described as suggested usage in the above website.)
35 | %
36 | % Kindly report any suggestions or corrections to zhouwang@ieee.org
37 | %
38 | %----------------------------------------------------------------------
39 | %
40 | %Input : (1) img1: the first image being compared
41 | % (2) img2: the second image being compared
42 | % (3) K: constants in the SSIM index formula (see the above
43 | % reference). defualt value: K = [0.01 0.03]
44 | % (4) window: local window for statistics (see the above
45 | % reference). default widnow is Gaussian given by
46 | % window = fspecial('gaussian', 11, 1.5);
47 | % (5) L: dynamic range of the images. default: L = 255
48 | %
49 | %Output: (1) mssim: the mean SSIM index value between 2 images.
50 | % If one of the images being compared is regarded as
51 | % perfect quality, then mssim can be considered as the
52 | % quality measure of the other image.
53 | % If img1 = img2, then mssim = 1.
54 | % (2) ssim_map: the SSIM index map of the test image. The map
55 | % has a smaller size than the input images. The actual size
56 | % depends on the window size and the downsampling factor.
57 | %
58 | %Basic Usage:
59 | % Given 2 test images img1 and img2, whose dynamic range is 0-255
60 | %
61 | % [mssim, ssim_map] = ssim(img1, img2);
62 | %
63 | %Advanced Usage:
64 | % User defined parameters. For example
65 | %
66 | % K = [0.05 0.05];
67 | % window = ones(8);
68 | % L = 100;
69 | % [mssim, ssim_map] = ssim(img1, img2, K, window, L);
70 | %
71 | %Visualize the results:
72 | %
73 | % mssim %Gives the mssim value
74 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map
75 | %========================================================================
76 |
77 |
78 | if (nargin < 2 || nargin > 5)
79 | mssim = -Inf;
80 | ssim_map = -Inf;
81 | return;
82 | end
83 |
84 | if (size(img1) ~= size(img2))
85 | mssim = -Inf;
86 | ssim_map = -Inf;
87 | return;
88 | end
89 |
90 | [M N] = size(img1);
91 |
92 | if (nargin == 2)
93 | if ((M < 11) || (N < 11))
94 | mssim = -Inf;
95 | ssim_map = -Inf;
96 | return
97 | end
98 | window = fspecial('gaussian', 11, 1.5); %
99 | K(1) = 0.01; % default settings
100 | K(2) = 0.03; %
101 | L = 255; %
102 | end
103 |
104 | if (nargin == 3)
105 | if ((M < 11) || (N < 11))
106 | mssim = -Inf;
107 | ssim_map = -Inf;
108 | return
109 | end
110 | window = fspecial('gaussian', 11, 1.5);
111 | L = 255;
112 | if (length(K) == 2)
113 | if (K(1) < 0 || K(2) < 0)
114 | mssim = -Inf;
115 | ssim_map = -Inf;
116 | return;
117 | end
118 | else
119 | mssim = -Inf;
120 | ssim_map = -Inf;
121 | return;
122 | end
123 | end
124 |
125 | if (nargin == 4)
126 | [H W] = size(window);
127 | if ((H*W) < 4 || (H > M) || (W > N))
128 | mssim = -Inf;
129 | ssim_map = -Inf;
130 | return
131 | end
132 | L = 255;
133 | if (length(K) == 2)
134 | if (K(1) < 0 || K(2) < 0)
135 | mssim = -Inf;
136 | ssim_map = -Inf;
137 | return;
138 | end
139 | else
140 | mssim = -Inf;
141 | ssim_map = -Inf;
142 | return;
143 | end
144 | end
145 |
146 | if (nargin == 5)
147 | [H W] = size(window);
148 | if ((H*W) < 4 || (H > M) || (W > N))
149 | mssim = -Inf;
150 | ssim_map = -Inf;
151 | return
152 | end
153 | if (length(K) == 2)
154 | if (K(1) < 0 || K(2) < 0)
155 | mssim = -Inf;
156 | ssim_map = -Inf;
157 | return;
158 | end
159 | else
160 | mssim = -Inf;
161 | ssim_map = -Inf;
162 | return;
163 | end
164 | end
165 |
166 |
167 | img1 = double(img1);
168 | img2 = double(img2);
169 |
170 | % automatic downsampling
171 | f = max(1,round(min(M,N)/256));
172 | %downsampling by f
173 | %use a simple low-pass filter
174 | if(f>1)
175 | lpf = ones(f,f);
176 | lpf = lpf/sum(lpf(:));
177 | img1 = imfilter(img1,lpf,'symmetric','same');
178 | img2 = imfilter(img2,lpf,'symmetric','same');
179 |
180 | img1 = img1(1:f:end,1:f:end);
181 | img2 = img2(1:f:end,1:f:end);
182 | end
183 |
184 | C1 = (K(1)*L)^2;
185 | C2 = (K(2)*L)^2;
186 | window = window/sum(sum(window));
187 |
188 | mu1 = filter2(window, img1, 'valid');
189 | mu2 = filter2(window, img2, 'valid');
190 | mu1_sq = mu1.*mu1;
191 | mu2_sq = mu2.*mu2;
192 | mu1_mu2 = mu1.*mu2;
193 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
194 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
195 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
196 |
197 | if (C1 > 0 && C2 > 0)
198 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
199 | else
200 | numerator1 = 2*mu1_mu2 + C1;
201 | numerator2 = 2*sigma12 + C2;
202 | denominator1 = mu1_sq + mu2_sq + C1;
203 | denominator2 = sigma1_sq + sigma2_sq + C2;
204 | ssim_map = ones(size(mu1));
205 | index = (denominator1.*denominator2 > 0);
206 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
207 | index = (denominator1 ~= 0) & (denominator2 == 0);
208 | ssim_map(index) = numerator1(index)./denominator1(index);
209 | end
210 |
211 | mssim = mean2(ssim_map);
212 |
213 | return
--------------------------------------------------------------------------------
/evaluations/Evaluation_DID-Data_DDN-Data/statistic.m:
--------------------------------------------------------------------------------
1 | clear all;
2 | ts =0;
3 | tp =0;
4 | % for i=1:1200 % the number of testing samples DID-Data
5 | for i=1:1400 % the number of testing samples DDN-Data
6 | x_true=im2double(imread(strcat('./gt/DDN-Data/target/',sprintf('%d.jpg',i)))); % groundtruth
7 | % x_true=im2double(imread(strcat('./gt/DID-Data/target/',sprintf('%d.jpg',i)))); % groundtruth
8 | x_true = rgb2ycbcr(x_true);
9 | x_true = x_true(:,:,1);
10 | x = im2double(imread(strcat('./results/DDN-Data/',sprintf('%d.png',i)))); %reconstructed image
11 | % x = im2double(imread(strcat('./results/DID-Data/',sprintf('%d.png',i)))); %reconstructed image
12 | x = rgb2ycbcr(x);
13 | x = x(:,:,1);
14 | tp= tp+ psnr(x,x_true);
15 | ts= ts+ssim(x*255,x_true*255);
16 | end
17 | % fprintf('psnr=%6.4f, ssim=%6.4f\n',tp/1200,ts/1200) % the number of testing samples DID-Data
18 | fprintf('psnr=%6.4f, ssim=%6.4f\n',tp/1400,ts/1400) % the number of testing samples DDN-Data
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/evaluations/Evalution_Rain200L_Rain200H_SPA-Data/evaluate_PSNR_SSIM.m:
--------------------------------------------------------------------------------
1 | clc;close all;clear all;addpath(genpath('./'));
2 |
3 | datasets = {'Rain200L'};
4 | % datasets = {'Rain200L', 'Rain200H', 'SPA-Data'};
5 | num_set = length(datasets);
6 |
7 | psnr_alldatasets = 0;
8 | ssim_alldatasets = 0;
9 | for idx_set = 1:num_set
10 | file_path = strcat('./results/', datasets{idx_set}, '/');
11 | gt_path = strcat('./Datasets/', datasets{idx_set}, '/');
12 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))];
13 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))];
14 | img_num = length(path_list);
15 |
16 | total_psnr = 0;
17 | total_ssim = 0;
18 | if img_num > 0
19 | for j = 1:img_num
20 | image_name = path_list(j).name;
21 | gt_name = gt_list(j).name;
22 | input = imread(strcat(file_path,image_name));
23 | gt = imread(strcat(gt_path, gt_name));
24 | ssim_val = compute_ssim(input, gt);
25 | psnr_val = compute_psnr(input, gt);
26 | total_ssim = total_ssim + ssim_val;
27 | total_psnr = total_psnr + psnr_val;
28 | end
29 | end
30 | qm_psnr = total_psnr / img_num;
31 | qm_ssim = total_ssim / img_num;
32 |
33 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim);
34 |
35 | psnr_alldatasets = psnr_alldatasets + qm_psnr;
36 | ssim_alldatasets = ssim_alldatasets + qm_ssim;
37 |
38 | end
39 |
40 | fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set);
41 |
42 | function ssim_mean=compute_ssim(img1,img2)
43 | if size(img1, 3) == 3
44 | img1 = rgb2ycbcr(img1);
45 | img1 = img1(:, :, 1);
46 | end
47 |
48 | if size(img2, 3) == 3
49 | img2 = rgb2ycbcr(img2);
50 | img2 = img2(:, :, 1);
51 | end
52 | ssim_mean = SSIM_index(img1, img2);
53 | end
54 |
55 | function psnr=compute_psnr(img1,img2)
56 | if size(img1, 3) == 3
57 | img1 = rgb2ycbcr(img1);
58 | img1 = img1(:, :, 1);
59 | end
60 |
61 | if size(img2, 3) == 3
62 | img2 = rgb2ycbcr(img2);
63 | img2 = img2(:, :, 1);
64 | end
65 |
66 | imdff = double(img1) - double(img2);
67 | imdff = imdff(:);
68 | rmse = sqrt(mean(imdff.^2));
69 | psnr = 20*log10(255/rmse);
70 |
71 | end
72 |
73 | function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L)
74 |
75 | %========================================================================
76 | %SSIM Index, Version 1.0
77 | %Copyright(c) 2003 Zhou Wang
78 | %All Rights Reserved.
79 | %
80 | %The author is with Howard Hughes Medical Institute, and Laboratory
81 | %for Computational Vision at Center for Neural Science and Courant
82 | %Institute of Mathematical Sciences, New York University.
83 | %
84 | %----------------------------------------------------------------------
85 | %Permission to use, copy, or modify this software and its documentation
86 | %for educational and research purposes only and without fee is hereby
87 | %granted, provided that this copyright notice and the original authors'
88 | %names appear on all copies and supporting documentation. This program
89 | %shall not be used, rewritten, or adapted as the basis of a commercial
90 | %software or hardware product without first obtaining permission of the
91 | %authors. The authors make no representations about the suitability of
92 | %this software for any purpose. It is provided "as is" without express
93 | %or implied warranty.
94 | %----------------------------------------------------------------------
95 | %
96 | %This is an implementation of the algorithm for calculating the
97 | %Structural SIMilarity (SSIM) index between two images. Please refer
98 | %to the following paper:
99 | %
100 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
101 | %quality assessment: From error measurement to structural similarity"
102 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
103 | %
104 | %Kindly report any suggestions or corrections to zhouwang@ieee.org
105 | %
106 | %----------------------------------------------------------------------
107 | %
108 | %Input : (1) img1: the first image being compared
109 | % (2) img2: the second image being compared
110 | % (3) K: constants in the SSIM index formula (see the above
111 | % reference). defualt value: K = [0.01 0.03]
112 | % (4) window: local window for statistics (see the above
113 | % reference). default widnow is Gaussian given by
114 | % window = fspecial('gaussian', 11, 1.5);
115 | % (5) L: dynamic range of the images. default: L = 255
116 | %
117 | %Output: (1) mssim: the mean SSIM index value between 2 images.
118 | % If one of the images being compared is regarded as
119 | % perfect quality, then mssim can be considered as the
120 | % quality measure of the other image.
121 | % If img1 = img2, then mssim = 1.
122 | % (2) ssim_map: the SSIM index map of the test image. The map
123 | % has a smaller size than the input images. The actual size:
124 | % size(img1) - size(window) + 1.
125 | %
126 | %Default Usage:
127 | % Given 2 test images img1 and img2, whose dynamic range is 0-255
128 | %
129 | % [mssim ssim_map] = ssim_index(img1, img2);
130 | %
131 | %Advanced Usage:
132 | % User defined parameters. For example
133 | %
134 | % K = [0.05 0.05];
135 | % window = ones(8);
136 | % L = 100;
137 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
138 | %
139 | %See the results:
140 | %
141 | % mssim %Gives the mssim value
142 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map
143 | %
144 | %========================================================================
145 |
146 |
147 | if (nargin < 2 || nargin > 5)
148 | ssim_index = -Inf;
149 | ssim_map = -Inf;
150 | return;
151 | end
152 |
153 | if (size(img1) ~= size(img2))
154 | ssim_index = -Inf;
155 | ssim_map = -Inf;
156 | return;
157 | end
158 |
159 | [M N] = size(img1);
160 |
161 | if (nargin == 2)
162 | if ((M < 11) || (N < 11))
163 | ssim_index = -Inf;
164 | ssim_map = -Inf;
165 | return
166 | end
167 | window = fspecial('gaussian', 11, 1.5); %
168 | K(1) = 0.01; % default settings
169 | K(2) = 0.03; %
170 | L = 255; %
171 | end
172 |
173 | if (nargin == 3)
174 | if ((M < 11) || (N < 11))
175 | ssim_index = -Inf;
176 | ssim_map = -Inf;
177 | return
178 | end
179 | window = fspecial('gaussian', 11, 1.5);
180 | L = 255;
181 | if (length(K) == 2)
182 | if (K(1) < 0 || K(2) < 0)
183 | ssim_index = -Inf;
184 | ssim_map = -Inf;
185 | return;
186 | end
187 | else
188 | ssim_index = -Inf;
189 | ssim_map = -Inf;
190 | return;
191 | end
192 | end
193 |
194 | if (nargin == 4)
195 | [H W] = size(window);
196 | if ((H*W) < 4 || (H > M) || (W > N))
197 | ssim_index = -Inf;
198 | ssim_map = -Inf;
199 | return
200 | end
201 | L = 255;
202 | if (length(K) == 2)
203 | if (K(1) < 0 || K(2) < 0)
204 | ssim_index = -Inf;
205 | ssim_map = -Inf;
206 | return;
207 | end
208 | else
209 | ssim_index = -Inf;
210 | ssim_map = -Inf;
211 | return;
212 | end
213 | end
214 |
215 | if (nargin == 5)
216 | [H W] = size(window);
217 | if ((H*W) < 4 || (H > M) || (W > N))
218 | ssim_index = -Inf;
219 | ssim_map = -Inf;
220 | return
221 | end
222 | if (length(K) == 2)
223 | if (K(1) < 0 || K(2) < 0)
224 | ssim_index = -Inf;
225 | ssim_map = -Inf;
226 | return;
227 | end
228 | else
229 | ssim_index = -Inf;
230 | ssim_map = -Inf;
231 | return;
232 | end
233 | end
234 |
235 | C1 = (K(1)*L)^2;
236 | C2 = (K(2)*L)^2;
237 | window = window/sum(sum(window));
238 | img1 = double(img1);
239 | img2 = double(img2);
240 |
241 | mu1 = filter2(window, img1, 'valid');
242 | mu2 = filter2(window, img2, 'valid');
243 | mu1_sq = mu1.*mu1;
244 | mu2_sq = mu2.*mu2;
245 | mu1_mu2 = mu1.*mu2;
246 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
247 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
248 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
249 |
250 | if (C1 > 0 & C2 > 0)
251 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
252 | else
253 | numerator1 = 2*mu1_mu2 + C1;
254 | numerator2 = 2*sigma12 + C2;
255 | denominator1 = mu1_sq + mu2_sq + C1;
256 | denominator2 = sigma1_sq + sigma2_sq + C2;
257 | ssim_map = ones(size(mu1));
258 | index = (denominator1.*denominator2 > 0);
259 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
260 | index = (denominator1 ~= 0) & (denominator2 == 0);
261 | ssim_map(index) = numerator1(index)./denominator1(index);
262 | end
263 |
264 | mssim = mean2(ssim_map);
265 |
266 | end
267 |
268 |
--------------------------------------------------------------------------------
/get_parameter_number.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def get_parameter_number(net):
4 | total_num = sum(np.prod(p.size()) for p in net.parameters())
5 | trainable_num = sum(np.prod(p.size()) for p in net.parameters() if p.requires_grad)
6 | print('Total: ', total_num)
7 | print('Trainable: ', trainable_num)
8 |
9 |
10 | if __name__=='__main__':
11 | from DeepRFT_MIMO import DeepRFT_flops as Net
12 | import torch
13 | from ptflops import get_model_complexity_info
14 | with torch.cuda.device(0):
15 | net = Net()
16 | macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True,
17 | print_per_layer_stat=True, verbose=True)
18 | print('{:<30} {:<8}'.format('Computational complexity: ', macs))
19 | print('{:<30} {:<8}'.format('Number of parameters: ', params))
20 |
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | from doconv_pytorch import *
2 |
3 |
4 | class BasicConv(nn.Module):
5 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=False, norm=False, relu=True, transpose=False,
6 | channel_shuffle_g=0, norm_method=nn.BatchNorm2d, groups=1):
7 | super(BasicConv, self).__init__()
8 | self.channel_shuffle_g = channel_shuffle_g
9 | self.norm = norm
10 | if bias and norm:
11 | bias = False
12 |
13 | padding = kernel_size // 2
14 | layers = list()
15 | if transpose:
16 | padding = kernel_size // 2 - 1
17 | layers.append(
18 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups))
19 | else:
20 | layers.append(
21 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups))
22 | if norm:
23 | layers.append(norm_method(out_channel))
24 | elif relu:
25 | layers.append(nn.ReLU(inplace=True))
26 |
27 | self.main = nn.Sequential(*layers)
28 |
29 | def forward(self, x):
30 | return self.main(x)
31 |
32 | class BasicConv_do(nn.Module):
33 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, bias=False, norm=False, relu=True, transpose=False,
34 | relu_method=nn.ReLU, groups=1, norm_method=nn.BatchNorm2d):
35 | super(BasicConv_do, self).__init__()
36 | if bias and norm:
37 | bias = False
38 |
39 | padding = kernel_size // 2
40 | layers = list()
41 | if transpose:
42 | padding = kernel_size // 2 - 1
43 | layers.append(
44 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
45 | else:
46 | layers.append(
47 | DOConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups))
48 | if norm:
49 | layers.append(norm_method(out_channel))
50 | if relu:
51 | if relu_method == nn.ReLU:
52 | layers.append(nn.ReLU(inplace=True))
53 | elif relu_method == nn.LeakyReLU:
54 | layers.append(nn.LeakyReLU(inplace=True))
55 | else:
56 | layers.append(relu_method())
57 | self.main = nn.Sequential(*layers)
58 |
59 | def forward(self, x):
60 | return self.main(x)
61 |
62 | class BasicConv_do_eval(nn.Module):
63 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=False, norm=False, relu=True, transpose=False,
64 | relu_method=nn.ReLU, groups=1, norm_method=nn.BatchNorm2d):
65 | super(BasicConv_do_eval, self).__init__()
66 | if bias and norm:
67 | bias = False
68 |
69 | padding = kernel_size // 2
70 | layers = list()
71 | if transpose:
72 | padding = kernel_size // 2 - 1
73 | layers.append(
74 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
75 | else:
76 | layers.append(
77 | DOConv2d_eval(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups))
78 | if norm:
79 | layers.append(norm_method(out_channel))
80 | if relu:
81 | if relu_method == nn.ReLU:
82 | layers.append(nn.ReLU(inplace=True))
83 | elif relu_method == nn.LeakyReLU:
84 | layers.append(nn.LeakyReLU(inplace=True))
85 | else:
86 | layers.append(relu_method())
87 | self.main = nn.Sequential(*layers)
88 |
89 | def forward(self, x):
90 | return self.main(x)
91 |
92 | class ResBlock(nn.Module):
93 | def __init__(self, out_channel):
94 | super(ResBlock, self).__init__()
95 | self.main = nn.Sequential(
96 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=True, norm=False),
97 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False, norm=False)
98 | )
99 |
100 | def forward(self, x):
101 | return self.main(x) + x
102 |
103 | class ResBlock_do(nn.Module):
104 | def __init__(self, out_channel):
105 | super(ResBlock_do, self).__init__()
106 | self.main = nn.Sequential(
107 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=True),
108 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
109 | )
110 |
111 | def forward(self, x):
112 | return self.main(x) + x
113 |
114 | class ResBlock_do_eval(nn.Module):
115 | def __init__(self, out_channel):
116 | super(ResBlock_do_eval, self).__init__()
117 | self.main = nn.Sequential(
118 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=True),
119 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
120 | )
121 |
122 | def forward(self, x):
123 | return self.main(x) + x
124 |
125 |
126 | class ResBlock_do_FECB_bench(nn.Module):
127 | def __init__(self, out_channel, norm='backward'):
128 | super(ResBlock_do_FECB_bench, self).__init__()
129 | self.main = nn.Sequential(
130 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=True),
131 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
132 | )
133 | self.main_fft = nn.Sequential(
134 | BasicConv_do(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=True),
135 | BasicConv_do(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=False)
136 | )
137 | self.dim = out_channel
138 | self.norm = norm
139 | def forward(self, x):
140 | _, _, H, W = x.shape
141 | dim = 1
142 | y = torch.fft.rfft2(x, norm=self.norm)
143 | y_imag = y.imag
144 | y_real = y.real
145 | y_f = torch.cat([y_real, y_imag], dim=dim)
146 | y = self.main_fft(y_f)
147 | y_real, y_imag = torch.chunk(y, 2, dim=dim)
148 | y = torch.complex(y_real, y_imag)
149 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm)
150 | return self.main(x) + x + y
151 |
152 | class ResBlock_FECB_bench(nn.Module):
153 | def __init__(self, n_feat, norm='backward'): # 'ortho'
154 | super(ResBlock_FECB_bench, self).__init__()
155 | self.main = nn.Sequential(
156 | BasicConv(n_feat, n_feat, kernel_size=3, stride=1, relu=True),
157 | BasicConv(n_feat, n_feat, kernel_size=3, stride=1, relu=False)
158 | )
159 | self.main_fft = nn.Sequential(
160 | BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=True),
161 | BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=False)
162 | )
163 | self.dim = n_feat
164 | self.norm = norm
165 | def forward(self, x):
166 | _, _, H, W = x.shape
167 | dim = 1
168 | y = torch.fft.rfft2(x, norm=self.norm)
169 | y_imag = y.imag
170 | y_real = y.real
171 | y_f = torch.cat([y_real, y_imag], dim=dim)
172 | y = self.main_fft(y_f)
173 | y_real, y_imag = torch.chunk(y, 2, dim=dim)
174 | y = torch.complex(y_real, y_imag)
175 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm)
176 | return self.main(x) + x + y
177 | class ResBlock_do_FECB_bench_eval(nn.Module):
178 | def __init__(self, out_channel, norm='backward'):
179 | super(ResBlock_do_FECB_bench_eval, self).__init__()
180 | self.main = nn.Sequential(
181 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=True),
182 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
183 | )
184 | self.main_fft = nn.Sequential(
185 | BasicConv_do_eval(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=True),
186 | BasicConv_do_eval(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=False)
187 | )
188 | self.dim = out_channel
189 | self.norm = norm
190 | def forward(self, x):
191 | _, _, H, W = x.shape
192 | dim = 1
193 | y = torch.fft.rfft2(x, norm=self.norm)
194 | y_imag = y.imag
195 | y_real = y.real
196 | y_f = torch.cat([y_real, y_imag], dim=dim)
197 | y = self.main_fft(y_f)
198 | y_real, y_imag = torch.chunk(y, 2, dim=dim)
199 | y = torch.complex(y_real, y_imag)
200 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm)
201 | return self.main(x) + x + y
202 |
203 | def window_partitions(x, window_size):
204 | """
205 | Args:
206 | x: (B, C, H, W)
207 | window_size (int): window size
208 | Returns:
209 | windows: (num_windows*B, C, window_size, window_size)
210 | """
211 | if isinstance(window_size, int):
212 | window_size = [window_size, window_size]
213 | B, C, H, W = x.shape
214 | x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
215 | windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
216 | return windows
217 |
218 |
219 | def window_reverses(windows, window_size, H, W):
220 | """
221 | Args:
222 | windows: (num_windows*B, C, window_size, window_size)
223 | window_size (int): Window size
224 | H (int): Height of image
225 | W (int): Width of image
226 | Returns:
227 | x: (B, C, H, W)
228 | """
229 | # B = int(windows.shape[0] / (H * W / window_size / window_size))
230 | # print('B: ', B)
231 | # print(H // window_size)
232 | # print(W // window_size)
233 | if isinstance(window_size, int):
234 | window_size = [window_size, window_size]
235 | C = windows.shape[1]
236 | # print('C: ', C)
237 | x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1])
238 | x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
239 | return x
240 |
241 | def window_partitionx(x, window_size):
242 | _, _, H, W = x.shape
243 | h, w = window_size * (H // window_size), window_size * (W // window_size)
244 | x_main = window_partitions(x[:, :, :h, :w], window_size)
245 | b_main = x_main.shape[0]
246 | if h == H and w == W:
247 | return x_main, [b_main]
248 | if h != H and w != W:
249 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
250 | b_r = x_r.shape[0] + b_main
251 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
252 | b_d = x_d.shape[0] + b_r
253 | x_dd = x[:, :, -window_size:, -window_size:]
254 | b_dd = x_dd.shape[0] + b_d
255 | # batch_list = [b_main, b_r, b_d, b_dd]
256 | return torch.cat([x_main, x_r, x_d, x_dd], dim=0), [b_main, b_r, b_d, b_dd]
257 | if h == H and w != W:
258 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
259 | b_r = x_r.shape[0] + b_main
260 | return torch.cat([x_main, x_r], dim=0), [b_main, b_r]
261 | if h != H and w == W:
262 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
263 | b_d = x_d.shape[0] + b_main
264 | return torch.cat([x_main, x_d], dim=0), [b_main, b_d]
265 | def window_reversex(windows, window_size, H, W, batch_list):
266 | h, w = window_size * (H // window_size), window_size * (W // window_size)
267 | # print(windows[:batch_list[0], ...].shape)
268 | x_main = window_reverses(windows[:batch_list[0], ...], window_size, h, w)
269 | B, C, _, _ = x_main.shape
270 | # print('windows: ', windows.shape)
271 | # print('batch_list: ', batch_list)
272 | if torch.is_complex(windows):
273 | res = torch.complex(torch.zeros([B, C, H, W]), torch.zeros([B, C, H, W]))
274 | res = res.to(windows.device)
275 | else:
276 | res = torch.zeros([B, C, H, W], device=windows.device)
277 |
278 | res[:, :, :h, :w] = x_main
279 | if h == H and w == W:
280 | return res
281 | if h != H and w != W and len(batch_list) == 4:
282 | x_dd = window_reverses(windows[batch_list[2]:, ...], window_size, window_size, window_size)
283 | res[:, :, h:, w:] = x_dd[:, :, h - H:, w - W:]
284 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
285 | res[:, :, :h, w:] = x_r[:, :, :, w - W:]
286 | x_d = window_reverses(windows[batch_list[1]:batch_list[2], ...], window_size, window_size, w)
287 | res[:, :, h:, :w] = x_d[:, :, h - H:, :]
288 | return res
289 | if w != W and len(batch_list) == 2:
290 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
291 | res[:, :, :h, w:] = x_r[:, :, :, w - W:]
292 | if h != H and len(batch_list) == 2:
293 | x_d = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, window_size, w)
294 | res[:, :, h:, :w] = x_d[:, :, h - H:, :]
295 | return res
296 |
297 | def window_partitions_old(x, window_size):
298 | """
299 | Args:
300 | x: (B, C, H, W)
301 | window_size (int): window size
302 |
303 | Returns:
304 | windows: (num_windows*B, C, window_size, window_size)
305 | """
306 | B, C, H, W = x.shape
307 | x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
308 | windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size)
309 | return windows
310 |
311 |
312 | def window_reverses_old(windows, window_size, H, W):
313 | """
314 | Args:
315 | windows: (num_windows*B, C, window_size, window_size)
316 | window_size (int): Window size
317 | H (int): Height of image
318 | W (int): Width of image
319 |
320 | Returns:
321 | x: (B, C, H, W)
322 | """
323 | # B = int(windows.shape[0] / (H * W / window_size / window_size))
324 | # print('B: ', B)
325 | # print(H // window_size)
326 | # print(W // window_size)
327 | C = windows.shape[1]
328 | # print('C: ', C)
329 | x = windows.view(-1, H // window_size, W // window_size, C, window_size, window_size)
330 | x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
331 | return x
332 |
333 | def window_partitionx_old(x, window_size):
334 | _, _, H, W = x.shape
335 | h, w = window_size * (H // window_size), window_size * (W // window_size)
336 | x_main = window_partitions(x[:, :, :h, :w], window_size)
337 | b_main = x_main.shape[0]
338 | if h == H and w == W:
339 | return x_main, [b_main]
340 | if h != H and w != W:
341 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
342 | b_r = x_r.shape[0] + b_main
343 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
344 | b_d = x_d.shape[0] + b_r
345 | x_dd = x[:, :, -window_size:, -window_size:]
346 | b_dd = x_dd.shape[0] + b_d
347 | # batch_list = [b_main, b_r, b_d, b_dd]
348 | return torch.cat([x_main, x_r, x_d, x_dd], dim=0), [b_main, b_r, b_d, b_dd]
349 | if h == H and w != W:
350 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
351 | b_r = x_r.shape[0] + b_main
352 | return torch.cat([x_main, x_r], dim=0), [b_main, b_r]
353 | if h != H and w == W:
354 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
355 | b_d = x_d.shape[0] + b_main
356 | return torch.cat([x_main, x_d], dim=0), [b_main, b_d]
357 |
358 | def window_reversex_old(windows, window_size, H, W, batch_list):
359 | h, w = window_size * (H // window_size), window_size * (W // window_size)
360 | x_main = window_reverses(windows[:batch_list[0], ...], window_size, h, w)
361 | B, C, _, _ = x_main.shape
362 | # print('windows: ', windows.shape)
363 | # print('batch_list: ', batch_list)
364 | res = torch.zeros([B, C, H, W],device=windows.device)
365 | res[:, :, :h, :w] = x_main
366 | if h == H and w == W:
367 | return res
368 | if h != H and w != W and len(batch_list) == 4:
369 | x_dd = window_reverses(windows[batch_list[2]:, ...], window_size, window_size, window_size)
370 | res[:, :, h:, w:] = x_dd[:, :, h - H:, w - W:]
371 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
372 | res[:, :, :h, w:] = x_r[:, :, :, w - W:]
373 | x_d = window_reverses(windows[batch_list[1]:batch_list[2], ...], window_size, window_size, w)
374 | res[:, :, h:, :w] = x_d[:, :, h - H:, :]
375 | return res
376 | if w != W and len(batch_list) == 2:
377 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
378 | res[:, :, :h, w:] = x_r[:, :, :, w - W:]
379 | if h != H and len(batch_list) == 2:
380 | x_d = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, window_size, w)
381 | res[:, :, h:, :w] = x_d[:, :, h - H:, :]
382 | return res
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class CharbonnierLoss(nn.Module):
6 | """Charbonnier Loss (L1)"""
7 |
8 | def __init__(self, eps=1e-3):
9 | super(CharbonnierLoss, self).__init__()
10 | self.eps = eps
11 |
12 | def forward(self, x, y):
13 | diff = x.to('cuda:0') - y.to('cuda:0')
14 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
15 | return loss
16 |
17 | class EdgeLoss(nn.Module):
18 | def __init__(self):
19 | super(EdgeLoss, self).__init__()
20 | k = torch.Tensor([[.05, .25, .4, .25, .05]])
21 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
22 | if torch.cuda.is_available():
23 | self.kernel = self.kernel.to('cuda:0')
24 | self.loss = CharbonnierLoss()
25 |
26 | def conv_gauss(self, img):
27 | n_channels, _, kw, kh = self.kernel.shape
28 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
29 | return F.conv2d(img, self.kernel, groups=n_channels)
30 |
31 | def laplacian_kernel(self, current):
32 | filtered = self.conv_gauss(current)
33 | down = filtered[:,:,::2,::2]
34 | new_filter = torch.zeros_like(filtered)
35 | new_filter[:,:,::2,::2] = down*4
36 | filtered = self.conv_gauss(new_filter)
37 | diff = current - filtered
38 | return diff
39 |
40 | def forward(self, x, y):
41 | loss = self.loss(self.laplacian_kernel(x.to('cuda:0')), self.laplacian_kernel(y.to('cuda:0')))
42 | return loss
43 |
44 | class fftLoss(nn.Module):
45 | def __init__(self):
46 | super(fftLoss, self).__init__()
47 |
48 | def forward(self, x, y):
49 | diff = torch.fft.fft2(x.to('cuda:0')) - torch.fft.fft2(y.to('cuda:0'))
50 | loss = torch.mean(abs(diff))
51 | return loss
52 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from layers import *
2 | import numbers
3 | from einops import rearrange
4 |
5 | class Downsample(nn.Module):
6 | def __init__(self, n_feat):
7 | super(Downsample, self).__init__()
8 |
9 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
10 | nn.PixelUnshuffle(2))
11 |
12 | def forward(self, x):
13 | return self.body(x)
14 |
15 | class Upsample(nn.Module):
16 | def __init__(self, n_feat):
17 | super(Upsample, self).__init__()
18 |
19 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
20 | nn.PixelShuffle(2))
21 |
22 | def forward(self, x):
23 | return self.body(x)
24 |
25 | def to_3d(x):
26 | return rearrange(x, 'b c h w -> b (h w) c')
27 |
28 | def to_4d(x, h, w):
29 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
30 |
31 | class BiasFree_LayerNorm(nn.Module):
32 | def __init__(self, normalized_shape):
33 | super(BiasFree_LayerNorm, self).__init__()
34 | if isinstance(normalized_shape, numbers.Integral):
35 | normalized_shape = (normalized_shape,)
36 | normalized_shape = torch.Size(normalized_shape)
37 |
38 | assert len(normalized_shape) == 1
39 |
40 | self.weight = nn.Parameter(torch.ones(normalized_shape))
41 | self.normalized_shape = normalized_shape
42 |
43 | def forward(self, x):
44 | sigma = x.var(-1, keepdim=True, unbiased=False)
45 | return x / torch.sqrt(sigma + 1e-5) * self.weight
46 |
47 | class WithBias_LayerNorm(nn.Module):
48 | def __init__(self, normalized_shape):
49 | super(WithBias_LayerNorm, self).__init__()
50 | if isinstance(normalized_shape, numbers.Integral):
51 | normalized_shape = (normalized_shape,)
52 | normalized_shape = torch.Size(normalized_shape)
53 |
54 | assert len(normalized_shape) == 1
55 |
56 | self.weight = nn.Parameter(torch.ones(normalized_shape))
57 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
58 | self.normalized_shape = normalized_shape
59 |
60 | def forward(self, x):
61 | mu = x.mean(-1, keepdim=True)
62 | sigma = x.var(-1, keepdim=True, unbiased=False)
63 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
64 |
65 | class LayerNorm(nn.Module):
66 | def __init__(self, dim, LayerNorm_type):
67 | super(LayerNorm, self).__init__()
68 | if LayerNorm_type == 'BiasFree':
69 | self.body = BiasFree_LayerNorm(dim)
70 | else:
71 | self.body = WithBias_LayerNorm(dim)
72 |
73 | def forward(self, x):
74 | h, w = x.shape[-2:]
75 | return to_4d(self.body(to_3d(x)), h, w)
76 |
77 | class FeedForward(nn.Module):
78 | def __init__(self, dim, ffn_expansion_factor, bias, BasicConv=BasicConv):
79 | super(FeedForward, self).__init__()
80 |
81 | hidden_features = int(dim * ffn_expansion_factor)
82 |
83 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
84 |
85 | self.dwconv = BasicConv(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, bias=bias, relu=False, groups=hidden_features * 2)
86 |
87 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
88 |
89 | def forward(self, x):
90 | x = self.project_in(x)
91 | x1, x2 = self.dwconv(x).chunk(2, dim=1)
92 | x = F.gelu(x1) * x2
93 | x = self.project_out(x)
94 | return x
95 |
96 | class Attention(nn.Module):
97 | def __init__(self,scale, dim, num_heads, bias):
98 | super(Attention, self).__init__()
99 | self.num_heads = num_heads
100 |
101 | self.sacle = scale
102 |
103 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
104 |
105 | self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
106 | self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
107 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
108 | self.attn_drop = nn.Dropout(0.)
109 |
110 | def forward(self, x):
111 | b, c, h, w = x.shape
112 |
113 | qkv = self.qkv_dwconv(self.qkv(x))
114 | q, k, v = qkv.chunk(3, dim=1)
115 |
116 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
117 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
118 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
119 |
120 | q = torch.nn.functional.normalize(q, dim=-1)
121 | k = torch.nn.functional.normalize(k, dim=-1)
122 |
123 | _, _, C, _ = q.shape
124 |
125 | mask1 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
126 |
127 | attn = (q @ k.transpose(-2, -1)) * self.temperature
128 |
129 | if self.sacle == 1:
130 | index = torch.topk(attn, k=int(C*6/10), dim=-1, largest=True)[1]
131 | elif self.sacle == 0.5:
132 | index = torch.topk(attn, k=int(C * 7 / 10), dim=-1, largest=True)[1]
133 | elif self.sacle == 0.25:
134 | index = torch.topk(attn, k=int(C * 8 / 10), dim=-1, largest=True)[1]
135 | mask1.scatter_(-1, index, 1.)
136 | attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf')))
137 |
138 |
139 | attn1 = attn1.softmax(dim=-1)
140 |
141 | out = (attn1 @ v)
142 |
143 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
144 |
145 | out = self.project_out(out)
146 | return out
147 |
148 | class TransformerBlock(nn.Module):
149 | def __init__(self, scale,dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, BasicConv=BasicConv):
150 | super(TransformerBlock, self).__init__()
151 |
152 | self.norm1 = LayerNorm(dim, LayerNorm_type)
153 | self.attn = Attention(scale,dim, num_heads, bias)
154 | self.norm2 = LayerNorm(dim, LayerNorm_type)
155 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias, BasicConv=BasicConv)
156 |
157 | def forward(self, x):
158 | x = x + self.attn(self.norm1(x))
159 | x = x + self.ffn(self.norm2(x))
160 |
161 | return x
162 |
163 | class FECB_SCTB(nn.Module):
164 | def __init__(self , out_channel, num_res=8, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=1, ffn_expansion_factor=1, bias=False, LayerNorm_type='WithBias', scale=1):
165 | super(FECB_SCTB, self).__init__()
166 |
167 | layers = []
168 | for _ in range(num_res):
169 | layers.append(ResBlock(out_channel))
170 | layers.append(TransformerBlock(scale = scale,dim=out_channel, num_heads=num_heads, ffn_expansion_factor=ffn_expansion_factor, bias=bias,
171 | LayerNorm_type=LayerNorm_type, BasicConv=BasicConv))
172 |
173 | self.layers = nn.Sequential(*layers)
174 |
175 | def forward(self, x):
176 | return self.layers(x)
177 |
178 | class GFM(nn.Module):
179 | def __init__(self, in_channel, out_channel, BasicConv=BasicConv):
180 | super(GFM, self).__init__()
181 | self.conv_max = nn.Sequential(
182 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
183 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
184 | )
185 | self.conv_mid = nn.Sequential(
186 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
187 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
188 | )
189 | self.conv_small = nn.Sequential(
190 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
191 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
192 | )
193 |
194 | self.conv1 =BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True)
195 | self.conv2 = BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True)
196 |
197 |
198 | def forward(self, x_max,x_mid,x_small):
199 |
200 | y_max=x_max +x_mid +x_small
201 |
202 | x_max = self.conv_max(x_max)
203 | x_mid = self.conv_max(x_mid)
204 | x_small = self.conv_max(x_small)
205 |
206 | x =F.tanh(x_mid) * x_max
207 | x = self.conv1(x)
208 |
209 | x =F.tanh(x_small) * x
210 | x = self.conv2(x)
211 |
212 | return x+y_max
213 |
214 | class SCM(nn.Module):
215 | def __init__(self, out_plane, BasicConv=BasicConv, inchannel=3):
216 | super(SCM, self).__init__()
217 | self.main = nn.Sequential(
218 | BasicConv(inchannel, out_plane//4, kernel_size=3, stride=1, relu=True),
219 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
220 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
221 | BasicConv(out_plane // 2, out_plane-inchannel, kernel_size=1, stride=1, relu=True)
222 | )
223 |
224 | self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False)
225 |
226 | def forward(self, x):
227 | x = torch.cat([x, self.main(x)], dim=1)
228 | return self.conv(x)
229 |
230 | class FAM(nn.Module):
231 | def __init__(self, channel, BasicConv=BasicConv):
232 | super(FAM, self).__init__()
233 | self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False)
234 |
235 | def forward(self, x1, x2):
236 | x = x1 * x2
237 | out = x1 + self.merge(x)
238 | return out
239 |
240 | class MSDT(nn.Module):
241 | def __init__(self, num_res=8, inference=False):
242 | super(MSDT, self).__init__()
243 | self.inference = inference
244 | if not inference:
245 | BasicConv = BasicConv_do
246 | ResBlock = ResBlock_do_FECB_bench
247 | else:
248 | BasicConv = BasicConv_do_eval
249 | ResBlock = ResBlock_do_FECB_bench_eval
250 | base_channel = 32
251 |
252 | heads = [1, 2, 4]
253 | ffn_expansion_factor = 2.66
254 | bias = False
255 | LayerNorm_type = 'WithBias'
256 | scale = [1,0.5,0.25]
257 |
258 | self.Encoder = nn.ModuleList([
259 | FECB_SCTB(base_channel, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[0],
260 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= scale[0]),
261 | FECB_SCTB(base_channel * 2, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[1],
262 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= scale[1]),
263 | FECB_SCTB(base_channel * 4, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[2],
264 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= scale[2]),
265 | ])
266 |
267 | self.feat_extract = nn.ModuleList([
268 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
269 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
270 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
271 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
272 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
273 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
274 | ])
275 |
276 | self.Decoder = nn.ModuleList([
277 | FECB_SCTB(base_channel * 4, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[2],
278 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= 0.25),
279 | FECB_SCTB(base_channel * 2, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[1],
280 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= 0.5),
281 | FECB_SCTB(base_channel, num_res, ResBlock=ResBlock, BasicConv=BasicConv, num_heads=heads[0],
282 | ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type,scale= 1)
283 | ])
284 |
285 | self.Convs = nn.ModuleList([
286 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
287 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
288 | ])
289 |
290 | self.ConvsOut = nn.ModuleList(
291 | [
292 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
293 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
294 | ]
295 | )
296 |
297 | self.GFMs = nn.ModuleList([
298 | GFM(32, 32, BasicConv=BasicConv),
299 | GFM(64, 64, BasicConv=BasicConv)
300 | ])
301 |
302 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv)
303 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv)
304 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv)
305 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv)
306 |
307 | self.down_1 = Downsample(32)
308 |
309 | self.up_1 = Upsample(64)
310 | self.up_2 = Upsample(128)
311 | self.up_3 = Upsample(64)
312 |
313 | def forward(self, x):
314 | x_2 = F.interpolate(x, scale_factor=0.5)
315 | x_4 = F.interpolate(x_2, scale_factor=0.5)
316 | z2 = self.SCM2(x_2)
317 | z4 = self.SCM1(x_4)
318 |
319 | outputs = list()
320 |
321 | x_ = self.feat_extract[0](x)
322 |
323 | res1 = self.Encoder[0](x_)
324 |
325 | z = self.feat_extract[1](res1)
326 | z = self.FAM2(z, z2)
327 |
328 | res2 = self.Encoder[1](z)
329 |
330 | z = self.feat_extract[2](res2)
331 | z = self.FAM1(z, z4)
332 |
333 | z = self.Encoder[2](z)
334 |
335 | z21 = self.up_1(res2)
336 | z42 = self.up_2(z)
337 | z41 = self.up_3(z42)
338 |
339 | z12 = self.down_1(res1)
340 |
341 | res1 = self.GFMs[0](res1,z21,z41)
342 | res2 = self.GFMs[1](z12,res2,z42)
343 |
344 | z = self.Decoder[0](z)
345 | z_ = self.ConvsOut[0](z)
346 | z = self.feat_extract[3](z)
347 | if not self.inference:
348 | outputs.append(z_+x_4)
349 |
350 | z = torch.cat([z, res2], dim=1)
351 | z = self.Convs[0](z)
352 |
353 | z = self.Decoder[1](z)
354 | z_ = self.ConvsOut[1](z)
355 | z = self.feat_extract[4](z)
356 | if not self.inference:
357 | outputs.append(z_+x_2)
358 |
359 | z = torch.cat([z, res1], dim=1)
360 | z = self.Convs[1](z)
361 |
362 | z = self.Decoder[2](z)
363 | z = self.feat_extract[5](z)
364 | if not self.inference:
365 | outputs.append(z+x)
366 | return outputs[::-1]
367 |
368 |
369 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/build/lib/warmup_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from warmup_scheduler.scheduler import GradualWarmupScheduler
3 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/build/lib/warmup_scheduler/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR
3 | from torch.optim.sgd import SGD
4 |
5 | from warmup_scheduler import GradualWarmupScheduler
6 |
7 |
8 | if __name__ == '__main__':
9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
10 | optim = SGD(model, 0.1)
11 |
12 | # scheduler_warmup is chained with schduler_steplr
13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
15 |
16 | # this zero gradient update is needed to avoid a warning message, issue #8.
17 | optim.zero_grad()
18 | optim.step()
19 |
20 | for epoch in range(1, 20):
21 | scheduler_warmup.step(epoch)
22 | print(epoch, optim.param_groups[0]['lr'])
23 |
24 | optim.step() # backward pass (update network)
25 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/build/lib/warmup_scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """ Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
12 | total_epoch: target learning rate is reached at total_epoch, gradually
13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14 | """
15 |
16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
17 | self.multiplier = multiplier
18 | if self.multiplier < 1.:
19 | raise ValueError('multiplier should be greater thant or equal to 1.')
20 | self.total_epoch = total_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super(GradualWarmupScheduler, self).__init__(optimizer)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.total_epoch:
27 | if self.after_scheduler:
28 | if not self.finished:
29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30 | self.finished = True
31 | return self.after_scheduler.get_lr()
32 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
33 |
34 | if self.multiplier == 1.0:
35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
36 | else:
37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
38 |
39 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
40 | if epoch is None:
41 | epoch = self.last_epoch + 1
42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
43 | if self.last_epoch <= self.total_epoch:
44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
46 | param_group['lr'] = lr
47 | else:
48 | if epoch is None:
49 | self.after_scheduler.step(metrics, None)
50 | else:
51 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
52 |
53 | def step(self, epoch=None, metrics=None):
54 | if type(self.after_scheduler) != ReduceLROnPlateau:
55 | if self.finished and self.after_scheduler:
56 | if epoch is None:
57 | self.after_scheduler.step(None)
58 | else:
59 | self.after_scheduler.step(epoch - self.total_epoch)
60 | else:
61 | return super(GradualWarmupScheduler, self).step(epoch)
62 | else:
63 | self.step_ReduceLROnPlateau(metrics, epoch)
64 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/dist/warmup_scheduler-0.3-py3.8.egg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/pytorch-gradual-warmup-lr/dist/warmup_scheduler-0.3-py3.8.egg
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import setuptools
6 |
7 | _VERSION = '0.3'
8 |
9 | REQUIRED_PACKAGES = [
10 | ]
11 |
12 | DEPENDENCY_LINKS = [
13 | ]
14 |
15 | setuptools.setup(
16 | name='warmup_scheduler',
17 | version=_VERSION,
18 | description='Gradually Warm-up LR Scheduler for Pytorch',
19 | install_requires=REQUIRED_PACKAGES,
20 | dependency_links=DEPENDENCY_LINKS,
21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr',
22 | license='MIT License',
23 | package_dir={},
24 | packages=setuptools.find_packages(exclude=['tests']),
25 | )
26 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/PKG-INFO:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: warmup-scheduler
3 | Version: 0.3
4 | Summary: Gradually Warm-up LR Scheduler for Pytorch
5 | Home-page: https://github.com/ildoonet/pytorch-gradual-warmup-lr
6 | License: MIT License
7 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | setup.py
2 | warmup_scheduler/__init__.py
3 | warmup_scheduler/run.py
4 | warmup_scheduler/scheduler.py
5 | warmup_scheduler.egg-info/PKG-INFO
6 | warmup_scheduler.egg-info/SOURCES.txt
7 | warmup_scheduler.egg-info/dependency_links.txt
8 | warmup_scheduler.egg-info/top_level.txt
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | warmup_scheduler
2 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from warmup_scheduler.scheduler import GradualWarmupScheduler
3 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR
3 | from torch.optim.sgd import SGD
4 |
5 | from warmup_scheduler import GradualWarmupScheduler
6 |
7 |
8 | if __name__ == '__main__':
9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
10 | optim = SGD(model, 0.1)
11 |
12 | # scheduler_warmup is chained with schduler_steplr
13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
15 |
16 | # this zero gradient update is needed to avoid a warning message, issue #8.
17 | optim.zero_grad()
18 | optim.step()
19 |
20 | for epoch in range(1, 20):
21 | scheduler_warmup.step(epoch)
22 | print(epoch, optim.param_groups[0]['lr'])
23 |
24 | optim.step() # backward pass (update network)
25 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """ Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
12 | total_epoch: target learning rate is reached at total_epoch, gradually
13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14 | """
15 |
16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
17 | self.multiplier = multiplier
18 | if self.multiplier < 1.:
19 | raise ValueError('multiplier should be greater thant or equal to 1.')
20 | self.total_epoch = total_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super(GradualWarmupScheduler, self).__init__(optimizer)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.total_epoch:
27 | if self.after_scheduler:
28 | if not self.finished:
29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30 | self.finished = True
31 | return self.after_scheduler.get_lr()
32 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
33 |
34 | if self.multiplier == 1.0:
35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
36 | else:
37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
38 |
39 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
40 | if epoch is None:
41 | epoch = self.last_epoch + 1
42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
43 | if self.last_epoch <= self.total_epoch:
44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
46 | param_group['lr'] = lr
47 | else:
48 | if epoch is None:
49 | self.after_scheduler.step(metrics, None)
50 | else:
51 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
52 |
53 | def step(self, epoch=None, metrics=None):
54 | if type(self.after_scheduler) != ReduceLROnPlateau:
55 | if self.finished and self.after_scheduler:
56 | if epoch is None:
57 | self.after_scheduler.step(None)
58 | else:
59 | self.after_scheduler.step(epoch - self.total_epoch)
60 | else:
61 | return super(GradualWarmupScheduler, self).step(epoch)
62 | else:
63 | self.step_ReduceLROnPlateau(metrics, epoch)
64 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch.nn as nn
4 | import torch
5 | from torch.utils.data import DataLoader
6 | import utils
7 | from data_RGB import get_test_data
8 | from model import MSDT as mynet
9 | from skimage import img_as_ubyte
10 | from get_parameter_number import get_parameter_number
11 | from tqdm import tqdm
12 | from layers import *
13 |
14 | parser = argparse.ArgumentParser(description='Image Deraining')
15 | parser.add_argument('--input_dir', default='', type=str, help='Directory of validation images')
16 | parser.add_argument('--output_dir', default='', type=str, help='Directory of validation images')
17 | parser.add_argument('--weights', default='', type=str, help='Path to weights')
18 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')
19 | parser.add_argument('--win_size', default=256, type=int, help='window size, [GoPro, HIDE, RealBlur]=256, [DPDD]=512')
20 | args = parser.parse_args()
21 | result_dir = args.output_dir
22 | win = args.win_size
23 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
24 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
25 | model_restoration = mynet()
26 | get_parameter_number(model_restoration)
27 | utils.load_checkpoint(model_restoration, args.weights)
28 | print("===>Testing using weights: ",args.weights)
29 | model_restoration.cuda()
30 | model_restoration = nn.DataParallel(model_restoration)
31 | model_restoration.eval()
32 |
33 | # dataset = args.dataset
34 | rgb_dir_test = args.input_dir
35 | test_dataset = get_test_data(rgb_dir_test, img_options={})
36 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)
37 |
38 | utils.mkdir(result_dir)
39 |
40 | with torch.no_grad():
41 | psnr_list = []
42 | ssim_list = []
43 | for ii, data_test in enumerate(tqdm(test_loader), 0):
44 |
45 | torch.cuda.ipc_collect()
46 | torch.cuda.empty_cache()
47 | input_ = data_test[0].cuda()
48 | filenames = data_test[1]
49 | _, _, Hx, Wx = input_.shape
50 | filenames = data_test[1]
51 | input_re, batch_list = window_partitionx(input_, win)
52 | restored = model_restoration(input_re)
53 | restored = window_reversex(restored[0], win, Hx, Wx, batch_list)
54 |
55 | restored = torch.clamp(restored, 0, 1)
56 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
57 |
58 | for batch in range(len(restored)):
59 | restored_img = restored[batch]
60 | restored_img = img_as_ubyte(restored[batch])
61 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img)
62 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
4 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
5 |
6 | import torch
7 |
8 | torch.backends.cudnn.benchmark = True
9 |
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | from torch.utils.data import DataLoader
13 |
14 | import random
15 | import time
16 | import numpy as np
17 |
18 | import utils
19 | from data_RGB import get_training_data, get_validation_data
20 | from model import MSDT as myNet
21 | import losses
22 | from warmup_scheduler import GradualWarmupScheduler
23 | from tqdm import tqdm
24 | from get_parameter_number import get_parameter_number
25 | import kornia
26 | from torch.utils.tensorboard import SummaryWriter
27 | import argparse
28 |
29 | from skimage import img_as_ubyte
30 |
31 | ######### Set Seeds ###########
32 | random.seed(1234)
33 | np.random.seed(1234)
34 | torch.manual_seed(1234)
35 | torch.cuda.manual_seed_all(1234)
36 |
37 | start_epoch = 1
38 |
39 | parser = argparse.ArgumentParser(description='Image Deraininig')
40 |
41 | parser.add_argument('--train_dir', default='/home/user/data/21/chm/dataset/Rain200H/Rain200H/train', type=str, help='Directory of train images')
42 | parser.add_argument('--val_dir', default='/home/user/data/21/chm/dataset/Rain200H/Rain200H/test', type=str, help='Directory of validation images')
43 | parser.add_argument('--model_save_dir', default='./checkpoints', type=str, help='Path to save weights')
44 | parser.add_argument('--pretrain_weights', default='', type=str, help='Path to pretrain-weights')
45 | parser.add_argument('--mode', default='Deraininig', type=str)
46 | parser.add_argument('--session', default='Multiscale', type=str, help='session')
47 | parser.add_argument('--patch_size', default=256, type=int, help='patch size')
48 | parser.add_argument('--num_epochs', default=500, type=int, help='num_epochs')
49 | parser.add_argument('--batch_size', default=1, type=int, help='batch_size')
50 | parser.add_argument('--val_epochs', default=1, type=int, help='val_epochs')
51 | args = parser.parse_args()
52 |
53 | mode = args.mode
54 | session = args.session
55 | patch_size = args.patch_size
56 |
57 | model_dir = os.path.join(args.model_save_dir, mode, 'models', session)
58 | utils.mkdir(model_dir)
59 |
60 | train_dir = args.train_dir
61 | val_dir = args.val_dir
62 |
63 | num_epochs = args.num_epochs
64 | batch_size = args.batch_size
65 | val_epochs = args.val_epochs
66 |
67 | start_lr = 1e-4
68 | end_lr = 1e-6
69 |
70 | ######### Model ###########
71 | model_restoration = myNet()
72 |
73 | # print number of model
74 | get_parameter_number(model_restoration)
75 |
76 | model_restoration.cuda()
77 |
78 | device_ids = [i for i in range(torch.cuda.device_count())]
79 | if torch.cuda.device_count() > 1:
80 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
81 |
82 | optimizer = optim.Adam(model_restoration.parameters(), lr=start_lr, betas=(0.9, 0.999), eps=1e-8)
83 |
84 | ######### Scheduler ###########
85 | warmup_epochs = 3
86 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs - warmup_epochs, eta_min=end_lr)
87 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
88 |
89 | RESUME = False
90 | Pretrain = False
91 | model_pre_dir = ''
92 | ######### Pretrain ###########
93 | if Pretrain:
94 | utils.load_checkpoint(model_restoration, model_pre_dir)
95 |
96 | print('------------------------------------------------------------------------------')
97 | print("==> Retrain Training with: " + model_pre_dir)
98 | print('------------------------------------------------------------------------------')
99 |
100 | ######### Resume ###########
101 | if RESUME:
102 | path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
103 | utils.load_checkpoint(model_restoration, path_chk_rest)
104 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1
105 | utils.load_optim(optimizer, path_chk_rest)
106 |
107 | for i in range(1, start_epoch):
108 | scheduler.step()
109 | new_lr = scheduler.get_lr()[0]
110 | print('------------------------------------------------------------------------------')
111 | print("==> Resuming Training with learning rate:", new_lr)
112 | print('------------------------------------------------------------------------------')
113 |
114 | if len(device_ids) > 1:
115 | model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids)
116 |
117 | ######### Loss ###########
118 | criterion_char = losses.CharbonnierLoss()
119 | criterion_edge = losses.EdgeLoss()
120 | criterion_fft = losses.fftLoss()
121 | ######### DataLoaders ###########
122 | train_dataset = get_training_data(train_dir, {'patch_size': patch_size})
123 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False,
124 | pin_memory=True)
125 |
126 | val_dataset = get_validation_data(val_dir, {'patch_size': patch_size})
127 | val_loader = DataLoader(dataset=val_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False,
128 | pin_memory=True)
129 |
130 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch, num_epochs + 1))
131 | print('===> Loading datasets')
132 |
133 | best_psnr = 0
134 | best_epoch = 0
135 | writer = SummaryWriter(model_dir)
136 | iter = 0
137 |
138 | for epoch in range(start_epoch, num_epochs + 1):
139 | epoch_start_time = time.time()
140 | epoch_loss = 0
141 | train_id = 1
142 |
143 | model_restoration.train()
144 | for i, data in enumerate(tqdm(train_loader), 0):
145 |
146 | # zero_grad
147 | for param in model_restoration.parameters():
148 | param.grad = None
149 |
150 | target_ = data[0].cuda()
151 | input_ = data[1].cuda()
152 | target = kornia.geometry.transform.build_pyramid(target_, 3)
153 | restored = model_restoration(input_)
154 |
155 | loss_fft = criterion_fft(restored[0], target[0]) + criterion_fft(restored[1], target[1]) + criterion_fft(restored[2], target[2])
156 | loss_char = criterion_char(restored[0], target[0]) + criterion_char(restored[1], target[1]) + criterion_char(restored[2], target[2])
157 | loss_edge = criterion_edge(restored[0], target[0]) + criterion_edge(restored[1], target[1]) + criterion_edge(restored[2], target[2])
158 | loss = loss_char + 0.01 * loss_fft + 0.05 * loss_edge
159 | loss.backward()
160 | optimizer.step()
161 | epoch_loss += loss.item()
162 | iter += 1
163 | writer.add_scalar('loss/fft_loss', loss_fft, iter)
164 | writer.add_scalar('loss/char_loss', loss_char, iter)
165 | writer.add_scalar('loss/edge_loss', loss_edge, iter)
166 | writer.add_scalar('loss/iter_loss', loss, iter)
167 | writer.add_scalar('loss/epoch_loss', epoch_loss, epoch)
168 | #### Evaluation ####
169 | if epoch % val_epochs == 0:
170 | model_restoration.eval()
171 | psnr_val_rgb = []
172 | for ii, data_val in enumerate((val_loader), 0):
173 | target = data_val[0].cuda()
174 | input_ = data_val[1].cuda()
175 |
176 | with torch.no_grad():
177 | restored = model_restoration(input_)
178 |
179 | for res, tar in zip(restored[0], target):
180 | psnr_val_rgb.append(utils.torchPSNR(res, tar))
181 |
182 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
183 | writer.add_scalar('val/psnr', psnr_val_rgb, epoch)
184 | if psnr_val_rgb > best_psnr:
185 | best_psnr = psnr_val_rgb
186 | best_epoch = epoch
187 | torch.save({'epoch': epoch,
188 | 'state_dict': model_restoration.state_dict(),
189 | 'optimizer': optimizer.state_dict()
190 | }, os.path.join(model_dir, "model_best.pth"))
191 |
192 | print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
193 |
194 | torch.save({'epoch': epoch,
195 | 'state_dict': model_restoration.state_dict(),
196 | 'optimizer': optimizer.state_dict()
197 | }, os.path.join(model_dir, f"model_epoch_{epoch}.pth"))
198 |
199 | scheduler.step()
200 |
201 | print("------------------------------------------------------------------")
202 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time() - epoch_start_time,
203 | epoch_loss, scheduler.get_lr()[0]))
204 | print("------------------------------------------------------------------")
205 |
206 | torch.save({'epoch': epoch,
207 | 'state_dict': model_restoration.state_dict(),
208 | 'optimizer': optimizer.state_dict()
209 | }, os.path.join(model_dir, "model_latest.pth"))
210 |
211 | writer.close()
212 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -p gpu20
3 | #SBATCH -t 7-00:00:00
4 |
5 | python train.py | tee logs_Rain200H.txt
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dir_utils import *
2 | from .image_utils import *
3 | from .model_utils import *
4 | from .dataset_utils import *
5 |
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataset_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/dataset_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataset_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/dataset_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dir_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/dir_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dir_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/dir_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/image_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/image_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/image_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/image_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/model_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/model_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/model_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cschenhm/MSDT/a3bcd76b98b8427e9d35fa87fc5c37ce9daf58f9/utils/__pycache__/model_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class MixUp_AUG:
4 | def __init__(self):
5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))
6 |
7 | def aug(self, rgb_gt, rgb_noisy):
8 | bs = rgb_gt.size(0)
9 | indices = torch.randperm(bs)
10 | rgb_gt2 = rgb_gt[indices]
11 | rgb_noisy2 = rgb_noisy[indices]
12 |
13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
14 |
15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
17 |
18 | return rgb_gt, rgb_noisy
--------------------------------------------------------------------------------
/utils/dir_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from natsort import natsorted
3 | from glob import glob
4 |
5 | def mkdirs(paths):
6 | if isinstance(paths, list) and not isinstance(paths, str):
7 | for path in paths:
8 | mkdir(path)
9 | else:
10 | mkdir(paths)
11 |
12 | def mkdir(path):
13 | if not os.path.exists(path):
14 | os.makedirs(path)
15 |
16 | def get_last_path(path, session):
17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
18 | return x
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 |
5 | def torchPSNR(tar_img, prd_img):
6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
7 | rmse = (imdff**2).mean().sqrt()
8 | ps = 20*torch.log10(1/rmse)
9 | return ps
10 |
11 | def save_img(filepath, img):
12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
13 |
14 | def numpyPSNR(tar_img, prd_img):
15 | imdff = np.float32(prd_img) - np.float32(tar_img)
16 | rmse = np.sqrt(np.mean(imdff**2))
17 | ps = 20*np.log10(255/rmse)
18 | return ps
19 |
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from collections import OrderedDict
4 | import numpy as np
5 | def freeze(model):
6 | for p in model.parameters():
7 | p.requires_grad=False
8 |
9 | def unfreeze(model):
10 | for p in model.parameters():
11 | p.requires_grad=True
12 |
13 | def is_frozen(model):
14 | x = [p.requires_grad for p in model.parameters()]
15 | return not all(x)
16 |
17 | def save_checkpoint(model_dir, state, session):
18 | epoch = state['epoch']
19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
20 | torch.save(state, model_out_path)
21 |
22 | def load_checkpoint(model, weights):
23 | checkpoint = torch.load(weights)
24 | # print(checkpoint)
25 | try:
26 | model.load_state_dict(checkpoint["state_dict"])
27 | except:
28 | state_dict = checkpoint["state_dict"]
29 | new_state_dict = OrderedDict()
30 | for k, v in state_dict.items():
31 | # print(k)
32 | name = k[7:] # remove `module.`
33 | new_state_dict[name] = v
34 |
35 | model.load_state_dict(new_state_dict)
36 |
37 |
38 | def load_checkpoint_compress_doconv(model, weights):
39 | checkpoint = torch.load(weights)
40 | # print(checkpoint)
41 | # state_dict = OrderedDict()
42 | # try:
43 | # model.load_state_dict(checkpoint["state_dict"])
44 | # state_dict = checkpoint["state_dict"]
45 | # except:
46 | old_state_dict = checkpoint["state_dict"]
47 | state_dict = OrderedDict()
48 | for k, v in old_state_dict.items():
49 | # print(k)
50 | name = k
51 | if k[:7] == 'module.':
52 | name = k[7:] # remove `module.`
53 | state_dict[name] = v
54 | # state_dict = checkpoint["state_dict"]
55 | do_state_dict = OrderedDict()
56 | for k, v in state_dict.items():
57 | if k[-1] == 'W' and k[:-1] + 'D' in state_dict:
58 | k_D = k[:-1] + 'D'
59 | k_D_diag = k_D + '_diag'
60 | W = v
61 | D = state_dict[k_D]
62 | D_diag = state_dict[k_D_diag]
63 | D = D + D_diag
64 | # W = torch.reshape(W, (out_channels, in_channels, D_mul))
65 | out_channels, in_channels, MN = W.shape
66 | M = int(np.sqrt(MN))
67 | DoW_shape = (out_channels, in_channels, M, M)
68 | DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
69 | do_state_dict[k] = DoW
70 | elif k[-1] == 'D' or k[-6:] == 'D_diag':
71 | continue
72 | elif k[-1] == 'W':
73 | out_channels, in_channels, MN = v.shape
74 | M = int(np.sqrt(MN))
75 | W_shape = (out_channels, in_channels, M, M)
76 | do_state_dict[k] = torch.reshape(v, W_shape)
77 | else:
78 | do_state_dict[k] = v
79 | model.load_state_dict(do_state_dict)
80 | def load_checkpoint_hin(model, weights):
81 | checkpoint = torch.load(weights)
82 | # print(checkpoint)
83 | try:
84 | model.load_state_dict(checkpoint)
85 | except:
86 | state_dict = checkpoint
87 | new_state_dict = OrderedDict()
88 | for k, v in state_dict.items():
89 | name = k[7:] # remove `module.`
90 | new_state_dict[name] = v
91 | model.load_state_dict(new_state_dict)
92 | def load_checkpoint_multigpu(model, weights):
93 | checkpoint = torch.load(weights)
94 | state_dict = checkpoint["state_dict"]
95 | new_state_dict = OrderedDict()
96 | for k, v in state_dict.items():
97 | name = k[7:] # remove `module.`
98 | new_state_dict[name] = v
99 | model.load_state_dict(new_state_dict)
100 |
101 | def load_start_epoch(weights):
102 | checkpoint = torch.load(weights)
103 | epoch = checkpoint["epoch"]
104 | return epoch
105 |
106 | def load_optim(optimizer, weights):
107 | checkpoint = torch.load(weights)
108 | optimizer.load_state_dict(checkpoint['optimizer'])
109 | # for p in optimizer.param_groups: lr = p['lr']
110 | # return lr
111 |
--------------------------------------------------------------------------------