├── DeepRFT_MIMO.py
├── README.md
├── data_RGB.py
├── dataset_RGB.py
├── doconv_pytorch.py
├── evaluate_GOPRO.m
├── evaluate_RealBlur.py
├── get_parameter_number.py
├── images
├── framework.png
└── psnr_params_flops.png
├── layers.py
├── license.md
├── losses.py
├── pytorch-gradual-warmup-lr
├── setup.py
└── warmup_scheduler
│ ├── __init__.py
│ ├── run.py
│ └── scheduler.py
├── test.py
├── test_speed.py
├── train.py
├── train_wo_warmup.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-38.pyc
├── dataset_utils.cpython-38.pyc
├── dir_utils.cpython-38.pyc
├── image_utils.cpython-38.pyc
└── model_utils.cpython-38.pyc
├── dataset_utils.py
├── dir_utils.py
├── image_utils.py
└── model_utils.py
/DeepRFT_MIMO.py:
--------------------------------------------------------------------------------
1 | from layers import *
2 |
3 |
4 |
5 | class EBlock(nn.Module):
6 | def __init__(self, out_channel, num_res=8, ResBlock=ResBlock):
7 | super(EBlock, self).__init__()
8 |
9 | layers = [ResBlock(out_channel) for _ in range(num_res)]
10 |
11 | self.layers = nn.Sequential(*layers)
12 |
13 | def forward(self, x):
14 | return self.layers(x)
15 |
16 | class DBlock(nn.Module):
17 | def __init__(self, channel, num_res=8, ResBlock=ResBlock):
18 | super(DBlock, self).__init__()
19 |
20 | layers = [ResBlock(channel) for _ in range(num_res)]
21 | self.layers = nn.Sequential(*layers)
22 |
23 | def forward(self, x):
24 | return self.layers(x)
25 |
26 | class AFF(nn.Module):
27 | def __init__(self, in_channel, out_channel, BasicConv=BasicConv):
28 | super(AFF, self).__init__()
29 | self.conv = nn.Sequential(
30 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True),
31 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
32 | )
33 |
34 | def forward(self, x1, x2, x4):
35 | x = torch.cat([x1, x2, x4], dim=1)
36 | return self.conv(x)
37 |
38 | class SCM(nn.Module):
39 | def __init__(self, out_plane, BasicConv=BasicConv, inchannel=3):
40 | super(SCM, self).__init__()
41 | self.main = nn.Sequential(
42 | BasicConv(inchannel, out_plane//4, kernel_size=3, stride=1, relu=True),
43 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True),
44 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True),
45 | BasicConv(out_plane // 2, out_plane-inchannel, kernel_size=1, stride=1, relu=True)
46 | )
47 |
48 | self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False)
49 |
50 | def forward(self, x):
51 | x = torch.cat([x, self.main(x)], dim=1)
52 | return self.conv(x)
53 |
54 | class FAM(nn.Module):
55 | def __init__(self, channel, BasicConv=BasicConv):
56 | super(FAM, self).__init__()
57 | self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False)
58 |
59 | def forward(self, x1, x2):
60 | x = x1 * x2
61 | out = x1 + self.merge(x)
62 | return out
63 |
64 | class DeepRFT_Small(nn.Module):
65 | def __init__(self, num_res=4, inference=False):
66 | super(DeepRFT_Small, self).__init__()
67 | self.inference = inference
68 |
69 | if not inference:
70 | BasicConv = BasicConv_do
71 | ResBlock = ResBlock_do_fft_bench
72 | else:
73 | BasicConv = BasicConv_do_eval
74 | ResBlock = ResBlock_do_fft_bench_eval
75 |
76 | base_channel = 32
77 |
78 | self.Encoder = nn.ModuleList([
79 | EBlock(base_channel, num_res, ResBlock=ResBlock),
80 | EBlock(base_channel*2, num_res, ResBlock=ResBlock),
81 | EBlock(base_channel*4, num_res, ResBlock=ResBlock),
82 | ])
83 |
84 | self.feat_extract = nn.ModuleList([
85 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
86 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
87 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
88 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
89 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
90 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
91 | ])
92 |
93 | self.Decoder = nn.ModuleList([
94 | DBlock(base_channel * 4, num_res, ResBlock=ResBlock),
95 | DBlock(base_channel * 2, num_res, ResBlock=ResBlock),
96 | DBlock(base_channel, num_res, ResBlock=ResBlock)
97 | ])
98 |
99 | self.Convs = nn.ModuleList([
100 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
101 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
102 | ])
103 |
104 | self.ConvsOut = nn.ModuleList(
105 | [
106 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
107 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
108 | ]
109 | )
110 |
111 | self.AFFs = nn.ModuleList([
112 | AFF(base_channel * 7, base_channel*1, BasicConv=BasicConv),
113 | AFF(base_channel * 7, base_channel*2, BasicConv=BasicConv)
114 | ])
115 |
116 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv)
117 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv)
118 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv)
119 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv)
120 |
121 | def forward(self, x):
122 | x_2 = F.interpolate(x, scale_factor=0.5)
123 | x_4 = F.interpolate(x_2, scale_factor=0.5)
124 | z2 = self.SCM2(x_2)
125 | z4 = self.SCM1(x_4)
126 |
127 | outputs = list()
128 |
129 | x_ = self.feat_extract[0](x)
130 | res1 = self.Encoder[0](x_)
131 |
132 | z = self.feat_extract[1](res1)
133 | z = self.FAM2(z, z2)
134 | res2 = self.Encoder[1](z)
135 |
136 | z = self.feat_extract[2](res2)
137 | z = self.FAM1(z, z4)
138 | z = self.Encoder[2](z)
139 |
140 | z12 = F.interpolate(res1, scale_factor=0.5)
141 | z21 = F.interpolate(res2, scale_factor=2)
142 | z42 = F.interpolate(z, scale_factor=2)
143 | z41 = F.interpolate(z42, scale_factor=2)
144 |
145 | res2 = self.AFFs[1](z12, res2, z42)
146 | res1 = self.AFFs[0](res1, z21, z41)
147 |
148 | z = self.Decoder[0](z)
149 | z_ = self.ConvsOut[0](z)
150 | z = self.feat_extract[3](z)
151 | if not self.inference:
152 | outputs.append(z_+x_4)
153 |
154 | z = torch.cat([z, res2], dim=1)
155 | z = self.Convs[0](z)
156 | z = self.Decoder[1](z)
157 | z_ = self.ConvsOut[1](z)
158 | z = self.feat_extract[4](z)
159 | if not self.inference:
160 | outputs.append(z_+x_2)
161 |
162 | z = torch.cat([z, res1], dim=1)
163 | z = self.Convs[1](z)
164 | z = self.Decoder[2](z)
165 | z = self.feat_extract[5](z)
166 | if not self.inference:
167 | outputs.append(z + x)
168 | return outputs[::-1]
169 | else:
170 | return z + x
171 | class DeepRFT_flops(nn.Module):
172 | def __init__(self, num_res=8, inference=True):
173 | super(DeepRFT_flops, self).__init__()
174 | self.inference = inference
175 | ResBlock = ResBlock_fft_bench
176 | base_channel = 32
177 |
178 | self.Encoder = nn.ModuleList([
179 | EBlock(base_channel, num_res, ResBlock=ResBlock),
180 | EBlock(base_channel*2, num_res, ResBlock=ResBlock),
181 | EBlock(base_channel*4, num_res, ResBlock=ResBlock),
182 | ])
183 |
184 | self.feat_extract = nn.ModuleList([
185 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
186 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
187 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
188 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
189 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
190 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
191 | ])
192 |
193 | self.Decoder = nn.ModuleList([
194 | DBlock(base_channel * 4, num_res, ResBlock=ResBlock),
195 | DBlock(base_channel * 2, num_res, ResBlock=ResBlock),
196 | DBlock(base_channel, num_res, ResBlock=ResBlock)
197 | ])
198 |
199 | self.Convs = nn.ModuleList([
200 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
201 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
202 | ])
203 |
204 | self.ConvsOut = nn.ModuleList(
205 | [
206 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
207 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
208 | ]
209 | )
210 |
211 | self.AFFs = nn.ModuleList([
212 | AFF(base_channel * 7, base_channel*1, BasicConv=BasicConv),
213 | AFF(base_channel * 7, base_channel*2, BasicConv=BasicConv)
214 | ])
215 |
216 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv)
217 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv)
218 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv)
219 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv)
220 |
221 | def forward(self, x):
222 | x_2 = F.interpolate(x, scale_factor=0.5)
223 | x_4 = F.interpolate(x_2, scale_factor=0.5)
224 | z2 = self.SCM2(x_2)
225 | z4 = self.SCM1(x_4)
226 |
227 | outputs = list()
228 |
229 | x_ = self.feat_extract[0](x)
230 | res1 = self.Encoder[0](x_)
231 |
232 | z = self.feat_extract[1](res1)
233 | z = self.FAM2(z, z2)
234 | res2 = self.Encoder[1](z)
235 |
236 | z = self.feat_extract[2](res2)
237 | z = self.FAM1(z, z4)
238 | z = self.Encoder[2](z)
239 |
240 | z12 = F.interpolate(res1, scale_factor=0.5)
241 | z21 = F.interpolate(res2, scale_factor=2)
242 | z42 = F.interpolate(z, scale_factor=2)
243 | z41 = F.interpolate(z42, scale_factor=2)
244 |
245 | res2 = self.AFFs[1](z12, res2, z42)
246 | res1 = self.AFFs[0](res1, z21, z41)
247 |
248 | z = self.Decoder[0](z)
249 | z_ = self.ConvsOut[0](z)
250 | z = self.feat_extract[3](z)
251 | if not self.inference:
252 | outputs.append(z_+x_4)
253 |
254 | z = torch.cat([z, res2], dim=1)
255 | z = self.Convs[0](z)
256 | z = self.Decoder[1](z)
257 | z_ = self.ConvsOut[1](z)
258 | z = self.feat_extract[4](z)
259 | if not self.inference:
260 | outputs.append(z_+x_2)
261 |
262 | z = torch.cat([z, res1], dim=1)
263 | z = self.Convs[1](z)
264 | z = self.Decoder[2](z)
265 | z = self.feat_extract[5](z)
266 | if not self.inference:
267 | outputs.append(z+x)
268 | # print(outputs)
269 | return outputs[::-1]
270 | else:
271 | return z+x
272 | class DeepRFT(nn.Module):
273 | def __init__(self, num_res=8, inference=False):
274 | super(DeepRFT, self).__init__()
275 | self.inference = inference
276 | if not inference:
277 | BasicConv = BasicConv_do
278 | ResBlock = ResBlock_do_fft_bench
279 | else:
280 | BasicConv = BasicConv_do_eval
281 | ResBlock = ResBlock_do_fft_bench_eval
282 | base_channel = 32
283 |
284 | self.Encoder = nn.ModuleList([
285 | EBlock(base_channel, num_res, ResBlock=ResBlock),
286 | EBlock(base_channel*2, num_res, ResBlock=ResBlock),
287 | EBlock(base_channel*4, num_res, ResBlock=ResBlock),
288 | ])
289 |
290 | self.feat_extract = nn.ModuleList([
291 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
292 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
293 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
294 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
295 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
296 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
297 | ])
298 |
299 | self.Decoder = nn.ModuleList([
300 | DBlock(base_channel * 4, num_res, ResBlock=ResBlock),
301 | DBlock(base_channel * 2, num_res, ResBlock=ResBlock),
302 | DBlock(base_channel, num_res, ResBlock=ResBlock)
303 | ])
304 |
305 | self.Convs = nn.ModuleList([
306 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
307 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
308 | ])
309 |
310 | self.ConvsOut = nn.ModuleList(
311 | [
312 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
313 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
314 | ]
315 | )
316 |
317 | self.AFFs = nn.ModuleList([
318 | AFF(base_channel * 7, base_channel*1, BasicConv=BasicConv),
319 | AFF(base_channel * 7, base_channel*2, BasicConv=BasicConv)
320 | ])
321 |
322 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv)
323 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv)
324 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv)
325 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv)
326 |
327 | def forward(self, x):
328 | x_2 = F.interpolate(x, scale_factor=0.5)
329 | x_4 = F.interpolate(x_2, scale_factor=0.5)
330 | z2 = self.SCM2(x_2)
331 | z4 = self.SCM1(x_4)
332 |
333 | outputs = list()
334 |
335 | x_ = self.feat_extract[0](x)
336 | res1 = self.Encoder[0](x_)
337 |
338 | z = self.feat_extract[1](res1)
339 | z = self.FAM2(z, z2)
340 | res2 = self.Encoder[1](z)
341 |
342 | z = self.feat_extract[2](res2)
343 | z = self.FAM1(z, z4)
344 | z = self.Encoder[2](z)
345 |
346 | z12 = F.interpolate(res1, scale_factor=0.5)
347 | z21 = F.interpolate(res2, scale_factor=2)
348 | z42 = F.interpolate(z, scale_factor=2)
349 | z41 = F.interpolate(z42, scale_factor=2)
350 |
351 | res2 = self.AFFs[1](z12, res2, z42)
352 | res1 = self.AFFs[0](res1, z21, z41)
353 |
354 | z = self.Decoder[0](z)
355 | z_ = self.ConvsOut[0](z)
356 | z = self.feat_extract[3](z)
357 | if not self.inference:
358 | outputs.append(z_+x_4)
359 |
360 | z = torch.cat([z, res2], dim=1)
361 | z = self.Convs[0](z)
362 | z = self.Decoder[1](z)
363 | z_ = self.ConvsOut[1](z)
364 | z = self.feat_extract[4](z)
365 | if not self.inference:
366 | outputs.append(z_+x_2)
367 |
368 | z = torch.cat([z, res1], dim=1)
369 | z = self.Convs[1](z)
370 | z = self.Decoder[2](z)
371 | z = self.feat_extract[5](z)
372 | if not self.inference:
373 | outputs.append(z+x)
374 | # print(outputs)
375 | return outputs[::-1]
376 | else:
377 | return z+x
378 | class DeepRFTPLUS(nn.Module):
379 | def __init__(self, num_res=20, inference=False):
380 | super(DeepRFTPLUS, self).__init__()
381 | # ResBlock = ResBlock_fft_bench
382 | self.inference = inference
383 | if not inference:
384 | BasicConv = BasicConv_do
385 | ResBlock = ResBlock_do_fft_bench
386 | else:
387 | BasicConv = BasicConv_do_eval
388 | ResBlock = ResBlock_do_fft_bench_eval
389 | base_channel = 32
390 |
391 | self.Encoder = nn.ModuleList([
392 | EBlock(base_channel, num_res, ResBlock=ResBlock),
393 | EBlock(base_channel*2, num_res, ResBlock=ResBlock),
394 | EBlock(base_channel*4, num_res, ResBlock=ResBlock),
395 | ])
396 |
397 | self.feat_extract = nn.ModuleList([
398 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1),
399 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2),
400 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2),
401 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True),
402 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True),
403 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1)
404 | ])
405 |
406 | self.Decoder = nn.ModuleList([
407 | DBlock(base_channel * 4, num_res, ResBlock=ResBlock),
408 | DBlock(base_channel * 2, num_res, ResBlock=ResBlock),
409 | DBlock(base_channel, num_res, ResBlock=ResBlock)
410 | ])
411 |
412 | self.Convs = nn.ModuleList([
413 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1),
414 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1),
415 | ])
416 |
417 | self.ConvsOut = nn.ModuleList(
418 | [
419 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1),
420 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1),
421 | ]
422 | )
423 |
424 | self.AFFs = nn.ModuleList([
425 | AFF(base_channel * 7, base_channel*1, BasicConv=BasicConv),
426 | AFF(base_channel * 7, base_channel*2, BasicConv=BasicConv)
427 | ])
428 |
429 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv)
430 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv)
431 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv)
432 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv)
433 |
434 | def forward(self, x):
435 | x_2 = F.interpolate(x, scale_factor=0.5)
436 | x_4 = F.interpolate(x_2, scale_factor=0.5)
437 | z2 = self.SCM2(x_2)
438 | z4 = self.SCM1(x_4)
439 |
440 | outputs = list()
441 |
442 | x_ = self.feat_extract[0](x)
443 | res1 = self.Encoder[0](x_)
444 |
445 | z = self.feat_extract[1](res1)
446 | z = self.FAM2(z, z2)
447 | res2 = self.Encoder[1](z)
448 |
449 | z = self.feat_extract[2](res2)
450 | z = self.FAM1(z, z4)
451 | z = self.Encoder[2](z)
452 |
453 | z12 = F.interpolate(res1, scale_factor=0.5)
454 | z21 = F.interpolate(res2, scale_factor=2)
455 | z42 = F.interpolate(z, scale_factor=2)
456 | z41 = F.interpolate(z42, scale_factor=2)
457 |
458 | res2 = self.AFFs[1](z12, res2, z42)
459 | res1 = self.AFFs[0](res1, z21, z41)
460 |
461 | z = self.Decoder[0](z)
462 | z_ = self.ConvsOut[0](z)
463 | z = self.feat_extract[3](z)
464 | if not self.inference:
465 | outputs.append(z_+x_4)
466 |
467 | z = torch.cat([z, res2], dim=1)
468 | z = self.Convs[0](z)
469 | z = self.Decoder[1](z)
470 | z_ = self.ConvsOut[1](z)
471 | z = self.feat_extract[4](z)
472 | if not self.inference:
473 | outputs.append(z_+x_2)
474 |
475 | z = torch.cat([z, res1], dim=1)
476 | z = self.Convs[1](z)
477 | z = self.Decoder[2](z)
478 | z = self.feat_extract[5](z)
479 | if not self.inference:
480 | outputs.append(z+x)
481 | # print(outputs)
482 | return outputs[::-1]
483 | else:
484 | return z+x
485 |
486 |
487 |
488 |
489 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://paperswithcode.com/sota/deblurring-on-gopro?p=deep-residual-fourier-transformation-for)
2 | [](https://paperswithcode.com/sota/deblurring-on-hide-trained-on-gopro?p=deep-residual-fourier-transformation-for)
3 | [](https://paperswithcode.com/sota/deblurring-on-realblur-j-1?p=deep-residual-fourier-transformation-for)
4 | [](https://paperswithcode.com/sota/deblurring-on-realblur-j-trained-on-gopro?p=deep-residual-fourier-transformation-for)
5 | [](https://paperswithcode.com/sota/deblurring-on-realblur-r?p=deep-residual-fourier-transformation-for)
6 | [](https://paperswithcode.com/sota/deblurring-on-realblur-r-trained-on-gopro?p=deep-residual-fourier-transformation-for)
7 |
8 |
9 | # Intriguing Findings of Frequency Selection for Image Deblurring (AAAI 2023)
10 | Xintian Mao, Yiming Liu, Fengze Liu, Qingli Li, Wei Shen and Yan Wang
11 |
12 | **Paper**: xxx
13 |
14 | **code**: https://github.com/INVOKERer/DeepRFT/tree/AAAI2023
15 |
16 | # Deep Residual Fourier Transformation for Single Image Deblurring
17 | Xintian Mao, Yiming Liu, Wei Shen, Qingli Li and Yan Wang
18 |
19 |
20 | **Paper**: https://arxiv.org/abs/2111.11745
21 |
22 |
23 | ## Network Architecture
24 |
25 |
26 | |
27 |
28 |
29 | Overall Framework of DeepRFT |
30 |
31 |
32 |
33 | ## Installation
34 | The model is built in PyTorch 1.8.0 and tested on Ubuntu 18.04 environment (Python3.8, CUDA11.1).
35 |
36 | For installing, follow these intructions
37 | ```
38 | conda create -n pytorch python=3.8
39 | conda activate pytorch
40 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
41 | pip install matplotlib scikit-image opencv-python yacs joblib natsort h5py tqdm kornia tensorboard ptflops
42 | ```
43 |
44 | Install warmup scheduler
45 |
46 | ```
47 | cd pytorch-gradual-warmup-lr; python setup.py install; cd ..
48 | ```
49 |
50 | ## Quick Run
51 |
52 | To test the pre-trained models of Deblur and Defocus [Google Drive](https://drive.google.com/file/d/1FoQZrbcYPGzU9xzOPI1Q1NybNUGR-ZUg/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/10DuQZiXC-Dc6jtLc9YJGbg)(提取码:phws) on your own images, run
53 | ```
54 | python test.py --weights ckpt_path_here --input_dir path_to_images --result_dir save_images_here --win_size 256 --num_res 8 [4:small, 20:plus]# deblur
55 | python test.py --weights ckpt_path_here --input_dir path_to_images --result_dir save_images_here --win_size 512 --num_res 8 # defocus
56 | ```
57 | Here is an example to train:
58 | ```
59 | python train.py
60 | ```
61 |
62 |
63 | ## Results
64 | Experiment for image deblurring.
65 |
66 |
67 | |
68 |
69 |
70 | Deblurring on GoPro Datasets. |
71 |
72 |
73 |
74 | ## Reference Code:
75 | - https://github.com/yangyanli/DO-Conv
76 | - https://github.com/swz30/MPRNet
77 | - https://github.com/chosj95/MIMO-UNet
78 | - https://github.com/codeslake/IFAN
79 |
80 | ## Citation
81 | If you use DeepRFT, please consider citing:
82 | ```
83 | @inproceedings{xint2023freqsel,
84 | title = {Intriguing Findings of Frequency Selection for Image Deblurring},
85 | author = {Xintian Mao, Yiming Liu, Fengze Liu, Qingli Li, Wei Shen and Yan Wang},
86 | booktitle = {Proceedings of the 37th AAAI Conference on Artificial Intelligence},
87 | year = {2023}
88 | }
89 | or
90 | @inproceedings{,
91 | title={Deep Residual Fourier Transformation for Single Image Deblurring},
92 | author={Xintian Mao, Yiming Liu, Wei Shen, Qingli Li, Yan Wang},
93 | booktitle={arXiv:2111.11745},
94 | year={2021}
95 | }
96 | ```
97 | ## Contact
98 | If you have any question, please contact mxt_invoker1997@163.com
99 |
--------------------------------------------------------------------------------
/data_RGB.py:
--------------------------------------------------------------------------------
1 | from dataset_RGB import *
2 |
3 |
4 | def get_training_data(rgb_dir, img_options):
5 | assert os.path.exists(rgb_dir)
6 | return DataLoaderTrain(rgb_dir, img_options)
7 |
8 | def get_validation_data(rgb_dir, img_options):
9 | assert os.path.exists(rgb_dir)
10 | return DataLoaderVal(rgb_dir, img_options)
11 |
12 | def get_test_data(rgb_dir, img_options):
13 | assert os.path.exists(rgb_dir)
14 | return DataLoaderTest(rgb_dir, img_options)
15 |
--------------------------------------------------------------------------------
/dataset_RGB.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import torch
5 | from PIL import Image
6 | import torchvision.transforms.functional as TF
7 | import random
8 |
9 |
10 | def is_image_file(filename):
11 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
12 |
13 | class DataLoaderTrain(Dataset):
14 | def __init__(self, rgb_dir, img_options=None):
15 | super(DataLoaderTrain, self).__init__()
16 |
17 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'blur')))
18 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'sharp')))
19 |
20 | self.inp_filenames = [os.path.join(rgb_dir, 'blur', x) for x in inp_files if is_image_file(x)]
21 | self.tar_filenames = [os.path.join(rgb_dir, 'sharp', x) for x in tar_files if is_image_file(x)]
22 |
23 | self.img_options = img_options
24 | self.sizex = len(self.tar_filenames) # get the size of target
25 |
26 | self.ps = self.img_options['patch_size']
27 |
28 | def __len__(self):
29 | return self.sizex
30 |
31 | def __getitem__(self, index):
32 | index_ = index % self.sizex
33 | ps = self.ps
34 |
35 | inp_path = self.inp_filenames[index_]
36 | tar_path = self.tar_filenames[index_]
37 |
38 | inp_img = Image.open(inp_path)
39 | tar_img = Image.open(tar_path)
40 |
41 | w,h = tar_img.size
42 | padw = ps-w if w 1:
61 | self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul))
62 | init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
63 | self.D.data = torch.from_numpy(init_zero)
64 |
65 | eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N))
66 | D_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N)))
67 | if self.D_mul % (M * N) != 0: # the cases when D_mul > M * N
68 | zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)])
69 | self.D_diag = Parameter(torch.cat([D_diag, zeros], dim=2), requires_grad=False)
70 | else: # the case when D_mul = M * N
71 | self.D_diag = Parameter(D_diag, requires_grad=False)
72 | ##################################################################################################
73 | if simam:
74 | self.simam_block = simam_module()
75 | if bias:
76 | self.bias = Parameter(torch.Tensor(out_channels))
77 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
78 | bound = 1 / math.sqrt(fan_in)
79 | init.uniform_(self.bias, -bound, bound)
80 | else:
81 | self.register_parameter('bias', None)
82 |
83 | def extra_repr(self):
84 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
85 | ', stride={stride}')
86 | if self.padding != (0,) * len(self.padding):
87 | s += ', padding={padding}'
88 | if self.dilation != (1,) * len(self.dilation):
89 | s += ', dilation={dilation}'
90 | if self.groups != 1:
91 | s += ', groups={groups}'
92 | if self.bias is None:
93 | s += ', bias=False'
94 | if self.padding_mode != 'zeros':
95 | s += ', padding_mode={padding_mode}'
96 | return s.format(**self.__dict__)
97 |
98 | def __setstate__(self, state):
99 | super(DOConv2d, self).__setstate__(state)
100 | if not hasattr(self, 'padding_mode'):
101 | self.padding_mode = 'zeros'
102 |
103 | def _conv_forward(self, input, weight):
104 | if self.padding_mode != 'zeros':
105 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
106 | weight, self.bias, self.stride,
107 | (0, 0), self.dilation, self.groups)
108 | return F.conv2d(input, weight, self.bias, self.stride,
109 | self.padding, self.dilation, self.groups)
110 |
111 | def forward(self, input):
112 | M = self.kernel_size[0]
113 | N = self.kernel_size[1]
114 | DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
115 | if M * N > 1:
116 | ######################### Compute DoW #################
117 | # (input_channels, D_mul, M * N)
118 | D = self.D + self.D_diag
119 | W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))
120 |
121 | # einsum outputs (out_channels // groups, in_channels, M * N),
122 | # which is reshaped to
123 | # (out_channels, in_channels // groups, M, N)
124 | DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
125 | #######################################################
126 | else:
127 | DoW = torch.reshape(self.W, DoW_shape)
128 | if self.simam:
129 | DoW_h1, DoW_h2 = torch.chunk(DoW, 2, dim=2)
130 | DoW = torch.cat([self.simam_block(DoW_h1), DoW_h2], dim=2)
131 |
132 | return self._conv_forward(input, DoW)
133 | class DOConv2d_eval(Module):
134 | """
135 | DOConv2d can be used as an alternative for torch.nn.Conv2d.
136 | The interface is similar to that of Conv2d, with one exception:
137 | 1. D_mul: the depth multiplier for the over-parameterization.
138 | Note that the groups parameter switchs between DO-Conv (groups=1),
139 | DO-DConv (groups=in_channels), DO-GConv (otherwise).
140 | """
141 | __constants__ = ['stride', 'padding', 'dilation', 'groups',
142 | 'padding_mode', 'output_padding', 'in_channels',
143 | 'out_channels', 'kernel_size', 'D_mul']
144 | __annotations__ = {'bias': Optional[torch.Tensor]}
145 |
146 | def __init__(self, in_channels, out_channels, kernel_size=3, D_mul=None, stride=1,
147 | padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros', simam=False):
148 | super(DOConv2d_eval, self).__init__()
149 |
150 | kernel_size = (kernel_size, kernel_size)
151 | stride = (stride, stride)
152 | padding = (padding, padding)
153 | dilation = (dilation, dilation)
154 |
155 | if in_channels % groups != 0:
156 | raise ValueError('in_channels must be divisible by groups')
157 | if out_channels % groups != 0:
158 | raise ValueError('out_channels must be divisible by groups')
159 | valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
160 | if padding_mode not in valid_padding_modes:
161 | raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
162 | valid_padding_modes, padding_mode))
163 | self.in_channels = in_channels
164 | self.out_channels = out_channels
165 | self.kernel_size = kernel_size
166 | self.stride = stride
167 | self.padding = padding
168 | self.dilation = dilation
169 | self.groups = groups
170 | self.padding_mode = padding_mode
171 | self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))
172 | self.simam = simam
173 | #################################### Initailization of D & W ###################################
174 | M = self.kernel_size[0]
175 | N = self.kernel_size[1]
176 | self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, M, N))
177 | init.kaiming_uniform_(self.W, a=math.sqrt(5))
178 |
179 | self.register_parameter('bias', None)
180 | def extra_repr(self):
181 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
182 | ', stride={stride}')
183 | if self.padding != (0,) * len(self.padding):
184 | s += ', padding={padding}'
185 | if self.dilation != (1,) * len(self.dilation):
186 | s += ', dilation={dilation}'
187 | if self.groups != 1:
188 | s += ', groups={groups}'
189 | if self.bias is None:
190 | s += ', bias=False'
191 | if self.padding_mode != 'zeros':
192 | s += ', padding_mode={padding_mode}'
193 | return s.format(**self.__dict__)
194 |
195 | def __setstate__(self, state):
196 | super(DOConv2d, self).__setstate__(state)
197 | if not hasattr(self, 'padding_mode'):
198 | self.padding_mode = 'zeros'
199 |
200 | def _conv_forward(self, input, weight):
201 | if self.padding_mode != 'zeros':
202 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
203 | weight, self.bias, self.stride,
204 | (0, 0), self.dilation, self.groups)
205 | return F.conv2d(input, weight, self.bias, self.stride,
206 | self.padding, self.dilation, self.groups)
207 |
208 | def forward(self, input):
209 | return self._conv_forward(input, self.W)
210 |
211 | class simam_module(torch.nn.Module):
212 | def __init__(self, e_lambda=1e-4):
213 | super(simam_module, self).__init__()
214 | self.activaton = nn.Sigmoid()
215 | self.e_lambda = e_lambda
216 |
217 | def forward(self, x):
218 | b, c, h, w = x.size()
219 | n = w * h - 1
220 | x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
221 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
222 | return x * self.activaton(y)
223 |
224 |
225 |
--------------------------------------------------------------------------------
/evaluate_GOPRO.m:
--------------------------------------------------------------------------------
1 |
2 | close all;clear all;
3 |
4 | datasets = {'GoPro'};
5 | % datasets = {'GoPro', 'HIDE'};
6 | num_set = length(datasets);
7 |
8 | for idx_set = 1:num_set
9 | file_path = strcat('./results/DeepRFT/GoPro/');
10 | gt_path = strcat('./Datasets/GoPro/test/sharp/');
11 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))];
12 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))];
13 | img_num = length(path_list);
14 |
15 | total_psnr = 0;
16 | total_ssim = 0;
17 | if img_num > 0
18 | for j = 1:img_num
19 | image_name = path_list(j).name;
20 | gt_name = gt_list(j).name;
21 | input = imread(strcat(file_path,image_name));
22 | gt = imread(strcat(gt_path, gt_name));
23 | ssim_val = ssim(input, gt);
24 | psnr_val = psnr(input, gt);
25 | total_ssim = total_ssim + ssim_val;
26 | total_psnr = total_psnr + psnr_val;
27 | end
28 | end
29 | qm_psnr = total_psnr / img_num;
30 | qm_ssim = total_ssim / img_num;
31 |
32 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim);
33 |
34 | end
35 |
--------------------------------------------------------------------------------
/evaluate_RealBlur.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from glob import glob
4 | from natsort import natsorted
5 | from skimage import io
6 | import cv2
7 | from skimage.metrics import structural_similarity
8 | from tqdm import tqdm
9 | import concurrent.futures
10 |
11 | def image_align(deblurred, gt):
12 | # this function is based on kohler evaluation code
13 | z = deblurred
14 | c = np.ones_like(z)
15 | x = gt
16 |
17 | zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching
18 |
19 | warp_mode = cv2.MOTION_HOMOGRAPHY
20 | warp_matrix = np.eye(3, 3, dtype=np.float32)
21 |
22 | # Specify the number of iterations.
23 | number_of_iterations = 100
24 |
25 | termination_eps = 0
26 |
27 | criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
28 | number_of_iterations, termination_eps)
29 |
30 | # Run the ECC algorithm. The results are stored in warp_matrix.
31 | (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5)
32 |
33 | target_shape = x.shape
34 | shift = warp_matrix
35 |
36 | zr = cv2.warpPerspective(
37 | zs,
38 | warp_matrix,
39 | (target_shape[1], target_shape[0]),
40 | flags=cv2.INTER_CUBIC+ cv2.WARP_INVERSE_MAP,
41 | borderMode=cv2.BORDER_REFLECT)
42 |
43 | cr = cv2.warpPerspective(
44 | np.ones_like(zs, dtype='float32'),
45 | warp_matrix,
46 | (target_shape[1], target_shape[0]),
47 | flags=cv2.INTER_NEAREST+ cv2.WARP_INVERSE_MAP,
48 | borderMode=cv2.BORDER_CONSTANT,
49 | borderValue=0)
50 |
51 | zr = zr * cr
52 | xr = x * cr
53 |
54 | return zr, xr, cr, shift
55 |
56 | def compute_psnr(image_true, image_test, image_mask, data_range=None):
57 | # this function is based on skimage.metrics.peak_signal_noise_ratio
58 | err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask)
59 | return 10 * np.log10((data_range ** 2) / err)
60 |
61 |
62 | def compute_ssim(tar_img, prd_img, cr1):
63 | ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, use_sample_covariance=False, data_range = 1.0, full=True)
64 | ssim_map = ssim_map * cr1
65 | r = int(3.5 * 1.5 + 0.5) # radius as in ndimage
66 | win_size = 2 * r + 1
67 | pad = (win_size - 1) // 2
68 | ssim = ssim_map[pad:-pad,pad:-pad,:]
69 | crop_cr1 = cr1[pad:-pad,pad:-pad,:]
70 | ssim = ssim.sum(axis=0).sum(axis=0)/crop_cr1.sum(axis=0).sum(axis=0)
71 | ssim = np.mean(ssim)
72 | return ssim
73 |
74 | def proc(filename):
75 | tar,prd = filename
76 | tar_img = io.imread(tar)
77 | prd_img = io.imread(prd)
78 |
79 | tar_img = tar_img.astype(np.float32)/255.0
80 | prd_img = prd_img.astype(np.float32)/255.0
81 |
82 | prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img)
83 |
84 | PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1)
85 | SSIM = compute_ssim(tar_img, prd_img, cr1)
86 | return (PSNR,SSIM)
87 |
88 | datasets = ['RealBlur_J', 'RealBlur_R']
89 | # datasets = ['RealBlur_J']
90 | # datasets = ['RealBlur_R']
91 | for dataset in datasets:
92 |
93 | file_path = os.path.join('./results/DeepRFT', dataset)
94 | gt_path = os.path.join('./Datasets/RealBlur', dataset, 'test/sharp')
95 |
96 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg')))
97 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg')))
98 |
99 | assert len(path_list) != 0, "Predicted files not found"
100 | assert len(gt_list) != 0, "Target files not found"
101 |
102 | psnr, ssim = [], []
103 |
104 | img_files = [(i, j) for i,j in zip(gt_list,path_list)]
105 | # print(img_files)
106 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor:
107 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)):
108 |
109 | psnr.append(PSNR_SSIM[0])
110 | ssim.append(PSNR_SSIM[1])
111 | # print(filename, PSNR_SSIM[0])
112 |
113 | avg_psnr = sum(psnr)/len(psnr)
114 | avg_ssim = sum(ssim)/len(ssim)
115 |
116 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim))
117 |
--------------------------------------------------------------------------------
/get_parameter_number.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def get_parameter_number(net):
4 | total_num = sum(np.prod(p.size()) for p in net.parameters())
5 | trainable_num = sum(np.prod(p.size()) for p in net.parameters() if p.requires_grad)
6 | print('Total: ', total_num)
7 | print('Trainable: ', trainable_num)
8 |
9 |
10 | if __name__=='__main__':
11 | from DeepRFT_MIMO import DeepRFT_flops as Net
12 | import torch
13 | from ptflops import get_model_complexity_info
14 | with torch.cuda.device(0):
15 | net = Net()
16 | macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True,
17 | print_per_layer_stat=True, verbose=True)
18 | print('{:<30} {:<8}'.format('Computational complexity: ', macs))
19 | print('{:<30} {:<8}'.format('Number of parameters: ', params))
20 |
--------------------------------------------------------------------------------
/images/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/images/framework.png
--------------------------------------------------------------------------------
/images/psnr_params_flops.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/images/psnr_params_flops.png
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | from doconv_pytorch import *
2 |
3 |
4 | class BasicConv(nn.Module):
5 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=False, norm=False, relu=True, transpose=False,
6 | channel_shuffle_g=0, norm_method=nn.BatchNorm2d, groups=1):
7 | super(BasicConv, self).__init__()
8 | self.channel_shuffle_g = channel_shuffle_g
9 | self.norm = norm
10 | if bias and norm:
11 | bias = False
12 |
13 | padding = kernel_size // 2
14 | layers = list()
15 | if transpose:
16 | padding = kernel_size // 2 - 1
17 | layers.append(
18 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups))
19 | else:
20 | layers.append(
21 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups))
22 | if norm:
23 | layers.append(norm_method(out_channel))
24 | elif relu:
25 | layers.append(nn.ReLU(inplace=True))
26 |
27 | self.main = nn.Sequential(*layers)
28 |
29 | def forward(self, x):
30 | return self.main(x)
31 |
32 | class BasicConv_do(nn.Module):
33 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, bias=False, norm=False, relu=True, transpose=False,
34 | relu_method=nn.ReLU, groups=1, norm_method=nn.BatchNorm2d):
35 | super(BasicConv_do, self).__init__()
36 | if bias and norm:
37 | bias = False
38 |
39 | padding = kernel_size // 2
40 | layers = list()
41 | if transpose:
42 | padding = kernel_size // 2 - 1
43 | layers.append(
44 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
45 | else:
46 | layers.append(
47 | DOConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups))
48 | if norm:
49 | layers.append(norm_method(out_channel))
50 | if relu:
51 | if relu_method == nn.ReLU:
52 | layers.append(nn.ReLU(inplace=True))
53 | elif relu_method == nn.LeakyReLU:
54 | layers.append(nn.LeakyReLU(inplace=True))
55 | else:
56 | layers.append(relu_method())
57 | self.main = nn.Sequential(*layers)
58 |
59 | def forward(self, x):
60 | return self.main(x)
61 |
62 | class BasicConv_do_eval(nn.Module):
63 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=False, norm=False, relu=True, transpose=False,
64 | relu_method=nn.ReLU, groups=1, norm_method=nn.BatchNorm2d):
65 | super(BasicConv_do_eval, self).__init__()
66 | if bias and norm:
67 | bias = False
68 |
69 | padding = kernel_size // 2
70 | layers = list()
71 | if transpose:
72 | padding = kernel_size // 2 - 1
73 | layers.append(
74 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias))
75 | else:
76 | layers.append(
77 | DOConv2d_eval(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups))
78 | if norm:
79 | layers.append(norm_method(out_channel))
80 | if relu:
81 | if relu_method == nn.ReLU:
82 | layers.append(nn.ReLU(inplace=True))
83 | elif relu_method == nn.LeakyReLU:
84 | layers.append(nn.LeakyReLU(inplace=True))
85 | else:
86 | layers.append(relu_method())
87 | self.main = nn.Sequential(*layers)
88 |
89 | def forward(self, x):
90 | return self.main(x)
91 |
92 | class ResBlock(nn.Module):
93 | def __init__(self, out_channel):
94 | super(ResBlock, self).__init__()
95 | self.main = nn.Sequential(
96 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=True, norm=False),
97 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False, norm=False)
98 | )
99 |
100 | def forward(self, x):
101 | return self.main(x) + x
102 |
103 | class ResBlock_do(nn.Module):
104 | def __init__(self, out_channel):
105 | super(ResBlock_do, self).__init__()
106 | self.main = nn.Sequential(
107 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=True),
108 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
109 | )
110 |
111 | def forward(self, x):
112 | return self.main(x) + x
113 |
114 | class ResBlock_do_eval(nn.Module):
115 | def __init__(self, out_channel):
116 | super(ResBlock_do_eval, self).__init__()
117 | self.main = nn.Sequential(
118 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=True),
119 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
120 | )
121 |
122 | def forward(self, x):
123 | return self.main(x) + x
124 |
125 |
126 | class ResBlock_do_fft_bench(nn.Module):
127 | def __init__(self, out_channel, norm='backward'):
128 | super(ResBlock_do_fft_bench, self).__init__()
129 | self.main = nn.Sequential(
130 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=True),
131 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
132 | )
133 | self.main_fft = nn.Sequential(
134 | BasicConv_do(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=True),
135 | BasicConv_do(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=False)
136 | )
137 | self.dim = out_channel
138 | self.norm = norm
139 | def forward(self, x):
140 | _, _, H, W = x.shape
141 | dim = 1
142 | y = torch.fft.rfft2(x, norm=self.norm)
143 | y_imag = y.imag
144 | y_real = y.real
145 | y_f = torch.cat([y_real, y_imag], dim=dim)
146 | y = self.main_fft(y_f)
147 | y_real, y_imag = torch.chunk(y, 2, dim=dim)
148 | y = torch.complex(y_real, y_imag)
149 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm)
150 | return self.main(x) + x + y
151 |
152 | class ResBlock_fft_bench(nn.Module):
153 | def __init__(self, n_feat, norm='backward'): # 'ortho'
154 | super(ResBlock_fft_bench, self).__init__()
155 | self.main = nn.Sequential(
156 | BasicConv(n_feat, n_feat, kernel_size=3, stride=1, relu=True),
157 | BasicConv(n_feat, n_feat, kernel_size=3, stride=1, relu=False)
158 | )
159 | self.main_fft = nn.Sequential(
160 | BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=True),
161 | BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=False)
162 | )
163 | self.dim = n_feat
164 | self.norm = norm
165 | def forward(self, x):
166 | _, _, H, W = x.shape
167 | dim = 1
168 | y = torch.fft.rfft2(x, norm=self.norm)
169 | y_imag = y.imag
170 | y_real = y.real
171 | y_f = torch.cat([y_real, y_imag], dim=dim)
172 | y = self.main_fft(y_f)
173 | y_real, y_imag = torch.chunk(y, 2, dim=dim)
174 | y = torch.complex(y_real, y_imag)
175 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm)
176 | return self.main(x) + x + y
177 | class ResBlock_do_fft_bench_eval(nn.Module):
178 | def __init__(self, out_channel, norm='backward'):
179 | super(ResBlock_do_fft_bench_eval, self).__init__()
180 | self.main = nn.Sequential(
181 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=True),
182 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=False)
183 | )
184 | self.main_fft = nn.Sequential(
185 | BasicConv_do_eval(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=True),
186 | BasicConv_do_eval(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=False)
187 | )
188 | self.dim = out_channel
189 | self.norm = norm
190 | def forward(self, x):
191 | _, _, H, W = x.shape
192 | dim = 1
193 | y = torch.fft.rfft2(x, norm=self.norm)
194 | y_imag = y.imag
195 | y_real = y.real
196 | y_f = torch.cat([y_real, y_imag], dim=dim)
197 | y = self.main_fft(y_f)
198 | y_real, y_imag = torch.chunk(y, 2, dim=dim)
199 | y = torch.complex(y_real, y_imag)
200 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm)
201 | return self.main(x) + x + y
202 |
203 | def window_partitions(x, window_size):
204 | """
205 | Args:
206 | x: (B, C, H, W)
207 | window_size (int): window size
208 |
209 | Returns:
210 | windows: (num_windows*B, C, window_size, window_size)
211 | """
212 | B, C, H, W = x.shape
213 | x = x.view(B, C, H // window_size, window_size, W // window_size, window_size)
214 | windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size)
215 | return windows
216 |
217 |
218 | def window_reverses(windows, window_size, H, W):
219 | """
220 | Args:
221 | windows: (num_windows*B, C, window_size, window_size)
222 | window_size (int): Window size
223 | H (int): Height of image
224 | W (int): Width of image
225 |
226 | Returns:
227 | x: (B, C, H, W)
228 | """
229 | # B = int(windows.shape[0] / (H * W / window_size / window_size))
230 | # print('B: ', B)
231 | # print(H // window_size)
232 | # print(W // window_size)
233 | C = windows.shape[1]
234 | # print('C: ', C)
235 | x = windows.view(-1, H // window_size, W // window_size, C, window_size, window_size)
236 | x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
237 | return x
238 |
239 | def window_partitionx(x, window_size):
240 | _, _, H, W = x.shape
241 | h, w = window_size * (H // window_size), window_size * (W // window_size)
242 | x_main = window_partitions(x[:, :, :h, :w], window_size)
243 | b_main = x_main.shape[0]
244 | if h == H and w == W:
245 | return x_main, [b_main]
246 | if h != H and w != W:
247 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
248 | b_r = x_r.shape[0] + b_main
249 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
250 | b_d = x_d.shape[0] + b_r
251 | x_dd = x[:, :, -window_size:, -window_size:]
252 | b_dd = x_dd.shape[0] + b_d
253 | # batch_list = [b_main, b_r, b_d, b_dd]
254 | return torch.cat([x_main, x_r, x_d, x_dd], dim=0), [b_main, b_r, b_d, b_dd]
255 | if h == H and w != W:
256 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size)
257 | b_r = x_r.shape[0] + b_main
258 | return torch.cat([x_main, x_r], dim=0), [b_main, b_r]
259 | if h != H and w == W:
260 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size)
261 | b_d = x_d.shape[0] + b_main
262 | return torch.cat([x_main, x_d], dim=0), [b_main, b_d]
263 |
264 | def window_reversex(windows, window_size, H, W, batch_list):
265 | h, w = window_size * (H // window_size), window_size * (W // window_size)
266 | x_main = window_reverses(windows[:batch_list[0], ...], window_size, h, w)
267 | B, C, _, _ = x_main.shape
268 | # print('windows: ', windows.shape)
269 | # print('batch_list: ', batch_list)
270 | res = torch.zeros([B, C, H, W],device=windows.device)
271 | res[:, :, :h, :w] = x_main
272 | if h == H and w == W:
273 | return res
274 | if h != H and w != W and len(batch_list) == 4:
275 | x_dd = window_reverses(windows[batch_list[2]:, ...], window_size, window_size, window_size)
276 | res[:, :, h:, w:] = x_dd[:, :, h - H:, w - W:]
277 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
278 | res[:, :, :h, w:] = x_r[:, :, :, w - W:]
279 | x_d = window_reverses(windows[batch_list[1]:batch_list[2], ...], window_size, window_size, w)
280 | res[:, :, h:, :w] = x_d[:, :, h - H:, :]
281 | return res
282 | if w != W and len(batch_list) == 2:
283 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size)
284 | res[:, :, :h, w:] = x_r[:, :, :, w - W:]
285 | if h != H and len(batch_list) == 2:
286 | x_d = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, window_size, w)
287 | res[:, :, h:, :w] = x_d[:, :, h - H:, :]
288 | return res
289 |
--------------------------------------------------------------------------------
/license.md:
--------------------------------------------------------------------------------
1 | ## ACADEMIC PUBLIC LICENSE
2 |
3 | ### Permissions
4 | :heavy_check_mark: Non-Commercial use
5 | :heavy_check_mark: Modification
6 | :heavy_check_mark: Distribution
7 | :heavy_check_mark: Private use
8 |
9 | ### Limitations
10 | :x: Commercial Use
11 | :x: Liability
12 | :x: Warranty
13 |
14 | ### Conditions
15 | :information_source: License and copyright notice
16 | :information_source: Same License
17 |
18 | DeepRFT is free for use in noncommercial settings: at academic institutions for teaching and research use, and at non-profit research organizations.
19 | You can use DeepRFT in your research, academic work, non-commercial work, projects and personal work. We only ask you to credit us appropriately.
20 |
21 | You have the right to use the software, to distribute copies, to receive source code, to change the software and distribute your modifications or the modified software.
22 | If you distribute verbatim or modified copies of this software, they must be distributed under this license.
23 | This license guarantees that you're safe when using DeepRFT in your work, for teaching or research.
24 | This license guarantees that DeepRFT will remain available free of charge for nonprofit use.
25 | You can modify DeepRFT to your purposes, and you can also share your modifications.
26 |
27 | If you would like to use DeepRFT in commercial settings, contact us so we can discuss options. Send an email to mxt_invoker1997@163.com
28 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class CharbonnierLoss(nn.Module):
6 | """Charbonnier Loss (L1)"""
7 |
8 | def __init__(self, eps=1e-3):
9 | super(CharbonnierLoss, self).__init__()
10 | self.eps = eps
11 |
12 | def forward(self, x, y):
13 | diff = x.to('cuda:0') - y.to('cuda:0')
14 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
15 | return loss
16 |
17 | class EdgeLoss(nn.Module):
18 | def __init__(self):
19 | super(EdgeLoss, self).__init__()
20 | k = torch.Tensor([[.05, .25, .4, .25, .05]])
21 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
22 | if torch.cuda.is_available():
23 | self.kernel = self.kernel.to('cuda:0')
24 | self.loss = CharbonnierLoss()
25 |
26 | def conv_gauss(self, img):
27 | n_channels, _, kw, kh = self.kernel.shape
28 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
29 | return F.conv2d(img, self.kernel, groups=n_channels)
30 |
31 | def laplacian_kernel(self, current):
32 | filtered = self.conv_gauss(current)
33 | down = filtered[:,:,::2,::2]
34 | new_filter = torch.zeros_like(filtered)
35 | new_filter[:,:,::2,::2] = down*4
36 | filtered = self.conv_gauss(new_filter)
37 | diff = current - filtered
38 | return diff
39 |
40 | def forward(self, x, y):
41 | loss = self.loss(self.laplacian_kernel(x.to('cuda:0')), self.laplacian_kernel(y.to('cuda:0')))
42 | return loss
43 |
44 | class fftLoss(nn.Module):
45 | def __init__(self):
46 | super(fftLoss, self).__init__()
47 |
48 | def forward(self, x, y):
49 | diff = torch.fft.fft2(x.to('cuda:0')) - torch.fft.fft2(y.to('cuda:0'))
50 | loss = torch.mean(abs(diff))
51 | return loss
52 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/setup.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import setuptools
6 |
7 | _VERSION = '0.3'
8 |
9 | REQUIRED_PACKAGES = [
10 | ]
11 |
12 | DEPENDENCY_LINKS = [
13 | ]
14 |
15 | setuptools.setup(
16 | name='warmup_scheduler',
17 | version=_VERSION,
18 | description='Gradually Warm-up LR Scheduler for Pytorch',
19 | install_requires=REQUIRED_PACKAGES,
20 | dependency_links=DEPENDENCY_LINKS,
21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr',
22 | license='MIT License',
23 | package_dir={},
24 | packages=setuptools.find_packages(exclude=['tests']),
25 | )
26 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from warmup_scheduler.scheduler import GradualWarmupScheduler
3 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR
3 | from torch.optim.sgd import SGD
4 |
5 | from warmup_scheduler import GradualWarmupScheduler
6 |
7 |
8 | if __name__ == '__main__':
9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
10 | optim = SGD(model, 0.1)
11 |
12 | # scheduler_warmup is chained with schduler_steplr
13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
15 |
16 | # this zero gradient update is needed to avoid a warning message, issue #8.
17 | optim.zero_grad()
18 | optim.step()
19 |
20 | for epoch in range(1, 20):
21 | scheduler_warmup.step(epoch)
22 | print(epoch, optim.param_groups[0]['lr'])
23 |
24 | optim.step() # backward pass (update network)
25 |
--------------------------------------------------------------------------------
/pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """ Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
12 | total_epoch: target learning rate is reached at total_epoch, gradually
13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14 | """
15 |
16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
17 | self.multiplier = multiplier
18 | if self.multiplier < 1.:
19 | raise ValueError('multiplier should be greater thant or equal to 1.')
20 | self.total_epoch = total_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super(GradualWarmupScheduler, self).__init__(optimizer)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.total_epoch:
27 | if self.after_scheduler:
28 | if not self.finished:
29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30 | self.finished = True
31 | return self.after_scheduler.get_lr()
32 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
33 |
34 | if self.multiplier == 1.0:
35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
36 | else:
37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
38 |
39 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
40 | if epoch is None:
41 | epoch = self.last_epoch + 1
42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
43 | if self.last_epoch <= self.total_epoch:
44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
46 | param_group['lr'] = lr
47 | else:
48 | if epoch is None:
49 | self.after_scheduler.step(metrics, None)
50 | else:
51 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
52 |
53 | def step(self, epoch=None, metrics=None):
54 | if type(self.after_scheduler) != ReduceLROnPlateau:
55 | if self.finished and self.after_scheduler:
56 | if epoch is None:
57 | self.after_scheduler.step(None)
58 | else:
59 | self.after_scheduler.step(epoch - self.total_epoch)
60 | else:
61 | return super(GradualWarmupScheduler, self).step(epoch)
62 | else:
63 | self.step_ReduceLROnPlateau(metrics, epoch)
64 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch.nn as nn
4 | import torch
5 | from torch.utils.data import DataLoader
6 | import utils
7 | from data_RGB import get_test_data
8 | from DeepRFT_MIMO import DeepRFT as mynet
9 | from skimage import img_as_ubyte
10 | from get_parameter_number import get_parameter_number
11 | from tqdm import tqdm
12 | from layers import *
13 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss
14 | import cv2
15 |
16 |
17 | parser = argparse.ArgumentParser(description='Image Deblurring')
18 | parser.add_argument('--input_dir', default='./Datasets/GoPro/test/blur', type=str, help='Directory of validation images')
19 | parser.add_argument('--target_dir', default='./Datasets/GoPro/test/sharp', type=str, help='Directory of validation images')
20 | parser.add_argument('--output_dir', default='./results/DeepRFT/GoPro', type=str, help='Directory of validation images')
21 | parser.add_argument('--weights', default='./checkpoints/DeepRFT/model_GoPro.pth', type=str, help='Path to weights')
22 | parser.add_argument('--get_psnr', default=False, type=bool, help='PSNR')
23 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')
24 | parser.add_argument('--save_result', default=False, type=bool, help='save result')
25 | parser.add_argument('--win_size', default=256, type=int, help='window size, [GoPro, HIDE, RealBlur]=256, [DPDD]=512')
26 | parser.add_argument('--num_res', default=8, type=int, help='num of resblocks, [Small, Med, PLus]=[4, 8, 20]')
27 | args = parser.parse_args()
28 | result_dir = args.output_dir
29 | win = args.win_size
30 | get_psnr = args.get_psnr
31 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
32 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
33 | # model_restoration = mynet()
34 | model_restoration = mynet(num_res=args.num_res, inference=True)
35 | # print number of model
36 | get_parameter_number(model_restoration)
37 | # utils.load_checkpoint(model_restoration, args.weights)
38 | utils.load_checkpoint_compress_doconv(model_restoration, args.weights)
39 | print("===>Testing using weights: ", args.weights)
40 | model_restoration.cuda()
41 | model_restoration = nn.DataParallel(model_restoration)
42 | model_restoration.eval()
43 |
44 | # dataset = args.dataset
45 | rgb_dir_test = args.input_dir
46 | test_dataset = get_test_data(rgb_dir_test, img_options={})
47 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)
48 | psnr_val_rgb = []
49 | psnr = 0
50 |
51 | utils.mkdir(result_dir)
52 |
53 | with torch.no_grad():
54 | psnr_list = []
55 | ssim_list = []
56 | for ii, data_test in enumerate(tqdm(test_loader), 0):
57 |
58 | torch.cuda.ipc_collect()
59 | torch.cuda.empty_cache()
60 | input_ = data_test[0].cuda()
61 | filenames = data_test[1]
62 | _, _, Hx, Wx = input_.shape
63 | filenames = data_test[1]
64 | input_re, batch_list = window_partitionx(input_, win)
65 | restored = model_restoration(input_re)
66 | restored = window_reversex(restored, win, Hx, Wx, batch_list)
67 |
68 | restored = torch.clamp(restored, 0, 1)
69 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
70 | for batch in range(len(restored)):
71 | restored_img = restored[batch]
72 | restored_img = img_as_ubyte(restored[batch])
73 | if get_psnr:
74 | rgb_gt = cv2.imread(os.path.join(args.target_dir, filenames[batch]+'.png'))
75 | rgb_gt = cv2.cvtColor(rgb_gt, cv2.COLOR_BGR2RGB)
76 | psnr_val_rgb.append(psnr_loss(restored_img, rgb_gt))
77 | if args.save_result:
78 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img)
79 |
80 | if get_psnr:
81 | psnr = sum(psnr_val_rgb) / len(test_dataset)
82 | print("PSNR: %f" % psnr)
83 |
--------------------------------------------------------------------------------
/test_speed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch.nn as nn
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | from data_RGB import get_test_data
8 | from DeepRFT_MIMO import DeepRFT as mynet
9 |
10 | from get_parameter_number import get_parameter_number
11 | from tqdm import tqdm
12 | from layers import *
13 | import time
14 |
15 |
16 | parser = argparse.ArgumentParser(description='Image Deblurring')
17 | parser.add_argument('--input_dir', default='./Datasets/GoPro/test/blur', type=str, help='Directory of validation images')
18 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')
19 |
20 | args = parser.parse_args()
21 |
22 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
23 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
24 |
25 | # model_restoration = mynet(inference=True)
26 | model_restoration = mynet()
27 | # print number of model
28 | get_parameter_number(model_restoration)
29 |
30 | # utils.load_checkpoint_compress_doconv(model_restoration, args.weights)
31 | # print("===>Testing using weights: ", args.weights)
32 | model_restoration.cuda()
33 | model_restoration = nn.DataParallel(model_restoration)
34 | model_restoration.eval()
35 |
36 | # dataset = args.dataset
37 | rgb_dir_test = args.input_dir
38 | test_dataset = get_test_data(rgb_dir_test, img_options={})
39 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True)
40 |
41 | win = 256
42 | all_time = 0.
43 | with torch.no_grad():
44 | psnr_list = []
45 | ssim_list = []
46 | for ii, data_test in enumerate(tqdm(test_loader), 0):
47 |
48 | torch.cuda.ipc_collect()
49 | torch.cuda.empty_cache()
50 |
51 | input_ = data_test[0].cuda()
52 | filenames = data_test[1]
53 | _, _, Hx, Wx = input_.shape
54 | filenames = data_test[1]
55 |
56 | torch.cuda.synchronize()
57 | start = time.time()
58 | input_re, batch_list = window_partitionx(input_, win)
59 | restored = model_restoration(input_re)
60 | # print(restored[0].shape)
61 | restored = window_reversex(restored[0], win, Hx, Wx, batch_list)
62 | restored = torch.clamp(restored, 0, 1)
63 | # print(restored.shape)
64 | torch.cuda.synchronize()
65 | end = time.time()
66 | all_time += end - start
67 | print('average_time: ', all_time / len(test_dataset))
68 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
4 |
5 | import torch
6 | torch.backends.cudnn.benchmark = True
7 |
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | from torch.utils.data import DataLoader
11 |
12 | import random
13 | import time
14 | import numpy as np
15 |
16 | import utils
17 | from data_RGB import get_training_data, get_validation_data
18 | from DeepRFT_MIMO import DeepRFT as myNet
19 | import losses
20 | from warmup_scheduler import GradualWarmupScheduler
21 | from tqdm import tqdm
22 | from get_parameter_number import get_parameter_number
23 | import kornia
24 | from torch.utils.tensorboard import SummaryWriter
25 | import argparse
26 |
27 | ######### Set Seeds ###########
28 | random.seed(1234)
29 | np.random.seed(1234)
30 | torch.manual_seed(1234)
31 | torch.cuda.manual_seed_all(1234)
32 |
33 | start_epoch = 1
34 |
35 | parser = argparse.ArgumentParser(description='Image Deblurring')
36 |
37 | parser.add_argument('--train_dir', default='./Datasets/GoPro/train', type=str, help='Directory of train images')
38 | parser.add_argument('--val_dir', default='./Datasets/GoPro/val', type=str, help='Directory of validation images')
39 | parser.add_argument('--model_save_dir', default='./checkpoints', type=str, help='Path to save weights')
40 | parser.add_argument('--pretrain_weights', default='./checkpoints/model_best.pth', type=str, help='Path to pretrain-weights')
41 | parser.add_argument('--mode', default='Deblurring', type=str)
42 | parser.add_argument('--session', default='DeepRFT_gopro', type=str, help='session')
43 | parser.add_argument('--patch_size', default=256, type=int, help='patch size, for paper: [GoPro, HIDE, RealBlur]=256, [DPDD]=512')
44 | parser.add_argument('--num_epochs', default=3000, type=int, help='num_epochs')
45 | parser.add_argument('--batch_size', default=16, type=int, help='batch_size')
46 | parser.add_argument('--val_epochs', default=20, type=int, help='val_epochs')
47 | args = parser.parse_args()
48 |
49 | mode = args.mode
50 | session = args.session
51 | patch_size = args.patch_size
52 |
53 | model_dir = os.path.join(args.model_save_dir, mode, 'models', session)
54 | utils.mkdir(model_dir)
55 |
56 | train_dir = args.train_dir
57 | val_dir = args.val_dir
58 |
59 | num_epochs = args.num_epochs
60 | batch_size = args.batch_size
61 | val_epochs = args.val_epochs
62 |
63 | start_lr = 2e-4
64 | end_lr = 1e-6
65 |
66 | ######### Model ###########
67 | model_restoration = myNet()
68 |
69 | # print number of model
70 | get_parameter_number(model_restoration)
71 |
72 | model_restoration.cuda()
73 |
74 | device_ids = [i for i in range(torch.cuda.device_count())]
75 | if torch.cuda.device_count() > 1:
76 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
77 |
78 | optimizer = optim.Adam(model_restoration.parameters(), lr=start_lr, betas=(0.9, 0.999), eps=1e-8)
79 |
80 | ######### Scheduler ###########
81 | warmup_epochs = 3
82 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs-warmup_epochs, eta_min=end_lr)
83 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
84 |
85 | RESUME = False
86 | Pretrain = False
87 | model_pre_dir = ''
88 | ######### Pretrain ###########
89 | if Pretrain:
90 | utils.load_checkpoint(model_restoration, model_pre_dir)
91 |
92 | print('------------------------------------------------------------------------------')
93 | print("==> Retrain Training with: " + model_pre_dir)
94 | print('------------------------------------------------------------------------------')
95 |
96 | ######### Resume ###########
97 | if RESUME:
98 | path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
99 | utils.load_checkpoint(model_restoration,path_chk_rest)
100 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1
101 | utils.load_optim(optimizer, path_chk_rest)
102 |
103 | for i in range(1, start_epoch):
104 | scheduler.step()
105 | new_lr = scheduler.get_lr()[0]
106 | print('------------------------------------------------------------------------------')
107 | print("==> Resuming Training with learning rate:", new_lr)
108 | print('------------------------------------------------------------------------------')
109 |
110 | if len(device_ids)>1:
111 | model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids)
112 |
113 | ######### Loss ###########
114 | criterion_char = losses.CharbonnierLoss()
115 | criterion_edge = losses.EdgeLoss()
116 | criterion_fft = losses.fftLoss()
117 | ######### DataLoaders ###########
118 | train_dataset = get_training_data(train_dir, {'patch_size':patch_size})
119 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=False, pin_memory=True)
120 |
121 | val_dataset = get_validation_data(val_dir, {'patch_size':patch_size})
122 | val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)
123 |
124 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch, num_epochs + 1))
125 | print('===> Loading datasets')
126 |
127 | best_psnr = 0
128 | best_epoch = 0
129 | writer = SummaryWriter(model_dir)
130 | iter = 0
131 |
132 | for epoch in range(start_epoch, num_epochs + 1):
133 | epoch_start_time = time.time()
134 | epoch_loss = 0
135 | train_id = 1
136 |
137 | model_restoration.train()
138 | for i, data in enumerate(tqdm(train_loader), 0):
139 |
140 | # zero_grad
141 | for param in model_restoration.parameters():
142 | param.grad = None
143 |
144 | target_ = data[0].cuda()
145 | input_ = data[1].cuda()
146 | target = kornia.geometry.transform.build_pyramid(target_, 3)
147 | restored = model_restoration(input_)
148 |
149 | loss_fft = criterion_fft(restored[0], target[0]) + criterion_fft(restored[1], target[1]) + criterion_fft(
150 | restored[2], target[2])
151 | loss_char = criterion_char(restored[0], target[0]) + criterion_char(restored[1], target[1]) + criterion_char(restored[2], target[2])
152 | loss_edge = criterion_edge(restored[0], target[0]) + criterion_edge(restored[1], target[1]) + criterion_edge(restored[2], target[2])
153 | loss = loss_char + 0.01 * loss_fft + 0.05 * loss_edge
154 | loss.backward()
155 | optimizer.step()
156 | epoch_loss +=loss.item()
157 | iter += 1
158 | writer.add_scalar('loss/fft_loss', loss_fft, iter)
159 | writer.add_scalar('loss/char_loss', loss_char, iter)
160 | writer.add_scalar('loss/edge_loss', loss_edge, iter)
161 | writer.add_scalar('loss/iter_loss', loss, iter)
162 | writer.add_scalar('loss/epoch_loss', epoch_loss, epoch)
163 | #### Evaluation ####
164 | if epoch % val_epochs == 0:
165 | model_restoration.eval()
166 | psnr_val_rgb = []
167 | for ii, data_val in enumerate((val_loader), 0):
168 | target = data_val[0].cuda()
169 | input_ = data_val[1].cuda()
170 |
171 | with torch.no_grad():
172 | restored = model_restoration(input_)
173 |
174 | for res,tar in zip(restored[0], target):
175 | psnr_val_rgb.append(utils.torchPSNR(res, tar))
176 |
177 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
178 | writer.add_scalar('val/psnr', psnr_val_rgb, epoch)
179 | if psnr_val_rgb > best_psnr:
180 | best_psnr = psnr_val_rgb
181 | best_epoch = epoch
182 | torch.save({'epoch': epoch,
183 | 'state_dict': model_restoration.state_dict(),
184 | 'optimizer' : optimizer.state_dict()
185 | }, os.path.join(model_dir,"model_best.pth"))
186 |
187 | print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
188 |
189 | torch.save({'epoch': epoch,
190 | 'state_dict': model_restoration.state_dict(),
191 | 'optimizer' : optimizer.state_dict()
192 | }, os.path.join(model_dir,f"model_epoch_{epoch}.pth"))
193 |
194 | scheduler.step()
195 |
196 | print("------------------------------------------------------------------")
197 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0]))
198 | print("------------------------------------------------------------------")
199 |
200 | torch.save({'epoch': epoch,
201 | 'state_dict': model_restoration.state_dict(),
202 | 'optimizer' : optimizer.state_dict()
203 | }, os.path.join(model_dir,"model_latest.pth"))
204 |
205 | writer.close()
206 |
--------------------------------------------------------------------------------
/train_wo_warmup.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
4 |
5 | import torch
6 | torch.backends.cudnn.benchmark = True
7 |
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | from torch.utils.data import DataLoader
11 |
12 | import random
13 | import time
14 | import numpy as np
15 |
16 | import utils
17 | from data_RGB import get_training_data, get_validation_data
18 | from DeepRFT_MIMO import DeepRFT as myNet
19 | import losses
20 | from tqdm import tqdm
21 | from get_parameter_number import get_parameter_number
22 | import kornia
23 | from torch.utils.tensorboard import SummaryWriter
24 | import argparse
25 |
26 | ######### Set Seeds ###########
27 | random.seed(1234)
28 | np.random.seed(1234)
29 | torch.manual_seed(1234)
30 | torch.cuda.manual_seed_all(1234)
31 |
32 | start_epoch = 1
33 |
34 | parser = argparse.ArgumentParser(description='Image Deblurring')
35 |
36 | parser.add_argument('--train_dir', default='./Datasets/GoPro/train', type=str, help='Directory of train images')
37 | parser.add_argument('--val_dir', default='./Datasets/GoPro/val', type=str, help='Directory of validation images')
38 | parser.add_argument('--model_save_dir', default='./checkpoints', type=str, help='Path to save weights')
39 | parser.add_argument('--pretrain_weights', default='./checkpoints/model_best.pth', type=str, help='Path to pretrain-weights')
40 | parser.add_argument('--mode', default='Deblurring', type=str)
41 | parser.add_argument('--session', default='DeepRFT_gopro', type=str, help='session')
42 | parser.add_argument('--patch_size', default=256, type=int, help='patch size, for paper: [GoPro, HIDE, RealBlur]=256, [DPDD]=512')
43 | parser.add_argument('--num_epochs', default=3000, type=int, help='num_epochs')
44 | parser.add_argument('--batch_size', default=16, type=int, help='batch_size')
45 | parser.add_argument('--val_epochs', default=20, type=int, help='val_epochs')
46 | args = parser.parse_args()
47 |
48 | mode = args.mode
49 | session = args.session
50 | patch_size = args.patch_size
51 |
52 | model_dir = os.path.join(args.model_save_dir, mode, 'models', session)
53 | utils.mkdir(model_dir)
54 |
55 | train_dir = args.train_dir
56 | val_dir = args.val_dir
57 |
58 | num_epochs = args.num_epochs
59 | batch_size = args.batch_size
60 | val_epochs = args.val_epochs
61 |
62 | start_lr = 2e-4
63 | end_lr = 1e-6
64 |
65 | ######### Model ###########
66 | model_restoration = myNet()
67 | # print number of model
68 | get_parameter_number(model_restoration)
69 |
70 | model_restoration.cuda()
71 |
72 | device_ids = [i for i in range(torch.cuda.device_count())]
73 | if torch.cuda.device_count() > 1:
74 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n")
75 |
76 | optimizer = optim.Adam(model_restoration.parameters(), lr=start_lr, betas=(0.9, 0.999), eps=1e-8)
77 |
78 | ######### Scheduler ###########
79 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min=end_lr)
80 |
81 | RESUME = False
82 | Pretrain = False
83 | model_pre_dir = ''
84 | ######### Pretrain ###########
85 | if Pretrain:
86 | utils.load_checkpoint(model_restoration, model_pre_dir)
87 |
88 | print('------------------------------------------------------------------------------')
89 | print("==> Retrain Training with: " + model_pre_dir)
90 | print('------------------------------------------------------------------------------')
91 |
92 | ######### Resume ###########
93 | if RESUME:
94 | path_chk_rest = utils.get_last_path(model_dir, '_latest.pth')
95 | utils.load_checkpoint(model_restoration,path_chk_rest)
96 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1
97 | utils.load_optim(optimizer, path_chk_rest)
98 |
99 | for i in range(1, start_epoch):
100 | scheduler.step()
101 | new_lr = scheduler.get_lr()[0]
102 | print('------------------------------------------------------------------------------')
103 | print("==> Resuming Training with learning rate:", new_lr)
104 | print('------------------------------------------------------------------------------')
105 |
106 | if len(device_ids)>1:
107 | model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids)
108 |
109 | ######### Loss ###########
110 | criterion_char = losses.CharbonnierLoss()
111 | criterion_edge = losses.EdgeLoss()
112 | criterion_fft = losses.fftLoss()
113 | ######### DataLoaders ###########
114 | train_dataset = get_training_data(train_dir, {'patch_size':patch_size})
115 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=False, pin_memory=True)
116 |
117 | val_dataset = get_validation_data(val_dir, {'patch_size':patch_size})
118 | val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True)
119 |
120 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch, num_epochs + 1))
121 | print('===> Loading datasets')
122 |
123 | best_psnr = 0
124 | best_epoch = 0
125 | writer = SummaryWriter(model_dir)
126 | iter = 0
127 | for epoch in range(start_epoch, num_epochs + 1):
128 | epoch_start_time = time.time()
129 | epoch_loss = 0
130 | train_id = 1
131 |
132 | model_restoration.train()
133 | for i, data in enumerate(tqdm(train_loader), 0):
134 |
135 | # zero_grad
136 | for param in model_restoration.parameters():
137 | param.grad = None
138 |
139 | target_ = data[0].cuda()
140 | input_ = data[1].cuda()
141 | target = kornia.geometry.transform.build_pyramid(target_, 3)
142 | restored = model_restoration(input_)
143 |
144 | loss_fft = criterion_fft(restored[0], target[0]) + criterion_fft(restored[1], target[1]) + criterion_fft(
145 | restored[2], target[2])
146 | loss_char = criterion_char(restored[0], target[0]) + criterion_char(restored[1], target[1]) + criterion_char(restored[2], target[2])
147 | loss_edge = criterion_edge(restored[0], target[0]) + criterion_edge(restored[1], target[1]) + criterion_edge(restored[2], target[2])
148 | loss = loss_char + 0.01 * loss_fft + 0.05 * loss_edge
149 | loss.backward()
150 | optimizer.step()
151 | epoch_loss +=loss.item()
152 | iter += 1
153 | writer.add_scalar('loss/fft_loss', loss_fft, iter)
154 | writer.add_scalar('loss/char_loss', loss_char, iter)
155 | writer.add_scalar('loss/edge_loss', loss_edge, iter)
156 | writer.add_scalar('loss/iter_loss', loss, iter)
157 | writer.add_scalar('loss/epoch_loss', epoch_loss, epoch)
158 |
159 | #### Evaluation ####
160 | if epoch % val_epochs == 0:
161 | model_restoration.eval()
162 | psnr_val_rgb = []
163 | for ii, data_val in enumerate((val_loader), 0):
164 | target = data_val[0].cuda()
165 | input_ = data_val[1].cuda()
166 |
167 | with torch.no_grad():
168 | restored = model_restoration(input_)
169 |
170 | for res,tar in zip(restored[0], target):
171 | psnr_val_rgb.append(utils.torchPSNR(res, tar))
172 |
173 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()
174 | writer.add_scalar('val/psnr', psnr_val_rgb, epoch)
175 | if psnr_val_rgb > best_psnr:
176 | best_psnr = psnr_val_rgb
177 | best_epoch = epoch
178 | torch.save({'epoch': epoch,
179 | 'state_dict': model_restoration.state_dict(),
180 | 'optimizer' : optimizer.state_dict()
181 | }, os.path.join(model_dir,"model_best.pth"))
182 |
183 | print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))
184 |
185 | torch.save({'epoch': epoch,
186 | 'state_dict': model_restoration.state_dict(),
187 | 'optimizer' : optimizer.state_dict()
188 | }, os.path.join(model_dir,f"model_epoch_{epoch}.pth"))
189 |
190 | scheduler.step()
191 |
192 | print("------------------------------------------------------------------")
193 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0]))
194 | print("------------------------------------------------------------------")
195 |
196 | torch.save({'epoch': epoch,
197 | 'state_dict': model_restoration.state_dict(),
198 | 'optimizer' : optimizer.state_dict()
199 | }, os.path.join(model_dir,"model_latest.pth"))
200 |
201 | writer.close()
202 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dir_utils import *
2 | from .image_utils import *
3 | from .model_utils import *
4 | from .dataset_utils import *
5 |
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataset_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/dataset_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dir_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/dir_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/image_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/image_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/model_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/model_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class MixUp_AUG:
4 | def __init__(self):
5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6]))
6 |
7 | def aug(self, rgb_gt, rgb_noisy):
8 | bs = rgb_gt.size(0)
9 | indices = torch.randperm(bs)
10 | rgb_gt2 = rgb_gt[indices]
11 | rgb_noisy2 = rgb_noisy[indices]
12 |
13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
14 |
15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
17 |
18 | return rgb_gt, rgb_noisy
--------------------------------------------------------------------------------
/utils/dir_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from natsort import natsorted
3 | from glob import glob
4 |
5 | def mkdirs(paths):
6 | if isinstance(paths, list) and not isinstance(paths, str):
7 | for path in paths:
8 | mkdir(path)
9 | else:
10 | mkdir(paths)
11 |
12 | def mkdir(path):
13 | if not os.path.exists(path):
14 | os.makedirs(path)
15 |
16 | def get_last_path(path, session):
17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
18 | return x
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import cv2
4 |
5 | def torchPSNR(tar_img, prd_img):
6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
7 | rmse = (imdff**2).mean().sqrt()
8 | ps = 20*torch.log10(1/rmse)
9 | return ps
10 |
11 | def save_img(filepath, img):
12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
13 |
14 | def numpyPSNR(tar_img, prd_img):
15 | imdff = np.float32(prd_img) - np.float32(tar_img)
16 | rmse = np.sqrt(np.mean(imdff**2))
17 | ps = 20*np.log10(255/rmse)
18 | return ps
19 |
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from collections import OrderedDict
4 | import numpy as np
5 | def freeze(model):
6 | for p in model.parameters():
7 | p.requires_grad=False
8 |
9 | def unfreeze(model):
10 | for p in model.parameters():
11 | p.requires_grad=True
12 |
13 | def is_frozen(model):
14 | x = [p.requires_grad for p in model.parameters()]
15 | return not all(x)
16 |
17 | def save_checkpoint(model_dir, state, session):
18 | epoch = state['epoch']
19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
20 | torch.save(state, model_out_path)
21 |
22 | def load_checkpoint(model, weights):
23 | checkpoint = torch.load(weights)
24 | # print(checkpoint)
25 | try:
26 | model.load_state_dict(checkpoint["state_dict"])
27 | except:
28 | state_dict = checkpoint["state_dict"]
29 | new_state_dict = OrderedDict()
30 | for k, v in state_dict.items():
31 | # print(k)
32 | name = k[7:] # remove `module.`
33 | new_state_dict[name] = v
34 |
35 | model.load_state_dict(new_state_dict)
36 |
37 |
38 | def load_checkpoint_compress_doconv(model, weights):
39 | checkpoint = torch.load(weights)
40 | # print(checkpoint)
41 | # state_dict = OrderedDict()
42 | # try:
43 | # model.load_state_dict(checkpoint["state_dict"])
44 | # state_dict = checkpoint["state_dict"]
45 | # except:
46 | old_state_dict = checkpoint["state_dict"]
47 | state_dict = OrderedDict()
48 | for k, v in old_state_dict.items():
49 | # print(k)
50 | name = k
51 | if k[:7] == 'module.':
52 | name = k[7:] # remove `module.`
53 | state_dict[name] = v
54 | # state_dict = checkpoint["state_dict"]
55 | do_state_dict = OrderedDict()
56 | for k, v in state_dict.items():
57 | if k[-1] == 'W' and k[:-1] + 'D' in state_dict:
58 | k_D = k[:-1] + 'D'
59 | k_D_diag = k_D + '_diag'
60 | W = v
61 | D = state_dict[k_D]
62 | D_diag = state_dict[k_D_diag]
63 | D = D + D_diag
64 | # W = torch.reshape(W, (out_channels, in_channels, D_mul))
65 | out_channels, in_channels, MN = W.shape
66 | M = int(np.sqrt(MN))
67 | DoW_shape = (out_channels, in_channels, M, M)
68 | DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
69 | do_state_dict[k] = DoW
70 | elif k[-1] == 'D' or k[-6:] == 'D_diag':
71 | continue
72 | elif k[-1] == 'W':
73 | out_channels, in_channels, MN = v.shape
74 | M = int(np.sqrt(MN))
75 | W_shape = (out_channels, in_channels, M, M)
76 | do_state_dict[k] = torch.reshape(v, W_shape)
77 | else:
78 | do_state_dict[k] = v
79 | model.load_state_dict(do_state_dict)
80 | def load_checkpoint_hin(model, weights):
81 | checkpoint = torch.load(weights)
82 | # print(checkpoint)
83 | try:
84 | model.load_state_dict(checkpoint)
85 | except:
86 | state_dict = checkpoint
87 | new_state_dict = OrderedDict()
88 | for k, v in state_dict.items():
89 | name = k[7:] # remove `module.`
90 | new_state_dict[name] = v
91 | model.load_state_dict(new_state_dict)
92 | def load_checkpoint_multigpu(model, weights):
93 | checkpoint = torch.load(weights)
94 | state_dict = checkpoint["state_dict"]
95 | new_state_dict = OrderedDict()
96 | for k, v in state_dict.items():
97 | name = k[7:] # remove `module.`
98 | new_state_dict[name] = v
99 | model.load_state_dict(new_state_dict)
100 |
101 | def load_start_epoch(weights):
102 | checkpoint = torch.load(weights)
103 | epoch = checkpoint["epoch"]
104 | return epoch
105 |
106 | def load_optim(optimizer, weights):
107 | checkpoint = torch.load(weights)
108 | optimizer.load_state_dict(checkpoint['optimizer'])
109 | # for p in optimizer.param_groups: lr = p['lr']
110 | # return lr
111 |
--------------------------------------------------------------------------------