0
22 | parfor j = 1:img_num
23 | image_name = path_list(j).name;
24 | gt_name = gt_list(j).name;
25 | input = imread(strcat(file_path,image_name));
26 | gt = imread(strcat(gt_path, gt_name));
27 | ssim_val = compute_ssim(input, gt);
28 | psnr_val = compute_psnr(input, gt);
29 | total_ssim = total_ssim + ssim_val;
30 | total_psnr = total_psnr + psnr_val;
31 | end
32 | end
33 | qm_psnr = total_psnr / img_num;
34 | qm_ssim = total_ssim / img_num;
35 |
36 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim);
37 |
38 | psnr_alldatasets = psnr_alldatasets + qm_psnr;
39 | ssim_alldatasets = ssim_alldatasets + qm_ssim;
40 |
41 | end
42 |
43 | fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set);
44 |
45 | delete(gcp('nocreate'))
46 | toc
47 |
48 | function ssim_mean=compute_ssim(img1,img2)
49 | if size(img1, 3) == 3
50 | img1 = rgb2ycbcr(img1);
51 | img1 = img1(:, :, 1);
52 | end
53 |
54 | if size(img2, 3) == 3
55 | img2 = rgb2ycbcr(img2);
56 | img2 = img2(:, :, 1);
57 | end
58 | ssim_mean = SSIM_index(img1, img2);
59 | end
60 |
61 | function psnr=compute_psnr(img1,img2)
62 | if size(img1, 3) == 3
63 | img1 = rgb2ycbcr(img1);
64 | img1 = img1(:, :, 1);
65 | end
66 |
67 | if size(img2, 3) == 3
68 | img2 = rgb2ycbcr(img2);
69 | img2 = img2(:, :, 1);
70 | end
71 |
72 | imdff = double(img1) - double(img2);
73 | imdff = imdff(:);
74 | rmse = sqrt(mean(imdff.^2));
75 | psnr = 20*log10(255/rmse);
76 |
77 | end
78 |
79 | function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L)
80 |
81 | %========================================================================
82 | %SSIM Index, Version 1.0
83 | %Copyright(c) 2003 Zhou Wang
84 | %All Rights Reserved.
85 | %
86 | %The author is with Howard Hughes Medical Institute, and Laboratory
87 | %for Computational Vision at Center for Neural Science and Courant
88 | %Institute of Mathematical Sciences, New York University.
89 | %
90 | %----------------------------------------------------------------------
91 | %Permission to use, copy, or modify this software and its documentation
92 | %for educational and research purposes only and without fee is hereby
93 | %granted, provided that this copyright notice and the original authors'
94 | %names appear on all copies and supporting documentation. This program
95 | %shall not be used, rewritten, or adapted as the basis of a commercial
96 | %software or hardware product without first obtaining permission of the
97 | %authors. The authors make no representations about the suitability of
98 | %this software for any purpose. It is provided "as is" without express
99 | %or implied warranty.
100 | %----------------------------------------------------------------------
101 | %
102 | %This is an implementation of the algorithm for calculating the
103 | %Structural SIMilarity (SSIM) index between two images. Please refer
104 | %to the following paper:
105 | %
106 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
107 | %quality assessment: From error measurement to structural similarity"
108 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
109 | %
110 | %Kindly report any suggestions or corrections to zhouwang@ieee.org
111 | %
112 | %----------------------------------------------------------------------
113 | %
114 | %Input : (1) img1: the first image being compared
115 | % (2) img2: the second image being compared
116 | % (3) K: constants in the SSIM index formula (see the above
117 | % reference). defualt value: K = [0.01 0.03]
118 | % (4) window: local window for statistics (see the above
119 | % reference). default widnow is Gaussian given by
120 | % window = fspecial('gaussian', 11, 1.5);
121 | % (5) L: dynamic range of the images. default: L = 255
122 | %
123 | %Output: (1) mssim: the mean SSIM index value between 2 images.
124 | % If one of the images being compared is regarded as
125 | % perfect quality, then mssim can be considered as the
126 | % quality measure of the other image.
127 | % If img1 = img2, then mssim = 1.
128 | % (2) ssim_map: the SSIM index map of the test image. The map
129 | % has a smaller size than the input images. The actual size:
130 | % size(img1) - size(window) + 1.
131 | %
132 | %Default Usage:
133 | % Given 2 test images img1 and img2, whose dynamic range is 0-255
134 | %
135 | % [mssim ssim_map] = ssim_index(img1, img2);
136 | %
137 | %Advanced Usage:
138 | % User defined parameters. For example
139 | %
140 | % K = [0.05 0.05];
141 | % window = ones(8);
142 | % L = 100;
143 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
144 | %
145 | %See the results:
146 | %
147 | % mssim %Gives the mssim value
148 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map
149 | %
150 | %========================================================================
151 |
152 |
153 | if (nargin < 2 || nargin > 5)
154 | ssim_index = -Inf;
155 | ssim_map = -Inf;
156 | return;
157 | end
158 |
159 | if (size(img1) ~= size(img2))
160 | ssim_index = -Inf;
161 | ssim_map = -Inf;
162 | return;
163 | end
164 |
165 | [M N] = size(img1);
166 |
167 | if (nargin == 2)
168 | if ((M < 11) || (N < 11))
169 | ssim_index = -Inf;
170 | ssim_map = -Inf;
171 | return
172 | end
173 | window = fspecial('gaussian', 11, 1.5); %
174 | K(1) = 0.01; % default settings
175 | K(2) = 0.03; %
176 | L = 255; %
177 | end
178 |
179 | if (nargin == 3)
180 | if ((M < 11) || (N < 11))
181 | ssim_index = -Inf;
182 | ssim_map = -Inf;
183 | return
184 | end
185 | window = fspecial('gaussian', 11, 1.5);
186 | L = 255;
187 | if (length(K) == 2)
188 | if (K(1) < 0 || K(2) < 0)
189 | ssim_index = -Inf;
190 | ssim_map = -Inf;
191 | return;
192 | end
193 | else
194 | ssim_index = -Inf;
195 | ssim_map = -Inf;
196 | return;
197 | end
198 | end
199 |
200 | if (nargin == 4)
201 | [H W] = size(window);
202 | if ((H*W) < 4 || (H > M) || (W > N))
203 | ssim_index = -Inf;
204 | ssim_map = -Inf;
205 | return
206 | end
207 | L = 255;
208 | if (length(K) == 2)
209 | if (K(1) < 0 || K(2) < 0)
210 | ssim_index = -Inf;
211 | ssim_map = -Inf;
212 | return;
213 | end
214 | else
215 | ssim_index = -Inf;
216 | ssim_map = -Inf;
217 | return;
218 | end
219 | end
220 |
221 | if (nargin == 5)
222 | [H W] = size(window);
223 | if ((H*W) < 4 || (H > M) || (W > N))
224 | ssim_index = -Inf;
225 | ssim_map = -Inf;
226 | return
227 | end
228 | if (length(K) == 2)
229 | if (K(1) < 0 || K(2) < 0)
230 | ssim_index = -Inf;
231 | ssim_map = -Inf;
232 | return;
233 | end
234 | else
235 | ssim_index = -Inf;
236 | ssim_map = -Inf;
237 | return;
238 | end
239 | end
240 |
241 | C1 = (K(1)*L)^2;
242 | C2 = (K(2)*L)^2;
243 | window = window/sum(sum(window));
244 | img1 = double(img1);
245 | img2 = double(img2);
246 |
247 | mu1 = filter2(window, img1, 'valid');
248 | mu2 = filter2(window, img2, 'valid');
249 | mu1_sq = mu1.*mu1;
250 | mu2_sq = mu2.*mu2;
251 | mu1_mu2 = mu1.*mu2;
252 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
253 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
254 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
255 |
256 | if (C1 > 0 & C2 > 0)
257 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
258 | else
259 | numerator1 = 2*mu1_mu2 + C1;
260 | numerator2 = 2*sigma12 + C2;
261 | denominator1 = mu1_sq + mu2_sq + C1;
262 | denominator2 = sigma1_sq + sigma2_sq + C2;
263 | ssim_map = ones(size(mu1));
264 | index = (denominator1.*denominator2 > 0);
265 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
266 | index = (denominator1 ~= 0) & (denominator2 == 0);
267 | ssim_map(index) = numerator1(index)./denominator1(index);
268 | end
269 |
270 | mssim = mean2(ssim_map);
271 |
272 | end
273 |
--------------------------------------------------------------------------------
/Image_deraining/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from torch.backends import cudnn
5 | from models.ConvIR import build_net
6 | from train import _train
7 | from eval import _eval
8 |
9 | def main(args):
10 | cudnn.benchmark = True
11 |
12 | if not os.path.exists('results/'):
13 | os.makedirs(args.model_save_dir)
14 | if not os.path.exists('results/' + args.model_name + '/'):
15 | os.makedirs('results/' + args.model_name + '/')
16 | if not os.path.exists(args.result_dir):
17 | os.makedirs(args.result_dir)
18 |
19 | model = build_net()
20 | print(model)
21 |
22 | if torch.cuda.is_available():
23 | model.cuda()
24 | if args.mode == 'train':
25 | _train(model, args)
26 |
27 | elif args.mode == 'test':
28 | _eval(model, args)
29 |
30 |
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser()
33 |
34 | # Directories
35 | parser.add_argument('--model_name', default='ConvIR', type=str)
36 | parser.add_argument('--data_dir', type=str, default='../../dataset/deraining/train/Rain13K/')
37 | parser.add_argument('--valid_data', type=str, default='/Rain100K/')
38 |
39 | # Train
40 | parser.add_argument('--batch_size', type=int, default=4)
41 | parser.add_argument('--learning_rate', type=float, default=1e-4)
42 | parser.add_argument('--weight_decay', type=float, default=0)
43 | parser.add_argument('--num_epoch', type=int, default=300)
44 | parser.add_argument('--print_freq', type=int, default=100)
45 | parser.add_argument('--num_worker', type=int, default=8)
46 | parser.add_argument('--save_freq', type=int, default=10)
47 |
48 | parser.add_argument('--valid_freq', type=int, default=10)
49 | parser.add_argument('--resume', type=str, default='')
50 | parser.add_argument('--gamma', type=float, default=0.5)
51 | # Test
52 | parser.add_argument('--test_model', type=str, default='')
53 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False])
54 |
55 | args = parser.parse_args()
56 | args.model_save_dir = os.path.join('results/', 'ConvIR', 'train_results/')
57 | args.result_dir = os.path.join('results/', args.model_name, 'test')
58 | if not os.path.exists(args.model_save_dir):
59 | os.makedirs(args.model_save_dir)
60 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir
61 | os.system(command)
62 | command = 'cp ' + 'models/ConvIR.py ' + args.model_save_dir
63 | os.system(command)
64 | command = 'cp ' + 'train.py ' + args.model_save_dir
65 | os.system(command)
66 | command = 'cp ' + 'main.py ' + args.model_save_dir
67 | os.system(command)
68 | print(args)
69 | main(args)
70 |
--------------------------------------------------------------------------------
/Image_deraining/models/ConvIR.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .layers import *
6 |
7 |
8 | class EBlock(nn.Module):
9 | def __init__(self, out_channel, num_res=8):
10 | super(EBlock, self).__init__()
11 |
12 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res-1)]
13 | layers.append(ResBlock(out_channel, out_channel, filter=True))
14 |
15 | self.layers = nn.Sequential(*layers)
16 |
17 | def forward(self, x):
18 | return self.layers(x)
19 |
20 |
21 | class DBlock(nn.Module):
22 | def __init__(self, channel, num_res=8):
23 | super(DBlock, self).__init__()
24 |
25 | layers = [ResBlock(channel, channel) for _ in range(num_res-1)]
26 | layers.append(ResBlock(channel, channel, filter=True))
27 | self.layers = nn.Sequential(*layers)
28 |
29 | def forward(self, x):
30 | return self.layers(x)
31 |
32 |
33 | class SCM(nn.Module):
34 | def __init__(self, out_plane):
35 | super(SCM, self).__init__()
36 | self.main = nn.Sequential(
37 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
38 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
39 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
40 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False),
41 | nn.InstanceNorm2d(out_plane, affine=True)
42 | )
43 |
44 | def forward(self, x):
45 | x = self.main(x)
46 | return x
47 |
48 | class FAM(nn.Module):
49 | def __init__(self, channel):
50 | super(FAM, self).__init__()
51 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False)
52 |
53 | def forward(self, x1, x2):
54 | return self.merge(torch.cat([x1, x2], dim=1))
55 |
56 | class ConvIR(nn.Module):
57 | def __init__(self, num_res=16):
58 | super(ConvIR, self).__init__()
59 |
60 | base_channel = 32
61 |
62 | self.Encoder = nn.ModuleList([
63 | EBlock(base_channel, num_res),
64 | EBlock(base_channel*2, num_res),
65 | EBlock(base_channel*4, num_res),
66 | ])
67 |
68 | self.feat_extract = nn.ModuleList([
69 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
70 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
71 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
72 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
73 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
74 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
75 | ])
76 |
77 | self.Decoder = nn.ModuleList([
78 | DBlock(base_channel * 4, num_res),
79 | DBlock(base_channel * 2, num_res),
80 | DBlock(base_channel, num_res)
81 | ])
82 |
83 | self.Convs = nn.ModuleList([
84 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
85 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
86 | ])
87 |
88 | self.ConvsOut = nn.ModuleList(
89 | [
90 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
91 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
92 | ]
93 | )
94 |
95 | self.FAM1 = FAM(base_channel * 4)
96 | self.SCM1 = SCM(base_channel * 4)
97 | self.FAM2 = FAM(base_channel * 2)
98 | self.SCM2 = SCM(base_channel * 2)
99 |
100 | def forward(self, x):
101 | x_2 = F.interpolate(x, scale_factor=0.5)
102 | x_4 = F.interpolate(x_2, scale_factor=0.5)
103 | z2 = self.SCM2(x_2)
104 | z4 = self.SCM1(x_4)
105 |
106 | outputs = list()
107 | # 256
108 | x_ = self.feat_extract[0](x)
109 | res1 = self.Encoder[0](x_)
110 | # 128
111 | z = self.feat_extract[1](res1)
112 | z = self.FAM2(z, z2)
113 | res2 = self.Encoder[1](z)
114 | # 64
115 | z = self.feat_extract[2](res2)
116 | z = self.FAM1(z, z4)
117 | z = self.Encoder[2](z)
118 |
119 | z = self.Decoder[0](z)
120 | z_ = self.ConvsOut[0](z)
121 | # 128
122 | z = self.feat_extract[3](z)
123 | outputs.append(z_+x_4)
124 |
125 | z = torch.cat([z, res2], dim=1)
126 | z = self.Convs[0](z)
127 | z = self.Decoder[1](z)
128 | z_ = self.ConvsOut[1](z)
129 | # 256
130 | z = self.feat_extract[4](z)
131 | outputs.append(z_+x_2)
132 |
133 | z = torch.cat([z, res1], dim=1)
134 | z = self.Convs[1](z)
135 | z = self.Decoder[2](z)
136 | z = self.feat_extract[5](z)
137 | outputs.append(z+x)
138 |
139 | return outputs
140 |
141 |
142 | def build_net():
143 | return ConvIR()
144 |
--------------------------------------------------------------------------------
/Image_deraining/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class BasicConv(nn.Module):
6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
7 | super(BasicConv, self).__init__()
8 | if bias and norm:
9 | bias = False
10 |
11 | padding = kernel_size // 2
12 | layers = list()
13 | if transpose:
14 | padding = kernel_size // 2 -1
15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
16 | else:
17 | layers.append(
18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
19 | if norm:
20 | layers.append(nn.BatchNorm2d(out_channel))
21 | if relu:
22 | layers.append(nn.GELU())
23 | self.main = nn.Sequential(*layers)
24 |
25 | def forward(self, x):
26 | return self.main(x)
27 |
28 | class ResBlock(nn.Module):
29 | def __init__(self, in_channel, out_channel, filter=False):
30 | super(ResBlock, self).__init__()
31 | self.main = nn.Sequential(
32 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
33 | DeepPoolLayer(in_channel, out_channel) if filter else nn.Identity(),
34 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
35 | )
36 |
37 | def forward(self, x):
38 | return self.main(x) + x
39 |
40 | class DeepPoolLayer(nn.Module):
41 | def __init__(self, k, k_out):
42 | super(DeepPoolLayer, self).__init__()
43 | self.pools_sizes = [8,4,2]
44 | dilation = [3,7,9]
45 | pools, convs, dynas = [],[],[]
46 | for j, i in enumerate(self.pools_sizes):
47 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i))
48 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False))
49 | dynas.append(MultiShapeKernel(dim=k, kernel_size=3, dilation=dilation[j]))
50 | self.pools = nn.ModuleList(pools)
51 | self.convs = nn.ModuleList(convs)
52 | self.dynas = nn.ModuleList(dynas)
53 | self.relu = nn.GELU()
54 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False)
55 |
56 | def forward(self, x):
57 | x_size = x.size()
58 | resl = x
59 | for i in range(len(self.pools_sizes)):
60 | if i == 0:
61 | y = self.dynas[i](self.convs[i](self.pools[i](x)))
62 | else:
63 | y = self.dynas[i](self.convs[i](self.pools[i](x)+y_up))
64 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True))
65 | if i != len(self.pools_sizes)-1:
66 | y_up = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
67 | resl = self.relu(resl)
68 | resl = self.conv_sum(resl)
69 |
70 | return resl
71 |
72 | class dynamic_filter(nn.Module):
73 | def __init__(self, inchannels, kernel_size=3, dilation=1, stride=1, group=8):
74 | super(dynamic_filter, self).__init__()
75 | self.stride = stride
76 | self.kernel_size = kernel_size
77 | self.group = group
78 | self.dilation = dilation
79 |
80 | self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False)
81 | self.bn = nn.BatchNorm2d(group*kernel_size**2)
82 | self.act = nn.Tanh()
83 |
84 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
85 | self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True)
86 | self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True)
87 | self.pad = nn.ReflectionPad2d(self.dilation*(kernel_size-1)//2)
88 |
89 | self.ap = nn.AdaptiveAvgPool2d((1, 1))
90 | self.gap = nn.AdaptiveAvgPool2d(1)
91 |
92 | self.inside_all = nn.Parameter(torch.zeros(inchannels,1,1), requires_grad=True)
93 |
94 | def forward(self, x):
95 | identity_input = x
96 | low_filter = self.ap(x)
97 | low_filter = self.conv(low_filter)
98 | low_filter = self.bn(low_filter)
99 |
100 | n, c, h, w = x.shape
101 | x = F.unfold(self.pad(x), kernel_size=self.kernel_size, dilation=self.dilation).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w)
102 |
103 | n,c1,p,q = low_filter.shape
104 | low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2)
105 |
106 | low_filter = self.act(low_filter)
107 |
108 | low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w)
109 |
110 | out_low = low_part * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input)
111 |
112 | out_low = out_low * self.lamb_l[None,:,None,None]
113 |
114 | out_high = (identity_input) * (self.lamb_h[None,:,None,None] + 1.)
115 |
116 | return out_low + out_high
117 |
118 |
119 | class cubic_attention(nn.Module):
120 | def __init__(self, dim, group, dilation, kernel) -> None:
121 | super().__init__()
122 |
123 | self.H_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel)
124 | self.W_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel, H=False)
125 | self.gamma = nn.Parameter(torch.zeros(dim,1,1))
126 | self.beta = nn.Parameter(torch.ones(dim,1,1))
127 |
128 | def forward(self, x):
129 | out = self.H_spatial_att(x)
130 | out = self.W_spatial_att(out)
131 | return self.gamma * out + x * self.beta
132 |
133 |
134 | class spatial_strip_att(nn.Module):
135 | def __init__(self, dim, kernel=3, dilation=1, group=2, H=True) -> None:
136 | super().__init__()
137 |
138 | self.k = kernel
139 | pad = dilation*(kernel-1) // 2
140 | self.kernel = (1, kernel) if H else (kernel, 1)
141 | self.padding = (kernel//2, 1) if H else (1, kernel//2)
142 | self.dilation = dilation
143 | self.group = group
144 | self.pad = nn.ReflectionPad2d((pad, pad, 0, 0)) if H else nn.ReflectionPad2d((0, 0, pad, pad))
145 | self.conv = nn.Conv2d(dim, group*kernel, kernel_size=1, stride=1, bias=False)
146 | self.ap = nn.AdaptiveAvgPool2d((1, 1))
147 | self.filter_act = nn.Tanh()
148 | self.inside_all = nn.Parameter(torch.zeros(dim,1,1), requires_grad=True)
149 | self.lamb_l = nn.Parameter(torch.zeros(dim), requires_grad=True)
150 | self.lamb_h = nn.Parameter(torch.zeros(dim), requires_grad=True)
151 | gap_kernel = (None,1) if H else (1, None)
152 | self.gap = nn.AdaptiveAvgPool2d(gap_kernel)
153 |
154 | def forward(self, x):
155 | identity_input = x.clone()
156 | filter = self.ap(x)
157 | filter = self.conv(filter)
158 | n, c, h, w = x.shape
159 | x = F.unfold(self.pad(x), kernel_size=self.kernel, dilation=self.dilation).reshape(n, self.group, c//self.group, self.k, h*w)
160 | n, c1, p, q = filter.shape
161 | filter = filter.reshape(n, c1//self.k, self.k, p*q).unsqueeze(2)
162 | filter = self.filter_act(filter)
163 | out = torch.sum(x * filter, dim=3).reshape(n, c, h, w)
164 |
165 | out_low = out * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input)
166 | out_low = out_low * self.lamb_l[None,:,None,None]
167 | out_high = identity_input * (self.lamb_h[None,:,None,None]+1.)
168 |
169 | return out_low + out_high
170 |
171 |
172 | class MultiShapeKernel(nn.Module):
173 | def __init__(self, dim, kernel_size=3, dilation=1, group=8):
174 | super().__init__()
175 |
176 | self.square_att = dynamic_filter(inchannels=dim, dilation=dilation, group=group, kernel_size=kernel_size)
177 | self.strip_att = cubic_attention(dim, group=group, dilation=dilation, kernel=kernel_size)
178 |
179 | def forward(self, x):
180 |
181 | x1 = self.strip_att(x)
182 | x2 = self.square_att(x)
183 |
184 | return x1+x2
185 |
186 |
187 |
--------------------------------------------------------------------------------
/Image_deraining/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from models.SFNet import build_net
5 | from data import test_dataloader
6 | from utils import Adder
7 | import time
8 | from torchvision.transforms import functional as F
9 | from skimage.metrics import peak_signal_noise_ratio
10 | from torch.utils.data import Dataset, DataLoader
11 | from PIL import Image
12 | from tqdm import tqdm
13 | import torch.nn.functional as f
14 |
15 | class DeblurDataset(Dataset):
16 | def __init__(self, image_dir, transform=None, is_test=False):
17 | self.image_dir = image_dir
18 | self.image_list = os.listdir(os.path.join(image_dir, 'input/'))
19 | self._check_image(self.image_list)
20 | self.image_list.sort()
21 | self.transform = transform
22 | self.is_test = is_test
23 |
24 | def __len__(self):
25 | return len(self.image_list)
26 |
27 | def __getitem__(self, idx):
28 | image = Image.open(os.path.join(self.image_dir, 'input', self.image_list[idx]))
29 | label = Image.open(os.path.join(self.image_dir, 'target', self.image_list[idx]))
30 |
31 | if self.transform:
32 | image, label = self.transform(image, label)
33 | else:
34 | image = F.to_tensor(image)
35 | label = F.to_tensor(label)
36 | if self.is_test:
37 | name = self.image_list[idx]
38 | return image, label, name
39 | return image, label
40 |
41 | @staticmethod
42 | def _check_image(lst):
43 | for x in lst:
44 | splits = x.split('.')
45 | if splits[-1] not in ['png', 'jpg', 'jpeg']:
46 | raise ValueError
47 |
48 | def test_dataloader(path, batch_size=1, num_workers=0):
49 | dataloader = DataLoader(
50 | DeblurDataset(path, is_test=True),
51 | batch_size=batch_size,
52 | shuffle=False,
53 | num_workers=num_workers,
54 | pin_memory=True
55 | )
56 |
57 | return dataloader
58 |
59 |
60 | parser = argparse.ArgumentParser()
61 |
62 | # Directories
63 | parser.add_argument('--model_name', default='SFNet', type=str)
64 | parser.add_argument('--data_dir', type=str, default='/root/autodl-tmp/deraining_testset')
65 |
66 | parser.add_argument('--test_model', type=str, default='/root/autodl-tmp/sfnet/deraining.pkl')
67 | parser.add_argument('--save_image', type=bool, default=True, choices=[True, False])
68 | args = parser.parse_args()
69 |
70 | args.result_dir = os.path.join('results/', args.model_name, 'deraining/')
71 |
72 | if not os.path.exists('results/'):
73 | os.makedirs(args.model_save_dir)
74 | if not os.path.exists('results/' + args.model_name + '/'):
75 | os.makedirs('results/' + args.model_name + '/')
76 | if not os.path.exists(args.result_dir):
77 | os.makedirs(args.result_dir)
78 |
79 | model = build_net()
80 |
81 | if torch.cuda.is_available():
82 | model.cuda()
83 |
84 | state_dict = torch.load(args.test_model)
85 | model.load_state_dict(state_dict['model'])
86 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
87 | torch.cuda.empty_cache()
88 | adder = Adder()
89 | model.eval()
90 |
91 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800']
92 |
93 | for dataset in datasets:
94 | if not os.path.exists(args.result_dir+dataset+'/'):
95 | os.makedirs(args.result_dir+dataset)
96 | print(args.result_dir+dataset)
97 | dataloader = test_dataloader(os.path.join(args.data_dir, dataset), batch_size=1, num_workers=4)
98 | factor = 32
99 | with torch.no_grad():
100 | psnr_adder = Adder()
101 |
102 |
103 | # Main Evaluation
104 | for iter_idx, data in enumerate(tqdm(dataloader), 0):
105 | input_img, label_img, name = data
106 |
107 | input_img = input_img.to(device)
108 |
109 | h, w = input_img.shape[2], input_img.shape[3]
110 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
111 | padh = H-h if h%factor!=0 else 0
112 | padw = W-w if w%factor!=0 else 0
113 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
114 |
115 |
116 | tm = time.time()
117 |
118 | pred = model(input_img)[2]
119 |
120 | elapsed = time.time() - tm
121 | adder(elapsed)
122 |
123 | pred_clip = torch.clamp(pred, 0, 1)
124 |
125 | pred_numpy = pred_clip.squeeze(0).cpu().numpy()
126 | label_numpy = label_img.squeeze(0).cpu().numpy()
127 |
128 | if args.save_image:
129 | save_name = os.path.join(args.result_dir, dataset, name[0])
130 | pred_clip += 0.5 / 255
131 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB')
132 | pred.save(save_name)
133 |
134 | print('==========================================================')
135 |
--------------------------------------------------------------------------------
/Image_deraining/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from data import train_dataloader
4 | from utils import Adder, Timer
5 | from torch.utils.tensorboard import SummaryWriter
6 | from valid import _valid
7 | import torch.nn.functional as F
8 | import torch.nn as nn
9 |
10 | from warmup_scheduler import GradualWarmupScheduler
11 |
12 |
13 | def _train(model, args):
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | criterion = torch.nn.L1Loss()
16 |
17 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8)
18 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker)
19 | max_iter = len(dataloader)
20 | warmup_epochs=3
21 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6)
22 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
23 | scheduler.step()
24 | epoch = 1
25 | if args.resume:
26 | state = torch.load(args.resume)
27 | epoch = state['epoch']
28 | optimizer.load_state_dict(state['optimizer'])
29 | model.load_state_dict(state['model'])
30 | print('Resume from %d'%epoch)
31 | epoch += 1
32 |
33 |
34 |
35 | writer = SummaryWriter()
36 | epoch_pixel_adder = Adder()
37 | epoch_fft_adder = Adder()
38 | iter_pixel_adder = Adder()
39 | iter_fft_adder = Adder()
40 | epoch_timer = Timer('m')
41 | iter_timer = Timer('m')
42 | best_psnr=-1
43 |
44 | for epoch_idx in range(epoch, args.num_epoch + 1):
45 |
46 | epoch_timer.tic()
47 | iter_timer.tic()
48 | for iter_idx, batch_data in enumerate(dataloader):
49 |
50 | input_img, label_img = batch_data
51 | input_img = input_img.to(device)
52 | label_img = label_img.to(device)
53 |
54 | optimizer.zero_grad()
55 | pred_img = model(input_img)
56 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear')
57 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear')
58 | l1 = criterion(pred_img[0], label_img4)
59 | l2 = criterion(pred_img[1], label_img2)
60 | l3 = criterion(pred_img[2], label_img)
61 | loss_content = l1+l2+l3
62 |
63 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1))
64 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1)
65 |
66 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1))
67 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1)
68 |
69 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1))
70 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1)
71 |
72 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1))
73 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1)
74 |
75 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1))
76 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1)
77 |
78 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1))
79 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1)
80 |
81 | f1 = criterion(pred_fft1, label_fft1)
82 | f2 = criterion(pred_fft2, label_fft2)
83 | f3 = criterion(pred_fft3, label_fft3)
84 | loss_fft = f1+f2+f3
85 |
86 | loss = loss_content + 0.1 * loss_fft
87 | loss.backward()
88 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
89 | optimizer.step()
90 |
91 | iter_pixel_adder(loss_content.item())
92 | iter_fft_adder(loss_fft.item())
93 |
94 | epoch_pixel_adder(loss_content.item())
95 | epoch_fft_adder(loss_fft.item())
96 |
97 | if (iter_idx + 1) % args.print_freq == 0:
98 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % (
99 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(),
100 | iter_fft_adder.average()))
101 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter)
102 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter)
103 |
104 | iter_timer.tic()
105 | iter_pixel_adder.reset()
106 | iter_fft_adder.reset()
107 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl')
108 | torch.save({'model': model.state_dict(),
109 | 'optimizer': optimizer.state_dict(),
110 | 'epoch': epoch_idx}, overwrite_name)
111 |
112 | if epoch_idx % args.save_freq == 0:
113 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx)
114 | torch.save({'model': model.state_dict()}, save_name)
115 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % (
116 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average()))
117 | epoch_fft_adder.reset()
118 | epoch_pixel_adder.reset()
119 | scheduler.step()
120 | if epoch_idx % args.valid_freq == 0:
121 | val_rain = _valid(model, args, epoch_idx)
122 | print('%03d epoch \n Average DeRain PSNR %.2f dB' % (epoch_idx, val_rain))
123 | writer.add_scalar('PSNR_DeRain', val_rain, epoch_idx)
124 | if val_rain >= best_psnr:
125 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl'))
126 | save_name = os.path.join(args.model_save_dir, 'Final.pkl')
127 | torch.save({'model': model.state_dict()}, save_name)
128 |
--------------------------------------------------------------------------------
/Image_deraining/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 |
4 |
5 | class Adder(object):
6 | def __init__(self):
7 | self.count = 0
8 | self.num = float(0)
9 |
10 | def reset(self):
11 | self.count = 0
12 | self.num = float(0)
13 |
14 | def __call__(self, num):
15 | self.count += 1
16 | self.num += num
17 |
18 | def average(self):
19 | return self.num / self.count
20 |
21 |
22 | class Timer(object):
23 | def __init__(self, option='s'):
24 | self.tm = 0
25 | self.option = option
26 | if option == 's':
27 | self.devider = 1
28 | elif option == 'm':
29 | self.devider = 60
30 | else:
31 | self.devider = 3600
32 |
33 | def tic(self):
34 | self.tm = time.time()
35 |
36 | def toc(self):
37 | return (time.time() - self.tm) / self.devider
38 |
39 |
40 | def check_lr(optimizer):
41 | for i, param_group in enumerate(optimizer.param_groups):
42 | lr = param_group['lr']
43 | return lr
44 |
--------------------------------------------------------------------------------
/Image_deraining/valid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import functional as F
3 | from data import valid_dataloader
4 | from utils import Adder
5 | import os
6 | from skimage.metrics import peak_signal_noise_ratio
7 | import torch.nn.functional as f
8 |
9 |
10 | def _valid(model, args, ep):
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 | gopro = valid_dataloader(args.valid_data, batch_size=1, num_workers=0)
13 | model.eval()
14 | psnr_adder = Adder()
15 |
16 | with torch.no_grad():
17 | print('Start Derain Evaluation')
18 | factor = 32
19 | for idx, data in enumerate(gopro):
20 | input_img, label_img = data
21 | input_img = input_img.to(device)
22 |
23 | h, w = input_img.shape[2], input_img.shape[3]
24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
25 | padh = H-h if h%factor!=0 else 0
26 | padw = W-w if w%factor!=0 else 0
27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
28 |
29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))):
30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep)))
31 |
32 | pred = model(input_img)[2]
33 | pred = pred[:,:,:h,:w]
34 |
35 | pred_clip = torch.clamp(pred, 0, 1)
36 | p_numpy = pred_clip.squeeze(0).cpu().numpy()
37 | label_numpy = label_img.squeeze(0).cpu().numpy()
38 |
39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1)
40 |
41 | psnr_adder(psnr)
42 | print('\r%03d'%idx, end=' ')
43 |
44 | print('\n')
45 | model.train()
46 | return psnr_adder.average()
47 |
--------------------------------------------------------------------------------
/Image_desnowing/README.md:
--------------------------------------------------------------------------------
1 | ### Download the Datasets
2 | - SRRS [[gdrive](https://drive.google.com/file/d/11h1cZ0NXx6ev35cl5NKOAL3PCgLlWUl2/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1VXqsamkl12fPsI1Qek97TQ?pwd=vcfg)]
3 | - CSD [[gdrive](https://drive.google.com/file/d/1pns-7uWy-0SamxjA40qOCkkhSu7o7ULb/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1N52Jnx0co9udJeYrbd3blA?pwd=sb4a)]
4 | - Snow100K [[gdrive](https://drive.google.com/file/d/19zJs0cJ6F3G3IlDHLU2BO7nHnCTMNrIS/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1QGd5z9uM6vBKPnD5d7jQmA?pwd=aph4)]
5 |
6 | ### Training
7 | version=small/base/large for different versions
8 | ~~~
9 | python main.py --data CSD --version small --mode train --data_dir your_path/CSD
10 | python main.py --data SRRS --version small --mode train --data_dir your_path/SRRS
11 | python main.py --data Snow100K --version small --mode train --data_dir your_path/Snow100K
12 | ~~~
13 |
14 | ### Evaluation
15 | #### Download the model
16 | [gdrive](https://drive.google.com/drive/folders/1_5fO2p5xoWO5cUEVoXJ7x3Uhg1AP18FQ?usp=sharing), [百度网盘](https://pan.baidu.com/s/1oYzdxs3FvLJMWx7S5GW0rA?pwd=dvta)
17 | #### Testing
18 | version=small/base/large for different versions
19 | ~~~
20 | python main.py --data CSD --version small --save_image True --mode test --data_dir your_path/CSD --test_model path_to_CSD_model
21 |
22 | python main.py --data SRRS --version small --save_image True --mode test --data_dir your_path/SRRS --test_model path_to_SRRS_model
23 |
24 | python main.py --data Snow100K --version small --save_image True --mode test --data_dir your_path/Snow100K --test_model path_to_Snow100K_model
25 | ~~~
26 |
27 |
28 | For training and testing, your directory structure should look like this
29 |
30 | `Your path`
31 | `├──CSD`
32 | `├──train2500`
33 | `├──Gt`
34 | `└──Snow`
35 | `└──test2000`
36 | `├──Gt`
37 | `└──Snow`
38 | `├──SRRS`
39 | `├──train2500`
40 | `├──Gt`
41 | `└──Snow`
42 | `└──test2000`
43 | `├──Gt`
44 | `└──Snow`
45 | `└──Snow100K`
46 | `├──train2500`
47 | `├──Gt`
48 | `└──Snow`
49 | `└──test2000`
50 | `├──Gt`
51 | `└──Snow`
52 |
--------------------------------------------------------------------------------
/Image_desnowing/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor
2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader
3 |
--------------------------------------------------------------------------------
/Image_desnowing/data/data_augment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms as transforms
3 | import torchvision.transforms.functional as F
4 |
5 |
6 | class PairRandomCrop(transforms.RandomCrop):
7 |
8 | def __call__(self, image, label):
9 |
10 | if self.padding is not None:
11 | image = F.pad(image, self.padding, self.fill, self.padding_mode)
12 | label = F.pad(label, self.padding, self.fill, self.padding_mode)
13 |
14 | # pad the width if needed
15 | if self.pad_if_needed and image.size[0] < self.size[1]:
16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode)
17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode)
18 | # pad the height if needed
19 | if self.pad_if_needed and image.size[1] < self.size[0]:
20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
22 |
23 | i, j, h, w = self.get_params(image, self.size)
24 |
25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w)
26 |
27 |
28 | class PairCompose(transforms.Compose):
29 | def __call__(self, image, label):
30 | for t in self.transforms:
31 | image, label = t(image, label)
32 | return image, label
33 |
34 |
35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip):
36 | def __call__(self, img, label):
37 | """
38 | Args:
39 | img (PIL Image): Image to be flipped.
40 |
41 | Returns:
42 | PIL Image: Randomly flipped image.
43 | """
44 | if random.random() < self.p:
45 | return F.hflip(img), F.hflip(label)
46 | return img, label
47 |
48 |
49 | class PairToTensor(transforms.ToTensor):
50 | def __call__(self, pic, label):
51 | """
52 | Args:
53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
54 |
55 | Returns:
56 | Tensor: Converted image.
57 | """
58 | return F.to_tensor(pic), F.to_tensor(label)
59 |
--------------------------------------------------------------------------------
/Image_desnowing/data/data_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image as Image
3 | from data import PairCompose, PairRandomCrop, PairRandomHorizontalFilp, PairToTensor
4 | from torchvision.transforms import functional as F
5 | from torch.utils.data import Dataset, DataLoader
6 |
7 |
8 | def train_dataloader(path, batch_size=64, num_workers=0, data='CSD', use_transform=True):
9 | image_dir = os.path.join(path, 'train2500')
10 |
11 | transform = None
12 | if use_transform:
13 | transform = PairCompose(
14 | [
15 | PairRandomCrop(256),
16 | PairRandomHorizontalFilp(),
17 | PairToTensor()
18 | ]
19 | )
20 | dataloader = DataLoader(
21 | DeblurDataset(image_dir, data, transform=transform),
22 | batch_size=batch_size,
23 | shuffle=True,
24 | num_workers=num_workers,
25 | pin_memory=True
26 | )
27 | return dataloader
28 |
29 |
30 | def test_dataloader(path, data, batch_size=1, num_workers=0):
31 | image_dir = os.path.join(path, 'test2000')
32 | dataloader = DataLoader(
33 | DeblurDataset(image_dir, data, is_test=True),
34 |
35 | batch_size=batch_size,
36 | shuffle=False,
37 | num_workers=num_workers,
38 | pin_memory=True
39 | )
40 |
41 | return dataloader
42 |
43 |
44 | def valid_dataloader(path, data, batch_size=1, num_workers=0):
45 | dataloader = DataLoader(
46 | DeblurDataset(os.path.join(path, 'test2000'), data),
47 | batch_size=batch_size,
48 | shuffle=False,
49 | num_workers=num_workers
50 | )
51 |
52 | return dataloader
53 |
54 |
55 | class DeblurDataset(Dataset):
56 | def __init__(self, image_dir, data, transform=None, is_test=False):
57 | self.image_dir = image_dir
58 | self.image_list = os.listdir(os.path.join(image_dir, 'Snow/'))
59 | self.image_list.sort()
60 | self.transform = transform
61 | self.is_test = is_test
62 | self.data = data
63 |
64 | def __len__(self):
65 | return len(self.image_list)
66 |
67 | def __getitem__(self, idx):
68 | image = Image.open(os.path.join(self.image_dir, 'Snow', self.image_list[idx]))
69 | if self.data == 'SRRS':
70 | label = Image.open(os.path.join(self.image_dir, 'Gt', self.image_list[idx].split('.')[0]+'.jpg'))
71 | else:
72 | label = Image.open(os.path.join(self.image_dir, 'Gt', self.image_list[idx]))
73 |
74 | if self.transform:
75 | image, label = self.transform(image, label)
76 | else:
77 | image = F.to_tensor(image)
78 | label = F.to_tensor(label)
79 | if self.is_test:
80 | name = self.image_list[idx]
81 | return image, label, name
82 | return image, label
83 |
84 |
--------------------------------------------------------------------------------
/Image_desnowing/eval.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | from pytorch_msssim import ssim
5 | from torchvision.transforms import functional as F
6 | from utils import Adder
7 | from data import test_dataloader
8 | from skimage.metrics import peak_signal_noise_ratio
9 | import torch.nn.functional as f
10 |
11 | def _eval(model, args):
12 | state_dict = torch.load(args.test_model)
13 | model.load_state_dict(state_dict['model'])
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | dataloader = test_dataloader(args.data_dir, args.data, batch_size=1, num_workers=0)
16 | torch.cuda.empty_cache()
17 | model.eval()
18 | factor = 32
19 | with torch.no_grad():
20 | psnr_adder = Adder()
21 | ssim_adder = Adder()
22 |
23 | for iter_idx, data in enumerate(dataloader):
24 | input_img, label_img, name = data
25 | input_img = input_img.to(device)
26 |
27 | h, w = input_img.shape[2], input_img.shape[3]
28 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
29 | padh = H-h if h%factor!=0 else 0
30 | padw = W-w if w%factor!=0 else 0
31 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
32 |
33 | pred = model(input_img)[2]
34 | pred = pred[:,:,:h,:w]
35 |
36 | pred_clip = torch.clamp(pred, 0, 1)
37 |
38 | pred_numpy = pred_clip.squeeze(0).cpu().numpy()
39 | label_numpy = label_img.squeeze(0).cpu().numpy()
40 |
41 |
42 | if args.save_image:
43 | save_name = os.path.join(args.result_dir, name[0])
44 | pred_clip += 0.5 / 255
45 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB')
46 | pred.save(save_name)
47 |
48 |
49 | label_img = (label_img).cuda()
50 | down_ratio = max(1, round(min(H, W) / 256))
51 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))),
52 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))),
53 | data_range=1, size_average=False)
54 | ssim_adder(ssim_val)
55 |
56 | psnr = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1)
57 | psnr_adder(psnr)
58 |
59 | print('%d iter PSNR: %.2f SSIM: %f' % (iter_idx + 1, psnr, ssim_val))
60 |
61 | print('==========================================================')
62 | print('The average PSNR is %.2f dB' % (psnr_adder.average()))
63 | print('The average SSIM is %.4f' % (ssim_adder.average()))
64 |
65 |
--------------------------------------------------------------------------------
/Image_desnowing/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from torch.backends import cudnn
5 | from models.ConvIR import build_net
6 | from train import _train
7 | from eval import _eval
8 |
9 | def main(args):
10 | cudnn.benchmark = True
11 |
12 | if not os.path.exists('results/'):
13 | os.makedirs(args.model_save_dir)
14 | if not os.path.exists('results/' + args.model_name + '/'):
15 | os.makedirs('results/' + args.model_name + '/')
16 | if not os.path.exists(args.result_dir):
17 | os.makedirs(args.result_dir)
18 |
19 | model = build_net(args.version)
20 | # print(model)
21 |
22 | if torch.cuda.is_available():
23 | model.cuda()
24 |
25 | if args.mode == 'train':
26 | _train(model, args)
27 |
28 | elif args.mode == 'test':
29 | _eval(model, args)
30 |
31 |
32 | if __name__ == '__main__':
33 | parser = argparse.ArgumentParser()
34 |
35 | # Directories
36 | parser.add_argument('--model_name', default='ConvIR', type=str)
37 | parser.add_argument('--data', type=str, default='CSD', choices=['CSD', 'SRRS', 'Snow100K'])
38 |
39 | parser.add_argument('--data_dir', type=str, default='CSD')
40 | parser.add_argument('--mode', default='train', choices=['train', 'test'], type=str)
41 | parser.add_argument('--version', default='small', choices=['small', 'base', 'large'], type=str)
42 |
43 | # Train
44 | parser.add_argument('--batch_size', type=int, default=8)
45 | parser.add_argument('--learning_rate', type=float, default=2e-4)
46 | parser.add_argument('--weight_decay', type=float, default=0)
47 | parser.add_argument('--num_epoch', type=int, default=2000)
48 | parser.add_argument('--print_freq', type=int, default=100)
49 | parser.add_argument('--num_worker', type=int, default=8)
50 | parser.add_argument('--save_freq', type=int, default=50)
51 | parser.add_argument('--valid_freq', type=int, default=50)
52 | parser.add_argument('--resume', type=str, default='')
53 |
54 | # Test
55 | parser.add_argument('--test_model', type=str, default='CSD.pkl')
56 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False])
57 |
58 | args = parser.parse_args()
59 | args.model_save_dir = os.path.join('results/', args.model_name, 'Training-Results/')
60 | args.result_dir = os.path.join('results/', args.model_name, 'images', args.data)
61 | if not os.path.exists(args.model_save_dir):
62 | os.makedirs(args.model_save_dir)
63 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir
64 | os.system(command)
65 | command = 'cp ' + 'models/ConvIR.py ' + args.model_save_dir
66 | os.system(command)
67 | command = 'cp ' + 'train.py ' + args.model_save_dir
68 | os.system(command)
69 | command = 'cp ' + 'main.py ' + args.model_save_dir
70 | os.system(command)
71 | print(args)
72 | main(args)
73 |
--------------------------------------------------------------------------------
/Image_desnowing/models/ConvIR.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .layers import *
6 |
7 |
8 | class EBlock(nn.Module):
9 | def __init__(self, out_channel, num_res=8):
10 | super(EBlock, self).__init__()
11 |
12 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res-1)]
13 | layers.append(ResBlock(out_channel, out_channel, filter=True))
14 |
15 | self.layers = nn.Sequential(*layers)
16 |
17 | def forward(self, x):
18 | return self.layers(x)
19 |
20 |
21 | class DBlock(nn.Module):
22 | def __init__(self, channel, num_res=8):
23 | super(DBlock, self).__init__()
24 |
25 | layers = [ResBlock(channel, channel) for _ in range(num_res-1)]
26 | layers.append(ResBlock(channel, channel, filter=True))
27 | self.layers = nn.Sequential(*layers)
28 |
29 | def forward(self, x):
30 | return self.layers(x)
31 |
32 |
33 | class SCM(nn.Module):
34 | def __init__(self, out_plane):
35 | super(SCM, self).__init__()
36 | self.main = nn.Sequential(
37 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
38 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
39 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
40 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False),
41 | nn.InstanceNorm2d(out_plane, affine=True)
42 | )
43 |
44 | def forward(self, x):
45 | x = self.main(x)
46 | return x
47 |
48 | class FAM(nn.Module):
49 | def __init__(self, channel):
50 | super(FAM, self).__init__()
51 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False)
52 |
53 | def forward(self, x1, x2):
54 | return self.merge(torch.cat([x1, x2], dim=1))
55 |
56 | class ConvIR(nn.Module):
57 | def __init__(self, version):
58 | super(ConvIR, self).__init__()
59 |
60 | if version == 'small':
61 | num_res = 4
62 | elif version == 'base':
63 | num_res = 8
64 | elif version == 'large':
65 | num_res = 16
66 |
67 | base_channel = 32
68 |
69 | self.Encoder = nn.ModuleList([
70 | EBlock(base_channel, num_res),
71 | EBlock(base_channel*2, num_res),
72 | EBlock(base_channel*4, num_res),
73 | ])
74 |
75 | self.feat_extract = nn.ModuleList([
76 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
77 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
78 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
79 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
80 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
81 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
82 | ])
83 |
84 | self.Decoder = nn.ModuleList([
85 | DBlock(base_channel * 4, num_res),
86 | DBlock(base_channel * 2, num_res),
87 | DBlock(base_channel, num_res)
88 | ])
89 |
90 | self.Convs = nn.ModuleList([
91 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
92 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
93 | ])
94 |
95 | self.ConvsOut = nn.ModuleList(
96 | [
97 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
98 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
99 | ]
100 | )
101 |
102 | self.FAM1 = FAM(base_channel * 4)
103 | self.SCM1 = SCM(base_channel * 4)
104 | self.FAM2 = FAM(base_channel * 2)
105 | self.SCM2 = SCM(base_channel * 2)
106 |
107 | def forward(self, x):
108 | x_2 = F.interpolate(x, scale_factor=0.5)
109 | x_4 = F.interpolate(x_2, scale_factor=0.5)
110 | z2 = self.SCM2(x_2)
111 | z4 = self.SCM1(x_4)
112 |
113 | outputs = list()
114 | # 256
115 | x_ = self.feat_extract[0](x)
116 | res1 = self.Encoder[0](x_)
117 | # 128
118 | z = self.feat_extract[1](res1)
119 | z = self.FAM2(z, z2)
120 | res2 = self.Encoder[1](z)
121 | # 64
122 | z = self.feat_extract[2](res2)
123 | z = self.FAM1(z, z4)
124 | z = self.Encoder[2](z)
125 |
126 | z = self.Decoder[0](z)
127 | z_ = self.ConvsOut[0](z)
128 | # 128
129 | z = self.feat_extract[3](z)
130 | outputs.append(z_+x_4)
131 |
132 | z = torch.cat([z, res2], dim=1)
133 | z = self.Convs[0](z)
134 | z = self.Decoder[1](z)
135 | z_ = self.ConvsOut[1](z)
136 | # 256
137 | z = self.feat_extract[4](z)
138 | outputs.append(z_+x_2)
139 |
140 | z = torch.cat([z, res1], dim=1)
141 | z = self.Convs[1](z)
142 | z = self.Decoder[2](z)
143 | z = self.feat_extract[5](z)
144 | outputs.append(z+x)
145 |
146 | return outputs
147 |
148 |
149 | def build_net(version):
150 | return ConvIR(version)
151 |
--------------------------------------------------------------------------------
/Image_desnowing/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class BasicConv(nn.Module):
6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
7 | super(BasicConv, self).__init__()
8 | if bias and norm:
9 | bias = False
10 |
11 | padding = kernel_size // 2
12 | layers = list()
13 | if transpose:
14 | padding = kernel_size // 2 -1
15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
16 | else:
17 | layers.append(
18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
19 | if norm:
20 | layers.append(nn.BatchNorm2d(out_channel))
21 | if relu:
22 | layers.append(nn.GELU())
23 | self.main = nn.Sequential(*layers)
24 |
25 | def forward(self, x):
26 | return self.main(x)
27 |
28 |
29 | class ResBlock(nn.Module):
30 | def __init__(self, in_channel, out_channel, filter=False):
31 | super(ResBlock, self).__init__()
32 | self.main = nn.Sequential(
33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
34 | DeepPoolLayer(in_channel, out_channel) if filter else nn.Identity(),
35 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
36 | )
37 |
38 | def forward(self, x):
39 | return self.main(x) + x
40 |
41 |
42 | class DeepPoolLayer(nn.Module):
43 | def __init__(self, k, k_out):
44 | super(DeepPoolLayer, self).__init__()
45 | self.pools_sizes = [8,4,2]
46 | dilation = [7,9,11]
47 | pools, convs, dynas = [],[],[]
48 | for j, i in enumerate(self.pools_sizes):
49 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i))
50 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False))
51 | dynas.append(MultiShapeKernel(dim=k, kernel_size=3, dilation=dilation[j]))
52 | self.pools = nn.ModuleList(pools)
53 | self.convs = nn.ModuleList(convs)
54 | self.dynas = nn.ModuleList(dynas)
55 | self.relu = nn.GELU()
56 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False)
57 |
58 | def forward(self, x):
59 | x_size = x.size()
60 | resl = x
61 | for i in range(len(self.pools_sizes)):
62 | if i == 0:
63 | y = self.dynas[i](self.convs[i](self.pools[i](x)))
64 | else:
65 | y = self.dynas[i](self.convs[i](self.pools[i](x)+y_up))
66 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True))
67 | if i != len(self.pools_sizes)-1:
68 | y_up = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
69 | resl = self.relu(resl)
70 | resl = self.conv_sum(resl)
71 |
72 | return resl
73 |
74 |
75 | class dynamic_filter(nn.Module):
76 | def __init__(self, inchannels, kernel_size=3, dilation=1, stride=1, group=8):
77 | super(dynamic_filter, self).__init__()
78 | self.stride = stride
79 | self.kernel_size = kernel_size
80 | self.group = group
81 | self.dilation = dilation
82 |
83 | self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False)
84 | self.bn = nn.BatchNorm2d(group*kernel_size**2)
85 | self.act = nn.Tanh()
86 |
87 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
88 | self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True)
89 | self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True)
90 | self.pad = nn.ReflectionPad2d(self.dilation*(kernel_size-1)//2)
91 |
92 | self.ap = nn.AdaptiveAvgPool2d((1, 1))
93 | self.gap = nn.AdaptiveAvgPool2d(1)
94 |
95 | self.inside_all = nn.Parameter(torch.zeros(inchannels,1,1), requires_grad=True)
96 |
97 | def forward(self, x):
98 | identity_input = x
99 | low_filter = self.ap(x)
100 | low_filter = self.conv(low_filter)
101 | low_filter = self.bn(low_filter)
102 |
103 | n, c, h, w = x.shape
104 | x = F.unfold(self.pad(x), kernel_size=self.kernel_size, dilation=self.dilation).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w)
105 |
106 | n,c1,p,q = low_filter.shape
107 | low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2)
108 |
109 | low_filter = self.act(low_filter)
110 |
111 | low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w)
112 |
113 | out_low = low_part * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input)
114 |
115 | out_low = out_low * self.lamb_l[None,:,None,None]
116 |
117 | out_high = (identity_input) * (self.lamb_h[None,:,None,None] + 1.)
118 |
119 | return out_low + out_high
120 |
121 |
122 | class cubic_attention(nn.Module):
123 | def __init__(self, dim, group, dilation, kernel) -> None:
124 | super().__init__()
125 |
126 | self.H_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel)
127 | self.W_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel, H=False)
128 | self.gamma = nn.Parameter(torch.zeros(dim,1,1))
129 | self.beta = nn.Parameter(torch.ones(dim,1,1))
130 |
131 | def forward(self, x):
132 | out = self.H_spatial_att(x)
133 | out = self.W_spatial_att(out)
134 | return self.gamma * out + x * self.beta
135 |
136 |
137 | class spatial_strip_att(nn.Module):
138 | def __init__(self, dim, kernel=3, dilation=1, group=2, H=True) -> None:
139 | super().__init__()
140 |
141 | self.k = kernel
142 | pad = dilation*(kernel-1) // 2
143 | self.kernel = (1, kernel) if H else (kernel, 1)
144 | self.padding = (kernel//2, 1) if H else (1, kernel//2)
145 | self.dilation = dilation
146 | self.group = group
147 | self.pad = nn.ReflectionPad2d((pad, pad, 0, 0)) if H else nn.ReflectionPad2d((0, 0, pad, pad))
148 | self.conv = nn.Conv2d(dim, group*kernel, kernel_size=1, stride=1, bias=False)
149 | self.ap = nn.AdaptiveAvgPool2d((1, 1))
150 | self.filter_act = nn.Tanh()
151 | self.inside_all = nn.Parameter(torch.zeros(dim,1,1), requires_grad=True)
152 | self.lamb_l = nn.Parameter(torch.zeros(dim), requires_grad=True)
153 | self.lamb_h = nn.Parameter(torch.zeros(dim), requires_grad=True)
154 | gap_kernel = (None,1) if H else (1, None)
155 | self.gap = nn.AdaptiveAvgPool2d(gap_kernel)
156 |
157 | def forward(self, x):
158 | identity_input = x.clone()
159 | filter = self.ap(x)
160 | filter = self.conv(filter)
161 | n, c, h, w = x.shape
162 | x = F.unfold(self.pad(x), kernel_size=self.kernel, dilation=self.dilation).reshape(n, self.group, c//self.group, self.k, h*w)
163 | n, c1, p, q = filter.shape
164 | filter = filter.reshape(n, c1//self.k, self.k, p*q).unsqueeze(2)
165 | filter = self.filter_act(filter)
166 | out = torch.sum(x * filter, dim=3).reshape(n, c, h, w)
167 |
168 | out_low = out * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input)
169 | out_low = out_low * self.lamb_l[None,:,None,None]
170 | out_high = identity_input * (self.lamb_h[None,:,None,None]+1.)
171 |
172 | return out_low + out_high
173 |
174 |
175 | class MultiShapeKernel(nn.Module):
176 | def __init__(self, dim, kernel_size=3, dilation=1, group=8):
177 | super().__init__()
178 |
179 | self.square_att = dynamic_filter(inchannels=dim, dilation=dilation, group=group, kernel_size=kernel_size)
180 | self.strip_att = cubic_attention(dim, group=group, dilation=dilation, kernel=kernel_size)
181 |
182 | def forward(self, x):
183 |
184 | x1 = self.strip_att(x)
185 | x2 = self.square_att(x)
186 |
187 | return x1+x2
188 |
189 |
190 |
--------------------------------------------------------------------------------
/Image_desnowing/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from data import train_dataloader
4 | from utils import Adder, Timer
5 | from torch.utils.tensorboard import SummaryWriter
6 | from valid import _valid
7 | import torch.nn.functional as F
8 | import torch.nn as nn
9 |
10 | from warmup_scheduler import GradualWarmupScheduler
11 |
12 |
13 | def _train(model, args):
14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15 | criterion = torch.nn.L1Loss()
16 |
17 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8)
18 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker, args.data)
19 | max_iter = len(dataloader)
20 | warmup_epochs=3
21 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6)
22 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
23 | scheduler.step()
24 | epoch = 1
25 | if args.resume:
26 | state = torch.load(args.resume)
27 | epoch = state['epoch']
28 | optimizer.load_state_dict(state['optimizer'])
29 | model.load_state_dict(state['model'])
30 | print('Resume from %d'%epoch)
31 | epoch += 1
32 |
33 |
34 |
35 | writer = SummaryWriter()
36 | epoch_pixel_adder = Adder()
37 | epoch_fft_adder = Adder()
38 | iter_pixel_adder = Adder()
39 | iter_fft_adder = Adder()
40 | epoch_timer = Timer('m')
41 | iter_timer = Timer('m')
42 | best_psnr=-1
43 |
44 | for epoch_idx in range(epoch, args.num_epoch + 1):
45 |
46 | epoch_timer.tic()
47 | iter_timer.tic()
48 | for iter_idx, batch_data in enumerate(dataloader):
49 |
50 | input_img, label_img = batch_data
51 | input_img = input_img.to(device)
52 | label_img = label_img.to(device)
53 |
54 | optimizer.zero_grad()
55 | pred_img = model(input_img)
56 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear')
57 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear')
58 | l1 = criterion(pred_img[0], label_img4)
59 | l2 = criterion(pred_img[1], label_img2)
60 | l3 = criterion(pred_img[2], label_img)
61 | loss_content = l1+l2+l3
62 |
63 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1))
64 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1)
65 |
66 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1))
67 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1)
68 |
69 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1))
70 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1)
71 |
72 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1))
73 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1)
74 |
75 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1))
76 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1)
77 |
78 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1))
79 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1)
80 |
81 | f1 = criterion(pred_fft1, label_fft1)
82 | f2 = criterion(pred_fft2, label_fft2)
83 | f3 = criterion(pred_fft3, label_fft3)
84 | loss_fft = f1+f2+f3
85 |
86 | loss = loss_content + 0.1 * loss_fft
87 | loss.backward()
88 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01)
89 | optimizer.step()
90 |
91 | iter_pixel_adder(loss_content.item())
92 | iter_fft_adder(loss_fft.item())
93 |
94 | epoch_pixel_adder(loss_content.item())
95 | epoch_fft_adder(loss_fft.item())
96 |
97 | if (iter_idx + 1) % args.print_freq == 0:
98 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % (
99 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(),
100 | iter_fft_adder.average()))
101 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter)
102 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter)
103 |
104 | iter_timer.tic()
105 | iter_pixel_adder.reset()
106 | iter_fft_adder.reset()
107 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl')
108 | torch.save({'model': model.state_dict(),
109 | 'optimizer': optimizer.state_dict(),
110 | 'epoch': epoch_idx}, overwrite_name)
111 |
112 | if epoch_idx % args.save_freq == 0:
113 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx)
114 | torch.save({'model': model.state_dict()}, save_name)
115 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % (
116 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average()))
117 | epoch_fft_adder.reset()
118 | epoch_pixel_adder.reset()
119 | scheduler.step()
120 | if epoch_idx % args.valid_freq == 0:
121 | val_snow = _valid(model, args, epoch_idx)
122 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val_snow))
123 | writer.add_scalar('PSNR_Desnowing', val_snow, epoch_idx)
124 | if val_snow >= best_psnr:
125 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl'))
126 | save_name = os.path.join(args.model_save_dir, 'Final.pkl')
127 | torch.save({'model': model.state_dict()}, save_name)
128 |
--------------------------------------------------------------------------------
/Image_desnowing/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 |
4 |
5 | class Adder(object):
6 | def __init__(self):
7 | self.count = 0
8 | self.num = float(0)
9 |
10 | def reset(self):
11 | self.count = 0
12 | self.num = float(0)
13 |
14 | def __call__(self, num):
15 | self.count += 1
16 | self.num += num
17 |
18 | def average(self):
19 | return self.num / self.count
20 |
21 |
22 | class Timer(object):
23 | def __init__(self, option='s'):
24 | self.tm = 0
25 | self.option = option
26 | if option == 's':
27 | self.devider = 1
28 | elif option == 'm':
29 | self.devider = 60
30 | else:
31 | self.devider = 3600
32 |
33 | def tic(self):
34 | self.tm = time.time()
35 |
36 | def toc(self):
37 | return (time.time() - self.tm) / self.devider
38 |
39 |
40 | def check_lr(optimizer):
41 | for i, param_group in enumerate(optimizer.param_groups):
42 | lr = param_group['lr']
43 | return lr
44 |
--------------------------------------------------------------------------------
/Image_desnowing/valid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import functional as F
3 | from data import valid_dataloader
4 | from utils import Adder
5 | import os
6 | from skimage.metrics import peak_signal_noise_ratio
7 | import torch.nn.functional as f
8 |
9 |
10 | def _valid(model, args, ep):
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 | snow_data = valid_dataloader(args.data_dir, args.data, batch_size=1, num_workers=0)
13 | model.eval()
14 | psnr_adder = Adder()
15 |
16 | with torch.no_grad():
17 | print('Start Desnowing Evaluation')
18 | factor = 32
19 | for idx, data in enumerate(snow_data):
20 | input_img, label_img = data
21 | input_img = input_img.to(device)
22 |
23 | h, w = input_img.shape[2], input_img.shape[3]
24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
25 | padh = H-h if h%factor!=0 else 0
26 | padw = W-w if w%factor!=0 else 0
27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
28 |
29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))):
30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep)))
31 |
32 | pred = model(input_img)[2]
33 | pred = pred[:,:,:h,:w]
34 |
35 | pred_clip = torch.clamp(pred, 0, 1)
36 | p_numpy = pred_clip.squeeze(0).cpu().numpy()
37 | label_numpy = label_img.squeeze(0).cpu().numpy()
38 |
39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1)
40 |
41 | psnr_adder(psnr)
42 | print('\r%03d'%idx, end=' ')
43 |
44 | print('\n')
45 | model.train()
46 | return psnr_adder.average()
47 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Yuning Cui
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Motion_Deblurring/README.md:
--------------------------------------------------------------------------------
1 | ### Download the Datasets
2 | - Gopro [[gdrive](https://drive.google.com/file/d/1y_wQ5G5B65HS_mdIjxKYTcnRys_AGh5v/view?usp=sharing), [百度网盘](https://pan.baidu.com/s/1eNCvqewdUp15-0dD2MfJbg?pwd=ea0r)]
3 |
4 | ### Training on GoPro
5 | ~~~
6 | python main.py --data_dir your_path/GOPRO
7 | ~~~
8 | ### Evaluation
9 | Download model: [gdrive](https://drive.google.com/drive/folders/1_5fO2p5xoWO5cUEVoXJ7x3Uhg1AP18FQ?usp=sharing), [百度网盘](https://pan.baidu.com/s/1oYzdxs3FvLJMWx7S5GW0rA?pwd=dvta)
10 | #### Testing on GoPro
11 | ~~~
12 | python main.py --mode test --data_dir your_path/GOPRO --test_model path_to_gopro_model --save_image True
13 | ~~~
14 |
--------------------------------------------------------------------------------
/Motion_Deblurring/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor
2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader
3 |
--------------------------------------------------------------------------------
/Motion_Deblurring/data/data_augment.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torchvision.transforms as transforms
3 | import torchvision.transforms.functional as F
4 |
5 |
6 | class PairRandomCrop(transforms.RandomCrop):
7 |
8 | def __call__(self, image, label):
9 |
10 | if self.padding is not None:
11 | image = F.pad(image, self.padding, self.fill, self.padding_mode)
12 | label = F.pad(label, self.padding, self.fill, self.padding_mode)
13 |
14 | # pad the width if needed
15 | if self.pad_if_needed and image.size[0] < self.size[1]:
16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode)
17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode)
18 | # pad the height if needed
19 | if self.pad_if_needed and image.size[1] < self.size[0]:
20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode)
22 |
23 | i, j, h, w = self.get_params(image, self.size)
24 |
25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w)
26 |
27 |
28 | class PairCompose(transforms.Compose):
29 | def __call__(self, image, label):
30 | for t in self.transforms:
31 | image, label = t(image, label)
32 | return image, label
33 |
34 |
35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip):
36 | def __call__(self, img, label):
37 | """
38 | Args:
39 | img (PIL Image): Image to be flipped.
40 |
41 | Returns:
42 | PIL Image: Randomly flipped image.
43 | """
44 | if random.random() < self.p:
45 | return F.hflip(img), F.hflip(label)
46 | return img, label
47 |
48 |
49 | #class PairRandomVerticalFlip(transforms.RandomVerticalFlip):
50 | # def __call__(self, img, label):
51 | """
52 | Args:
53 | img (PIL Image): Image to be flipped.
54 |
55 | Returns:
56 | PIL Image: Randomly flipped image.
57 | """
58 | # if random.random() < self.p:
59 | # return F.vflip(img), F.vflip(label)
60 | # return img, label
61 |
62 |
63 | class PairToTensor(transforms.ToTensor):
64 | def __call__(self, pic, label):
65 | """
66 | Args:
67 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
68 |
69 | Returns:
70 | Tensor: Converted image.
71 | """
72 | return F.to_tensor(pic), F.to_tensor(label)
73 |
--------------------------------------------------------------------------------
/Motion_Deblurring/data/data_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image as Image
5 | from data import PairCompose, PairRandomCrop, PairRandomHorizontalFilp, PairToTensor
6 | from torchvision.transforms import functional as F
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 |
10 | def train_dataloader(path, batch_size=64, num_workers=0, use_transform=True):
11 | image_dir = os.path.join(path, 'train')
12 |
13 | transform = None
14 | if use_transform:
15 | transform = PairCompose(
16 | [
17 | PairRandomCrop(256),
18 | PairRandomHorizontalFilp(),
19 | PairToTensor()
20 | ]
21 | )
22 | dataloader = DataLoader(
23 | DeblurDataset(image_dir, transform=transform),
24 | batch_size=batch_size,
25 | shuffle=True,
26 | num_workers=num_workers,
27 | pin_memory=True
28 | )
29 | return dataloader
30 |
31 |
32 | def test_dataloader(path, batch_size=1, num_workers=0):
33 | image_dir = os.path.join(path, 'valid')
34 | dataloader = DataLoader(
35 | DeblurDataset(image_dir, is_test=True),
36 | batch_size=batch_size,
37 | shuffle=False,
38 | num_workers=num_workers,
39 | pin_memory=True
40 | )
41 |
42 | return dataloader
43 |
44 |
45 | def valid_dataloader(path, batch_size=1, num_workers=0):
46 | dataloader = DataLoader(
47 | DeblurDataset(os.path.join(path, 'valid')),
48 | batch_size=batch_size,
49 | shuffle=False,
50 | num_workers=num_workers
51 | )
52 |
53 | return dataloader
54 |
55 |
56 | class DeblurDataset(Dataset):
57 | def __init__(self, image_dir, transform=None, is_test=False):
58 | self.image_dir = image_dir
59 | self.image_list = os.listdir(os.path.join(image_dir, 'blur/'))
60 | self._check_image(self.image_list)
61 | self.image_list.sort()
62 | self.transform = transform
63 | self.is_test = is_test
64 |
65 | def __len__(self):
66 | return len(self.image_list)
67 |
68 | def __getitem__(self, idx):
69 | image = Image.open(os.path.join(self.image_dir, 'blur', self.image_list[idx]))
70 | label = Image.open(os.path.join(self.image_dir, 'sharp', self.image_list[idx].replace('blur', 'gt')))
71 |
72 | if self.transform:
73 | image, label = self.transform(image, label)
74 | else:
75 | image = F.to_tensor(image)
76 | label = F.to_tensor(label)
77 | if self.is_test:
78 | name = self.image_list[idx]
79 | return image, label, name
80 | return image, label
81 |
82 | @staticmethod
83 | def _check_image(lst):
84 | for x in lst:
85 | splits = x.split('.')
86 | if splits[-1] not in ['png', 'jpg', 'jpeg']:
87 | raise ValueError
88 |
--------------------------------------------------------------------------------
/Motion_Deblurring/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torchvision.transforms import functional as F
4 | from utils import Adder
5 | from data import test_dataloader
6 | from skimage.metrics import peak_signal_noise_ratio
7 | import time
8 | import torch.nn.functional as f
9 |
10 | factor = 32
11 |
12 | def _eval(model, args):
13 | state_dict = torch.load(args.test_model)
14 | model.load_state_dict(state_dict['model'])
15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0)
17 | adder = Adder()
18 | model.eval()
19 |
20 | with torch.no_grad():
21 | psnr_adder = Adder()
22 | for iter_idx, data in enumerate(dataloader):
23 | input_img, label_img, name = data
24 |
25 | input_img = input_img.to(device)
26 | h, w = input_img.shape[2], input_img.shape[3]
27 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
28 | padh = H-h if h%factor!=0 else 0
29 | padw = W-w if w%factor!=0 else 0
30 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
31 | tm = time.time()
32 |
33 | pred = model(input_img)[2]
34 | pred = pred[:,:,:h,:w]
35 | elapsed = time.time() - tm
36 | adder(elapsed)
37 |
38 | pred_clip = torch.clamp(pred, 0, 1)
39 | pred_numpy = pred_clip.squeeze(0).cpu().numpy()
40 | label_numpy = label_img.squeeze(0).cpu().numpy()
41 |
42 | if args.save_image:
43 | save_name = os.path.join(args.result_dir, name[0])
44 | pred_clip += 0.5 / 255
45 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB')
46 | pred.save(save_name)
47 |
48 | psnr = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1)
49 | psnr_adder(psnr)
50 | print('%d iter PSNR: %.4f time: %f' % (iter_idx + 1, psnr, elapsed))
51 |
52 | print('==========================================================')
53 | print('The average PSNR is %.4f dB' % (psnr_adder.average()))
54 | print("Average time: %f" % adder.average())
55 |
--------------------------------------------------------------------------------
/Motion_Deblurring/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from torch.backends import cudnn
5 | from models.ConvIR import build_net
6 | from train import _train
7 | from eval import _eval
8 |
9 | def main(args):
10 | cudnn.benchmark = True
11 |
12 | if not os.path.exists('results/'):
13 | os.makedirs(args.model_save_dir)
14 | if not os.path.exists('results/' + args.model_name + '/'):
15 | os.makedirs('results/' + args.model_name + '/')
16 | if not os.path.exists(args.model_save_dir):
17 | os.makedirs(args.model_save_dir)
18 | if not os.path.exists(args.result_dir):
19 | os.makedirs(args.result_dir)
20 |
21 | model = build_net()
22 | # print(model)
23 |
24 | if torch.cuda.is_available():
25 | model.cuda()
26 |
27 | if args.mode == 'train':
28 | _train(model, args)
29 |
30 | elif args.mode == 'test':
31 | _eval(model, args)
32 |
33 |
34 | if __name__ == '__main__':
35 | parser = argparse.ArgumentParser()
36 |
37 | # Directories
38 | parser.add_argument('--model_name', default='ConvIR', type=str)
39 | parser.add_argument('--data_dir', type=str, default='')
40 |
41 | parser.add_argument('--mode', default='train', choices=['train', 'test'], type=str)
42 |
43 | # Train
44 | parser.add_argument('--batch_size', type=int, default=4)
45 | parser.add_argument('--learning_rate', type=float, default=1e-4)
46 | parser.add_argument('--weight_decay', type=float, default=0)
47 | parser.add_argument('--num_epoch', type=int, default=3000)
48 | parser.add_argument('--print_freq', type=int, default=100)
49 | parser.add_argument('--num_worker', type=int, default=8)
50 | parser.add_argument('--save_freq', type=int, default=100)
51 | parser.add_argument('--valid_freq', type=int, default=100)
52 | parser.add_argument('--resume', type=str, default='')
53 |
54 | # Test
55 | parser.add_argument('--test_model', type=str, default='gopro.pkl')
56 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False])
57 |
58 | args = parser.parse_args()
59 | args.model_save_dir = os.path.join('results/', 'ConvIR', 'test')
60 | args.result_dir = os.path.join('results/', args.model_name, 'GOPRO')
61 | if not os.path.exists(args.model_save_dir):
62 | os.makedirs(args.model_save_dir)
63 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir
64 | os.system(command)
65 | command = 'cp ' + 'models/ConvIR.py ' + args.model_save_dir
66 | os.system(command)
67 | command = 'cp ' + 'train.py ' + args.model_save_dir
68 | os.system(command)
69 | command = 'cp ' + 'main.py ' + args.model_save_dir
70 | os.system(command)
71 | print(args)
72 | main(args)
73 |
--------------------------------------------------------------------------------
/Motion_Deblurring/models/ConvIR.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .layers import *
6 |
7 |
8 | class EBlock(nn.Module):
9 | def __init__(self, out_channel, num_res=8):
10 | super(EBlock, self).__init__()
11 |
12 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res-1)]
13 | layers.append(ResBlock(out_channel, out_channel, filter=True))
14 |
15 | self.layers = nn.Sequential(*layers)
16 |
17 | def forward(self, x):
18 | return self.layers(x)
19 |
20 |
21 | class DBlock(nn.Module):
22 | def __init__(self, channel, num_res=8):
23 | super(DBlock, self).__init__()
24 |
25 | layers = [ResBlock(channel, channel) for _ in range(num_res-1)]
26 | layers.append(ResBlock(channel, channel, filter=True))
27 | self.layers = nn.Sequential(*layers)
28 |
29 | def forward(self, x):
30 | return self.layers(x)
31 |
32 |
33 | class SCM(nn.Module):
34 | def __init__(self, out_plane):
35 | super(SCM, self).__init__()
36 | self.main = nn.Sequential(
37 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True),
38 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
39 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
40 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False),
41 | nn.InstanceNorm2d(out_plane, affine=True)
42 | )
43 |
44 | def forward(self, x):
45 | x = self.main(x)
46 | return x
47 |
48 | class FAM(nn.Module):
49 | def __init__(self, channel):
50 | super(FAM, self).__init__()
51 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False)
52 |
53 | def forward(self, x1, x2):
54 | return self.merge(torch.cat([x1, x2], dim=1))
55 |
56 | class ConvIR(nn.Module):
57 | def __init__(self, num_res=16):
58 | super(ConvIR, self).__init__()
59 |
60 | base_channel = 32
61 |
62 | self.Encoder = nn.ModuleList([
63 | EBlock(base_channel, num_res),
64 | EBlock(base_channel*2, num_res),
65 | EBlock(base_channel*4, num_res),
66 | ])
67 |
68 | self.feat_extract = nn.ModuleList([
69 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
70 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
71 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
72 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
73 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
74 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
75 | ])
76 |
77 | self.Decoder = nn.ModuleList([
78 | DBlock(base_channel * 4, num_res),
79 | DBlock(base_channel * 2, num_res),
80 | DBlock(base_channel, num_res)
81 | ])
82 |
83 | self.Convs = nn.ModuleList([
84 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
85 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
86 | ])
87 |
88 | self.ConvsOut = nn.ModuleList(
89 | [
90 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
91 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
92 | ]
93 | )
94 |
95 | self.FAM1 = FAM(base_channel * 4)
96 | self.SCM1 = SCM(base_channel * 4)
97 | self.FAM2 = FAM(base_channel * 2)
98 | self.SCM2 = SCM(base_channel * 2)
99 |
100 | def forward(self, x):
101 | x_2 = F.interpolate(x, scale_factor=0.5)
102 | x_4 = F.interpolate(x_2, scale_factor=0.5)
103 | z2 = self.SCM2(x_2)
104 | z4 = self.SCM1(x_4)
105 |
106 | outputs = list()
107 | # 256
108 | x_ = self.feat_extract[0](x)
109 | res1 = self.Encoder[0](x_)
110 | # 128
111 | z = self.feat_extract[1](res1)
112 | z = self.FAM2(z, z2)
113 | res2 = self.Encoder[1](z)
114 | # 64
115 | z = self.feat_extract[2](res2)
116 | z = self.FAM1(z, z4)
117 | z = self.Encoder[2](z)
118 |
119 | z = self.Decoder[0](z)
120 | z_ = self.ConvsOut[0](z)
121 | # 128
122 | z = self.feat_extract[3](z)
123 | outputs.append(z_+x_4)
124 |
125 | z = torch.cat([z, res2], dim=1)
126 | z = self.Convs[0](z)
127 | z = self.Decoder[1](z)
128 | z_ = self.ConvsOut[1](z)
129 | # 256
130 | z = self.feat_extract[4](z)
131 | outputs.append(z_+x_2)
132 |
133 | z = torch.cat([z, res1], dim=1)
134 | z = self.Convs[1](z)
135 | z = self.Decoder[2](z)
136 | z = self.feat_extract[5](z)
137 | outputs.append(z+x)
138 |
139 | return outputs
140 |
141 |
142 | def build_net():
143 | return ConvIR()
144 |
--------------------------------------------------------------------------------
/Motion_Deblurring/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class BasicConv(nn.Module):
6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False):
7 | super(BasicConv, self).__init__()
8 | if bias and norm:
9 | bias = False
10 |
11 | padding = kernel_size // 2
12 | layers = list()
13 | if transpose:
14 | padding = kernel_size // 2 -1
15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
16 | else:
17 | layers.append(
18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
19 | if norm:
20 | layers.append(nn.BatchNorm2d(out_channel))
21 | if relu:
22 | layers.append(nn.GELU())
23 | self.main = nn.Sequential(*layers)
24 |
25 | def forward(self, x):
26 | return self.main(x)
27 |
28 |
29 | class ResBlock(nn.Module):
30 | def __init__(self, in_channel, out_channel, filter=False):
31 | super(ResBlock, self).__init__()
32 | self.main = nn.Sequential(
33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True),
34 | DeepPoolLayer(in_channel, out_channel) if filter else nn.Identity(),
35 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
36 | )
37 |
38 | def forward(self, x):
39 | return self.main(x) + x
40 |
41 |
42 | class DeepPoolLayer(nn.Module):
43 | def __init__(self, k, k_out):
44 | super(DeepPoolLayer, self).__init__()
45 | self.pools_sizes = [8,4,2]
46 | dilation = [7,9,11]
47 | pools, convs, dynas = [],[],[]
48 | for j, i in enumerate(self.pools_sizes):
49 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i))
50 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False))
51 | dynas.append(MultiShapeKernel(dim=k, kernel_size=3, dilation=dilation[j]))
52 | self.pools = nn.ModuleList(pools)
53 | self.convs = nn.ModuleList(convs)
54 | self.dynas = nn.ModuleList(dynas)
55 | self.relu = nn.GELU()
56 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False)
57 |
58 | def forward(self, x):
59 | x_size = x.size()
60 | resl = x
61 | for i in range(len(self.pools_sizes)):
62 | if i == 0:
63 | y = self.dynas[i](self.convs[i](self.pools[i](x)))
64 | else:
65 | y = self.dynas[i](self.convs[i](self.pools[i](x)+y_up))
66 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True))
67 | if i != len(self.pools_sizes)-1:
68 | y_up = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True)
69 | resl = self.relu(resl)
70 | resl = self.conv_sum(resl)
71 |
72 | return resl
73 |
74 | class dynamic_filter(nn.Module):
75 | def __init__(self, inchannels, kernel_size=3, dilation=1, stride=1, group=8):
76 | super(dynamic_filter, self).__init__()
77 | self.stride = stride
78 | self.kernel_size = kernel_size
79 | self.group = group
80 | self.dilation = dilation
81 |
82 | self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False)
83 | self.bn = nn.BatchNorm2d(group*kernel_size**2)
84 | self.act = nn.Tanh()
85 |
86 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
87 | self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True)
88 | self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True)
89 | self.pad = nn.ReflectionPad2d(self.dilation*(kernel_size-1)//2)
90 |
91 | self.ap = nn.AdaptiveAvgPool2d((1, 1))
92 | self.gap = nn.AdaptiveAvgPool2d(1)
93 |
94 | self.inside_all = nn.Parameter(torch.zeros(inchannels,1,1), requires_grad=True)
95 |
96 | def forward(self, x):
97 | identity_input = x
98 | low_filter = self.ap(x)
99 | low_filter = self.conv(low_filter)
100 | low_filter = self.bn(low_filter)
101 |
102 | n, c, h, w = x.shape
103 | x = F.unfold(self.pad(x), kernel_size=self.kernel_size, dilation=self.dilation).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w)
104 |
105 | n,c1,p,q = low_filter.shape
106 | low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2)
107 |
108 | low_filter = self.act(low_filter)
109 |
110 | low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w)
111 |
112 | out_low = low_part * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input)
113 |
114 | out_low = out_low * self.lamb_l[None,:,None,None]
115 |
116 | out_high = (identity_input) * (self.lamb_h[None,:,None,None] + 1.)
117 |
118 | return out_low + out_high
119 |
120 |
121 | class cubic_attention(nn.Module):
122 | def __init__(self, dim, group, dilation, kernel) -> None:
123 | super().__init__()
124 |
125 | self.H_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel)
126 | self.W_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel, H=False)
127 | self.gamma = nn.Parameter(torch.zeros(dim,1,1))
128 | self.beta = nn.Parameter(torch.ones(dim,1,1))
129 |
130 | def forward(self, x):
131 | out = self.H_spatial_att(x)
132 | out = self.W_spatial_att(out)
133 | return self.gamma * out + x * self.beta
134 |
135 |
136 | class spatial_strip_att(nn.Module):
137 | def __init__(self, dim, kernel=3, dilation=1, group=2, H=True) -> None:
138 | super().__init__()
139 |
140 | self.k = kernel
141 | pad = dilation*(kernel-1) // 2
142 | self.kernel = (1, kernel) if H else (kernel, 1)
143 | self.padding = (kernel//2, 1) if H else (1, kernel//2)
144 | self.dilation = dilation
145 | self.group = group
146 | self.pad = nn.ReflectionPad2d((pad, pad, 0, 0)) if H else nn.ReflectionPad2d((0, 0, pad, pad))
147 | self.conv = nn.Conv2d(dim, group*kernel, kernel_size=1, stride=1, bias=False)
148 | self.ap = nn.AdaptiveAvgPool2d((1, 1))
149 | self.filter_act = nn.Tanh()
150 | self.inside_all = nn.Parameter(torch.zeros(dim,1,1), requires_grad=True)
151 | self.lamb_l = nn.Parameter(torch.zeros(dim), requires_grad=True)
152 | self.lamb_h = nn.Parameter(torch.zeros(dim), requires_grad=True)
153 | gap_kernel = (None,1) if H else (1, None)
154 | self.gap = nn.AdaptiveAvgPool2d(gap_kernel)
155 |
156 | def forward(self, x):
157 | identity_input = x.clone()
158 | filter = self.ap(x)
159 | filter = self.conv(filter)
160 | n, c, h, w = x.shape
161 | x = F.unfold(self.pad(x), kernel_size=self.kernel, dilation=self.dilation).reshape(n, self.group, c//self.group, self.k, h*w)
162 | n, c1, p, q = filter.shape
163 | filter = filter.reshape(n, c1//self.k, self.k, p*q).unsqueeze(2)
164 | filter = self.filter_act(filter)
165 | out = torch.sum(x * filter, dim=3).reshape(n, c, h, w)
166 |
167 | out_low = out * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input)
168 | out_low = out_low * self.lamb_l[None,:,None,None]
169 | out_high = identity_input * (self.lamb_h[None,:,None,None]+1.)
170 |
171 | return out_low + out_high
172 |
173 |
174 | class MultiShapeKernel(nn.Module):
175 | def __init__(self, dim, kernel_size=3, dilation=1, group=8):
176 | super().__init__()
177 |
178 | self.square_att = dynamic_filter(inchannels=dim, dilation=dilation, group=group, kernel_size=kernel_size)
179 | self.strip_att = cubic_attention(dim, group=group, dilation=dilation, kernel=kernel_size)
180 |
181 | def forward(self, x):
182 |
183 | x1 = self.strip_att(x)
184 | x2 = self.square_att(x)
185 |
186 | return x1+x2
187 |
188 |
189 |
--------------------------------------------------------------------------------
/Motion_Deblurring/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from data import train_dataloader
4 | from utils import Adder, Timer, check_lr
5 | from torch.utils.tensorboard import SummaryWriter
6 | from valid import _valid
7 | import torch.nn.functional as F
8 |
9 | from warmup_scheduler import GradualWarmupScheduler
10 |
11 | def _train(model, args):
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 | criterion = torch.nn.L1Loss()
14 |
15 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8)
16 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker)
17 | max_iter = len(dataloader)
18 | warmup_epochs=3
19 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6)
20 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
21 | scheduler.step()
22 | epoch = 1
23 | if args.resume:
24 | state = torch.load(args.resume)
25 | epoch = state['epoch']
26 | optimizer.load_state_dict(state['optimizer'])
27 | model.load_state_dict(state['model'])
28 | print('Resume from %d'%epoch)
29 | epoch += 1
30 |
31 | writer = SummaryWriter()
32 | epoch_pixel_adder = Adder()
33 | epoch_fft_adder = Adder()
34 | iter_pixel_adder = Adder()
35 | iter_fft_adder = Adder()
36 | epoch_timer = Timer('m')
37 | iter_timer = Timer('m')
38 | best_psnr=-1
39 |
40 | for epoch_idx in range(epoch, args.num_epoch + 1):
41 |
42 | epoch_timer.tic()
43 | iter_timer.tic()
44 | for iter_idx, batch_data in enumerate(dataloader):
45 |
46 | input_img, label_img = batch_data
47 | input_img = input_img.to(device)
48 | label_img = label_img.to(device)
49 |
50 | optimizer.zero_grad()
51 | pred_img = model(input_img)
52 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear')
53 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear')
54 | l1 = criterion(pred_img[0], label_img4)
55 | l2 = criterion(pred_img[1], label_img2)
56 | l3 = criterion(pred_img[2], label_img)
57 | loss_content = l1+l2+l3
58 |
59 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1))
60 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1)
61 |
62 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1))
63 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1)
64 |
65 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1))
66 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1)
67 |
68 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1))
69 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1)
70 |
71 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1))
72 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1)
73 |
74 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1))
75 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1)
76 |
77 | f1 = criterion(pred_fft1, label_fft1)
78 | f2 = criterion(pred_fft2, label_fft2)
79 | f3 = criterion(pred_fft3, label_fft3)
80 | loss_fft = f1+f2+f3
81 |
82 | loss = loss_content + 0.1 * loss_fft
83 | loss.backward()
84 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001)
85 | optimizer.step()
86 |
87 | iter_pixel_adder(loss_content.item())
88 | iter_fft_adder(loss_fft.item())
89 |
90 | epoch_pixel_adder(loss_content.item())
91 | epoch_fft_adder(loss_fft.item())
92 |
93 | if (iter_idx + 1) % args.print_freq == 0:
94 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % (
95 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(),
96 | iter_fft_adder.average()))
97 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter)
98 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter)
99 | iter_timer.tic()
100 | iter_pixel_adder.reset()
101 | iter_fft_adder.reset()
102 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl')
103 | torch.save({'model': model.state_dict(),
104 | 'optimizer': optimizer.state_dict(),
105 | 'epoch': epoch_idx}, overwrite_name)
106 |
107 | if epoch_idx % args.save_freq == 0:
108 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx)
109 | torch.save({'model': model.state_dict()}, save_name)
110 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % (
111 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average()))
112 | epoch_fft_adder.reset()
113 | epoch_pixel_adder.reset()
114 | scheduler.step()
115 |
116 | if epoch_idx % args.valid_freq == 0:
117 | val_gopro = _valid(model, args, epoch_idx)
118 | print('%03d epoch \n Average GOPRO PSNR %.2f dB' % (epoch_idx, val_gopro))
119 | writer.add_scalar('PSNR_GOPRO', val_gopro, epoch_idx)
120 | if val_gopro >= best_psnr:
121 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl'))
122 |
123 | save_name = os.path.join(args.model_save_dir, 'Final.pkl')
124 | torch.save({'model': model.state_dict()}, save_name)
125 |
--------------------------------------------------------------------------------
/Motion_Deblurring/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 |
4 |
5 | class Adder(object):
6 | def __init__(self):
7 | self.count = 0
8 | self.num = float(0)
9 |
10 | def reset(self):
11 | self.count = 0
12 | self.num = float(0)
13 |
14 | def __call__(self, num):
15 | self.count += 1
16 | self.num += num
17 |
18 | def average(self):
19 | return self.num / self.count
20 |
21 |
22 | class Timer(object):
23 | def __init__(self, option='s'):
24 | self.tm = 0
25 | self.option = option
26 | if option == 's':
27 | self.devider = 1
28 | elif option == 'm':
29 | self.devider = 60
30 | else:
31 | self.devider = 3600
32 |
33 | def tic(self):
34 | self.tm = time.time()
35 |
36 | def toc(self):
37 | return (time.time() - self.tm) / self.devider
38 |
39 |
40 | def check_lr(optimizer):
41 | for i, param_group in enumerate(optimizer.param_groups):
42 | lr = param_group['lr']
43 | return lr
44 |
--------------------------------------------------------------------------------
/Motion_Deblurring/valid.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import functional as F
3 | from data import valid_dataloader
4 | from utils import Adder
5 | import os
6 | from skimage.metrics import peak_signal_noise_ratio
7 | import torch.nn.functional as f
8 |
9 |
10 | def _valid(model, args, ep):
11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12 | gopro = valid_dataloader(args.data_dir, batch_size=1, num_workers=0)
13 | model.eval()
14 | psnr_adder = Adder()
15 |
16 | with torch.no_grad():
17 | print('Start GoPro Evaluation')
18 | factor = 32
19 | for idx, data in enumerate(gopro):
20 | input_img, label_img = data
21 | input_img = input_img.to(device)
22 |
23 | h, w = input_img.shape[2], input_img.shape[3]
24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor)
25 | padh = H-h if h%factor!=0 else 0
26 | padw = W-w if w%factor!=0 else 0
27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect')
28 |
29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))):
30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep)))
31 |
32 | pred = model(input_img)[2]
33 | pred = pred[:,:,:h,:w]
34 |
35 | pred_clip = torch.clamp(pred, 0, 1)
36 | p_numpy = pred_clip.squeeze(0).cpu().numpy()
37 | label_numpy = label_img.squeeze(0).cpu().numpy()
38 |
39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1)
40 |
41 | psnr_adder(psnr)
42 | print('\r%03d'%idx, end=' ')
43 |
44 | print('\n')
45 | model.train()
46 | return psnr_adder.average()
47 |
--------------------------------------------------------------------------------
/README.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Revitalizing Convolutional Network for Image Restoration
6 |
21 |
22 |
23 |
24 |
31 |
42 |
53 |
66 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 | 
148 | Revitalizing Convolutional Network for Image Restoration
149 | The official pytorch implementation of the paper Revitalizing Convolutional Network for Image Restoration
150 | (T-PAMI'24)
151 | Yuning Cui, Wenqi Ren, Xiaochun Cao, Alois Knoll
152 | Installation
153 | The project is built with PyTorch 3.8, PyTorch 1.8.1. CUDA 10.2, cuDNN 7.6.5
154 | For installing, follow these instructions:
155 | conda install pytorch=1.8.1 torchvision=0.9.1 -c pytorch
156 | pip install tensorboard einops scikit-image pytorch_msssim opencv-python
157 |
158 | Install warmup scheduler:
159 | cd pytorch-gradual-warmup-lr/
160 | python setup.py install
161 | cd ..
162 |
163 | Training and Evaluation
164 | Please refer to respective directories.
165 | Results [Download]
166 |
167 |
168 |
169 | Model |
170 | Parameters |
171 | FLOPs |
172 |
173 |
174 |
175 |
176 | ConvIR-S (small) |
177 | 5.53M |
178 | 42.1G |
179 |
180 |
181 | ConvIR-B (base) |
182 | 8.63M |
183 | 71.22G |
184 |
185 |
186 | ConvIR-L (large) |
187 | 14.83M |
188 | 129.34G |
189 |
190 |
191 |
192 |
193 |
194 |
195 | Task |
196 | Dataset |
197 | PSNR |
198 | SSIM |
199 |
200 |
201 |
202 |
203 | Image Dehazing |
204 | SOTS-Indoor |
205 | 41.53/42.72 |
206 | 0.996/0.997 |
207 |
208 |
209 | |
210 | SOTS-Outdoor |
211 | 37.95/39.42 |
212 | 0.994/0.996 |
213 |
214 |
215 | |
216 | Haze4K |
217 | 33.36/34.15/34.50 |
218 | 0.99/0.99/0.99 |
219 |
220 |
221 | |
222 | Dense-Haze |
223 | 17.45/16.86 |
224 | 0.648/0.621 |
225 |
226 |
227 | |
228 | NH-HAZE |
229 | 20.65/20.66 |
230 | 0.807/0.802 |
231 |
232 |
233 | |
234 | O-HAZE |
235 | 25.25/25.36 |
236 | 0.784/0.780 |
237 |
238 |
239 | |
240 | I-HAZE |
241 | 21.95/22.44 |
242 | 0.888/0.887 |
243 |
244 |
245 | |
246 | SateHaze-1k-Thin/Moderate/Thick |
247 | 25.11/26.79/22.65 |
248 | 0.978/0.978/0.950 |
249 |
250 |
251 | |
252 | NHR |
253 | 28.85/29.49 |
254 | 0.981/0.983 |
255 |
256 |
257 | |
258 | GTA5 |
259 | 31.68/31.83 |
260 | 0.917/0.921 |
261 |
262 |
263 | Image Desnowing |
264 | CSD |
265 | 38.43/39.10 |
266 | 0.99/0.99 |
267 |
268 |
269 | |
270 | SRRS |
271 | 32.25/32.39 |
272 | 0.98/0.98 |
273 |
274 |
275 | |
276 | Snow100K |
277 | 33.79/33.92 |
278 | 0.95/0.96 |
279 |
280 |
281 | Image Deraining |
282 | Test100 |
283 | 31.40 |
284 | 0.919 |
285 |
286 |
287 | |
288 | Test2800 |
289 | 33.73 |
290 | 0.937 |
291 |
292 |
293 | Defocus Deblurring |
294 | DPDD |
295 | 26.06/26.16/26.36 |
296 | 0.810/0.814/0.820 |
297 |
298 |
299 | Motion Deblurring |
300 | GoPro |
301 | 33.28 |
302 | 0.963 |
303 |
304 |
305 | |
306 | RSBlur |
307 | 34.06 |
308 | 0.868 |
309 |
310 |
311 |
312 | Citation
313 | @article{cui2024revitalizing,
314 | title={Revitalizing Convolutional Network for Image Restoration},
315 | author={Cui, Yuning and Ren, Wenqi and Cao, Xiaochun and Knoll, Alois},
316 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
317 | year={2024},
318 | publisher={IEEE}
319 | }
320 |
321 | @inproceedings{cui2023irnext,
322 | title={IRNeXt: Rethinking Convolutional Network Design for Image Restoration},
323 | author={Cui, Yuning and Ren, Wenqi and Yang, Sining and Cao, Xiaochun and Knoll, Alois},
324 | booktitle={International Conference on Machine Learning},
325 | pages={6545--6564},
326 | year={2023},
327 | organization={PMLR}
328 | }
329 |
330 |
331 | Should you have any problem, please contact Yuning Cui.
332 |
333 |
334 |
335 |
336 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://paperswithcode.com/sota/image-dehazing-on-sots-indoor)
2 | [](https://paperswithcode.com/sota/image-dehazing-on-sots-outdoor)
3 | [](https://paperswithcode.com/sota/image-dehazing-on-haze4k)
4 | [](https://paperswithcode.com/sota/image-dehazing-on-i-haze)
5 | [](https://paperswithcode.com/sota/image-dehazing-on-o-haze)
6 | [](https://paperswithcode.com/sota/snow-removal-on-snow100k)
7 | [](https://paperswithcode.com/sota/snow-removal-on-srrs)
8 |
9 |
10 | ## Revitalizing Convolutional Network for Image Restoration
11 |
12 | The official pytorch implementation of the paper **[Revitalizing Convolutional Network for Image Restoration](https://ieeexplore.ieee.org/abstract/document/10571568)**
13 |
14 | #### Yuning Cui, Wenqi Ren, Xiaochun Cao, Alois Knoll
15 |
16 | ## News
17 | All resulting images and pre-trained models are available in the provided links.
18 |
19 | **11/26/2024** Code for real haze and haze4k are released.
20 |
21 | **07/22/2024** We release the code for dehazing (ITS/OTS), desnowing, deraining, and motion deblurring.
22 |
23 | ## Pretrained models
24 | [gdrive](https://drive.google.com/drive/folders/1_5fO2p5xoWO5cUEVoXJ7x3Uhg1AP18FQ?usp=sharing), [百度网盘](https://pan.baidu.com/s/1oYzdxs3FvLJMWx7S5GW0rA?pwd=dvta)
25 |
26 |
27 | ## Installation
28 | The project is built with PyTorch 3.8, PyTorch 1.8.1. CUDA 10.2, cuDNN 7.6.5
29 | For installing, follow these instructions:
30 | ~~~
31 | conda install pytorch=1.8.1 torchvision=0.9.1 -c pytorch
32 | pip install tensorboard einops scikit-image pytorch_msssim opencv-python
33 | ~~~
34 |
35 | *Please use the pillow package downloaded by Conda rather than pip.*
36 |
37 |
38 |
39 | Install warmup scheduler:
40 | ~~~
41 | cd pytorch-gradual-warmup-lr/
42 | python setup.py install
43 | cd ..
44 | ~~~
45 | ## Training and Evaluation
46 | Please refer to respective directories.
47 | ## Results
48 | ### Visualization Results: [gdrive](https://drive.google.com/drive/folders/1YiuiYG36zqgHsoUhbk6UJAAywGc0avnj?usp=sharing), [百度网盘](https://pan.baidu.com/s/1mDlRfEoMSi8vpCLRUxk2tQ?pwd=y2gv)
49 | |Model|Parameters|FLOPs|
50 | |------|-----|-----|
51 | |*ConvIR-S (small)*|5.53M|42.1G|
52 | |**ConvIR-B (base)**| 8.63M|71.22G|
53 | |ConvIR-L (large)| 14.83M |129.34G|
54 |
55 | |Task|Dataset|PSNR|SSIM|
56 | |----|------|-----|----|
57 | |**Image Dehazing**|SOTS-Indoor|*41.53*/**42.72**|*0.996*/**0.997**|
58 | ||SOTS-Outdoor|*37.95*/**39.42**|*0.994*/**0.996**|
59 | ||Haze4K|*33.36*/**34.15**/34.50|*0.99*/**0.99**/0.99|
60 | ||Dense-Haze|*17.45*/**16.86**|*0.648*/**0.621**|
61 | ||NH-HAZE|*20.65*/**20.66**|*0.807*/**0.802**|
62 | ||O-HAZE|*25.25*/**25.36**|*0.784*/**0.780**|
63 | ||I-HAZE|*21.95*/**22.44**|*0.888*/**0.887**|
64 | ||SateHaze-1k-Thin/Moderate/Thick|*25.11*/*26.79*/*22.65*|*0.978*/*0.978*/*0.950*|
65 | ||NHR|*28.85*/**29.49**|*0.981*/**0.983**|
66 | ||GTA5|*31.68*/**31.83**|*0.917*/**0.921**|
67 | |**Image Desnowing**|CSD|*38.43*/**39.10**|*0.99*/**0.99**|
68 | ||SRRS|*32.25*/**32.39**|*0.98*/**0.98**|
69 | ||Snow100K|*33.79*/**33.92**|*0.95*/**0.96**|
70 | |**Image Deraining**|Test100|31.40|0.919|
71 | ||Test2800|33.73|0.937|
72 | |**Defocus Deblurring**|DPDD|*26.06*/**26.16**/26.36|*0.810*/**0.814**/0.820|
73 | |**Motion Deblurring**|GoPro|33.28|0.963|
74 | ||RSBlur|34.06|0.868|
75 |
76 |
77 | ## Citation
78 | ~~~
79 | @article{cui2024revitalizing,
80 | title={Revitalizing Convolutional Network for Image Restoration},
81 | author={Cui, Yuning and Ren, Wenqi and Cao, Xiaochun and Knoll, Alois},
82 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
83 | year={2024},
84 | publisher={IEEE}
85 | }
86 |
87 | @inproceedings{cui2023irnext,
88 | title={IRNeXt: Rethinking Convolutional Network Design for Image Restoration},
89 | author={Cui, Yuning and Ren, Wenqi and Yang, Sining and Cao, Xiaochun and Knoll, Alois},
90 | booktitle={International Conference on Machine Learning},
91 | pages={6545--6564},
92 | year={2023},
93 | organization={PMLR}
94 | }
95 | ~~~
96 |
97 | ## Contact
98 | Should you have any problem, please contact Yuning Cui.
99 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import setuptools
6 |
7 | _VERSION = '0.3'
8 |
9 | REQUIRED_PACKAGES = [
10 | ]
11 |
12 | DEPENDENCY_LINKS = [
13 | ]
14 |
15 | setuptools.setup(
16 | name='warmup_scheduler',
17 | version=_VERSION,
18 | description='Gradually Warm-up LR Scheduler for Pytorch',
19 | install_requires=REQUIRED_PACKAGES,
20 | dependency_links=DEPENDENCY_LINKS,
21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr',
22 | license='MIT License',
23 | package_dir={},
24 | packages=setuptools.find_packages(exclude=['tests']),
25 | )
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 | from warmup_scheduler.scheduler import GradualWarmupScheduler
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR
3 | from torch.optim.sgd import SGD
4 |
5 | from warmup_scheduler import GradualWarmupScheduler
6 |
7 |
8 | if __name__ == '__main__':
9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
10 | optim = SGD(model, 0.1)
11 |
12 | # scheduler_warmup is chained with schduler_steplr
13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
15 |
16 | # this zero gradient update is needed to avoid a warning message, issue #8.
17 | optim.zero_grad()
18 | optim.step()
19 |
20 | for epoch in range(1, 20):
21 | scheduler_warmup.step(epoch)
22 | print(epoch, optim.param_groups[0]['lr'])
23 |
24 | optim.step() # backward pass (update network)
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """ Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
12 | total_epoch: target learning rate is reached at total_epoch, gradually
13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14 | """
15 |
16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
17 | self.multiplier = multiplier
18 | if self.multiplier < 1.:
19 | raise ValueError('multiplier should be greater thant or equal to 1.')
20 | self.total_epoch = total_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super(GradualWarmupScheduler, self).__init__(optimizer)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.total_epoch:
27 | if self.after_scheduler:
28 | if not self.finished:
29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30 | self.finished = True
31 | return self.after_scheduler.get_lr()
32 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
33 |
34 | if self.multiplier == 1.0:
35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
36 | else:
37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
38 |
39 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
40 | if epoch is None:
41 | epoch = self.last_epoch + 1
42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
43 | if self.last_epoch <= self.total_epoch:
44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
46 | param_group['lr'] = lr
47 | else:
48 | if epoch is None:
49 | self.after_scheduler.step(metrics, None)
50 | else:
51 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
52 |
53 | def step(self, epoch=None, metrics=None):
54 | if type(self.after_scheduler) != ReduceLROnPlateau:
55 | if self.finished and self.after_scheduler:
56 | if epoch is None:
57 | self.after_scheduler.step(None)
58 | else:
59 | self.after_scheduler.step(epoch - self.total_epoch)
60 | else:
61 | return super(GradualWarmupScheduler, self).step(epoch)
62 | else:
63 | self.step_ReduceLROnPlateau(metrics, epoch)
--------------------------------------------------------------------------------