├── DeepRFT_MIMO.py ├── README.md ├── data_RGB.py ├── dataset_RGB.py ├── doconv_pytorch.py ├── evaluate_GOPRO.m ├── evaluate_RealBlur.py ├── get_parameter_number.py ├── images ├── framework.png └── psnr_params_flops.png ├── layers.py ├── license.md ├── losses.py ├── pytorch-gradual-warmup-lr ├── setup.py └── warmup_scheduler │ ├── __init__.py │ ├── run.py │ └── scheduler.py ├── test.py ├── test_speed.py ├── train.py ├── train_wo_warmup.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── dataset_utils.cpython-38.pyc ├── dir_utils.cpython-38.pyc ├── image_utils.cpython-38.pyc └── model_utils.cpython-38.pyc ├── dataset_utils.py ├── dir_utils.py ├── image_utils.py └── model_utils.py /DeepRFT_MIMO.py: -------------------------------------------------------------------------------- 1 | from layers import * 2 | 3 | 4 | 5 | class EBlock(nn.Module): 6 | def __init__(self, out_channel, num_res=8, ResBlock=ResBlock): 7 | super(EBlock, self).__init__() 8 | 9 | layers = [ResBlock(out_channel) for _ in range(num_res)] 10 | 11 | self.layers = nn.Sequential(*layers) 12 | 13 | def forward(self, x): 14 | return self.layers(x) 15 | 16 | class DBlock(nn.Module): 17 | def __init__(self, channel, num_res=8, ResBlock=ResBlock): 18 | super(DBlock, self).__init__() 19 | 20 | layers = [ResBlock(channel) for _ in range(num_res)] 21 | self.layers = nn.Sequential(*layers) 22 | 23 | def forward(self, x): 24 | return self.layers(x) 25 | 26 | class AFF(nn.Module): 27 | def __init__(self, in_channel, out_channel, BasicConv=BasicConv): 28 | super(AFF, self).__init__() 29 | self.conv = nn.Sequential( 30 | BasicConv(in_channel, out_channel, kernel_size=1, stride=1, relu=True), 31 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 32 | ) 33 | 34 | def forward(self, x1, x2, x4): 35 | x = torch.cat([x1, x2, x4], dim=1) 36 | return self.conv(x) 37 | 38 | class SCM(nn.Module): 39 | def __init__(self, out_plane, BasicConv=BasicConv, inchannel=3): 40 | super(SCM, self).__init__() 41 | self.main = nn.Sequential( 42 | BasicConv(inchannel, out_plane//4, kernel_size=3, stride=1, relu=True), 43 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 44 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 45 | BasicConv(out_plane // 2, out_plane-inchannel, kernel_size=1, stride=1, relu=True) 46 | ) 47 | 48 | self.conv = BasicConv(out_plane, out_plane, kernel_size=1, stride=1, relu=False) 49 | 50 | def forward(self, x): 51 | x = torch.cat([x, self.main(x)], dim=1) 52 | return self.conv(x) 53 | 54 | class FAM(nn.Module): 55 | def __init__(self, channel, BasicConv=BasicConv): 56 | super(FAM, self).__init__() 57 | self.merge = BasicConv(channel, channel, kernel_size=3, stride=1, relu=False) 58 | 59 | def forward(self, x1, x2): 60 | x = x1 * x2 61 | out = x1 + self.merge(x) 62 | return out 63 | 64 | class DeepRFT_Small(nn.Module): 65 | def __init__(self, num_res=4, inference=False): 66 | super(DeepRFT_Small, self).__init__() 67 | self.inference = inference 68 | 69 | if not inference: 70 | BasicConv = BasicConv_do 71 | ResBlock = ResBlock_do_fft_bench 72 | else: 73 | BasicConv = BasicConv_do_eval 74 | ResBlock = ResBlock_do_fft_bench_eval 75 | 76 | base_channel = 32 77 | 78 | self.Encoder = nn.ModuleList([ 79 | EBlock(base_channel, num_res, ResBlock=ResBlock), 80 | EBlock(base_channel*2, num_res, ResBlock=ResBlock), 81 | EBlock(base_channel*4, num_res, ResBlock=ResBlock), 82 | ]) 83 | 84 | self.feat_extract = nn.ModuleList([ 85 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 86 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 87 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 88 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 89 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 90 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 91 | ]) 92 | 93 | self.Decoder = nn.ModuleList([ 94 | DBlock(base_channel * 4, num_res, ResBlock=ResBlock), 95 | DBlock(base_channel * 2, num_res, ResBlock=ResBlock), 96 | DBlock(base_channel, num_res, ResBlock=ResBlock) 97 | ]) 98 | 99 | self.Convs = nn.ModuleList([ 100 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 101 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 102 | ]) 103 | 104 | self.ConvsOut = nn.ModuleList( 105 | [ 106 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 107 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 108 | ] 109 | ) 110 | 111 | self.AFFs = nn.ModuleList([ 112 | AFF(base_channel * 7, base_channel*1, BasicConv=BasicConv), 113 | AFF(base_channel * 7, base_channel*2, BasicConv=BasicConv) 114 | ]) 115 | 116 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv) 117 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv) 118 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv) 119 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv) 120 | 121 | def forward(self, x): 122 | x_2 = F.interpolate(x, scale_factor=0.5) 123 | x_4 = F.interpolate(x_2, scale_factor=0.5) 124 | z2 = self.SCM2(x_2) 125 | z4 = self.SCM1(x_4) 126 | 127 | outputs = list() 128 | 129 | x_ = self.feat_extract[0](x) 130 | res1 = self.Encoder[0](x_) 131 | 132 | z = self.feat_extract[1](res1) 133 | z = self.FAM2(z, z2) 134 | res2 = self.Encoder[1](z) 135 | 136 | z = self.feat_extract[2](res2) 137 | z = self.FAM1(z, z4) 138 | z = self.Encoder[2](z) 139 | 140 | z12 = F.interpolate(res1, scale_factor=0.5) 141 | z21 = F.interpolate(res2, scale_factor=2) 142 | z42 = F.interpolate(z, scale_factor=2) 143 | z41 = F.interpolate(z42, scale_factor=2) 144 | 145 | res2 = self.AFFs[1](z12, res2, z42) 146 | res1 = self.AFFs[0](res1, z21, z41) 147 | 148 | z = self.Decoder[0](z) 149 | z_ = self.ConvsOut[0](z) 150 | z = self.feat_extract[3](z) 151 | if not self.inference: 152 | outputs.append(z_+x_4) 153 | 154 | z = torch.cat([z, res2], dim=1) 155 | z = self.Convs[0](z) 156 | z = self.Decoder[1](z) 157 | z_ = self.ConvsOut[1](z) 158 | z = self.feat_extract[4](z) 159 | if not self.inference: 160 | outputs.append(z_+x_2) 161 | 162 | z = torch.cat([z, res1], dim=1) 163 | z = self.Convs[1](z) 164 | z = self.Decoder[2](z) 165 | z = self.feat_extract[5](z) 166 | if not self.inference: 167 | outputs.append(z + x) 168 | return outputs[::-1] 169 | else: 170 | return z + x 171 | class DeepRFT_flops(nn.Module): 172 | def __init__(self, num_res=8, inference=True): 173 | super(DeepRFT_flops, self).__init__() 174 | self.inference = inference 175 | ResBlock = ResBlock_fft_bench 176 | base_channel = 32 177 | 178 | self.Encoder = nn.ModuleList([ 179 | EBlock(base_channel, num_res, ResBlock=ResBlock), 180 | EBlock(base_channel*2, num_res, ResBlock=ResBlock), 181 | EBlock(base_channel*4, num_res, ResBlock=ResBlock), 182 | ]) 183 | 184 | self.feat_extract = nn.ModuleList([ 185 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 186 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 187 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 188 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 189 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 190 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 191 | ]) 192 | 193 | self.Decoder = nn.ModuleList([ 194 | DBlock(base_channel * 4, num_res, ResBlock=ResBlock), 195 | DBlock(base_channel * 2, num_res, ResBlock=ResBlock), 196 | DBlock(base_channel, num_res, ResBlock=ResBlock) 197 | ]) 198 | 199 | self.Convs = nn.ModuleList([ 200 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 201 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 202 | ]) 203 | 204 | self.ConvsOut = nn.ModuleList( 205 | [ 206 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 207 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 208 | ] 209 | ) 210 | 211 | self.AFFs = nn.ModuleList([ 212 | AFF(base_channel * 7, base_channel*1, BasicConv=BasicConv), 213 | AFF(base_channel * 7, base_channel*2, BasicConv=BasicConv) 214 | ]) 215 | 216 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv) 217 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv) 218 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv) 219 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv) 220 | 221 | def forward(self, x): 222 | x_2 = F.interpolate(x, scale_factor=0.5) 223 | x_4 = F.interpolate(x_2, scale_factor=0.5) 224 | z2 = self.SCM2(x_2) 225 | z4 = self.SCM1(x_4) 226 | 227 | outputs = list() 228 | 229 | x_ = self.feat_extract[0](x) 230 | res1 = self.Encoder[0](x_) 231 | 232 | z = self.feat_extract[1](res1) 233 | z = self.FAM2(z, z2) 234 | res2 = self.Encoder[1](z) 235 | 236 | z = self.feat_extract[2](res2) 237 | z = self.FAM1(z, z4) 238 | z = self.Encoder[2](z) 239 | 240 | z12 = F.interpolate(res1, scale_factor=0.5) 241 | z21 = F.interpolate(res2, scale_factor=2) 242 | z42 = F.interpolate(z, scale_factor=2) 243 | z41 = F.interpolate(z42, scale_factor=2) 244 | 245 | res2 = self.AFFs[1](z12, res2, z42) 246 | res1 = self.AFFs[0](res1, z21, z41) 247 | 248 | z = self.Decoder[0](z) 249 | z_ = self.ConvsOut[0](z) 250 | z = self.feat_extract[3](z) 251 | if not self.inference: 252 | outputs.append(z_+x_4) 253 | 254 | z = torch.cat([z, res2], dim=1) 255 | z = self.Convs[0](z) 256 | z = self.Decoder[1](z) 257 | z_ = self.ConvsOut[1](z) 258 | z = self.feat_extract[4](z) 259 | if not self.inference: 260 | outputs.append(z_+x_2) 261 | 262 | z = torch.cat([z, res1], dim=1) 263 | z = self.Convs[1](z) 264 | z = self.Decoder[2](z) 265 | z = self.feat_extract[5](z) 266 | if not self.inference: 267 | outputs.append(z+x) 268 | # print(outputs) 269 | return outputs[::-1] 270 | else: 271 | return z+x 272 | class DeepRFT(nn.Module): 273 | def __init__(self, num_res=8, inference=False): 274 | super(DeepRFT, self).__init__() 275 | self.inference = inference 276 | if not inference: 277 | BasicConv = BasicConv_do 278 | ResBlock = ResBlock_do_fft_bench 279 | else: 280 | BasicConv = BasicConv_do_eval 281 | ResBlock = ResBlock_do_fft_bench_eval 282 | base_channel = 32 283 | 284 | self.Encoder = nn.ModuleList([ 285 | EBlock(base_channel, num_res, ResBlock=ResBlock), 286 | EBlock(base_channel*2, num_res, ResBlock=ResBlock), 287 | EBlock(base_channel*4, num_res, ResBlock=ResBlock), 288 | ]) 289 | 290 | self.feat_extract = nn.ModuleList([ 291 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 292 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 293 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 294 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 295 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 296 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 297 | ]) 298 | 299 | self.Decoder = nn.ModuleList([ 300 | DBlock(base_channel * 4, num_res, ResBlock=ResBlock), 301 | DBlock(base_channel * 2, num_res, ResBlock=ResBlock), 302 | DBlock(base_channel, num_res, ResBlock=ResBlock) 303 | ]) 304 | 305 | self.Convs = nn.ModuleList([ 306 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 307 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 308 | ]) 309 | 310 | self.ConvsOut = nn.ModuleList( 311 | [ 312 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 313 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 314 | ] 315 | ) 316 | 317 | self.AFFs = nn.ModuleList([ 318 | AFF(base_channel * 7, base_channel*1, BasicConv=BasicConv), 319 | AFF(base_channel * 7, base_channel*2, BasicConv=BasicConv) 320 | ]) 321 | 322 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv) 323 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv) 324 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv) 325 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv) 326 | 327 | def forward(self, x): 328 | x_2 = F.interpolate(x, scale_factor=0.5) 329 | x_4 = F.interpolate(x_2, scale_factor=0.5) 330 | z2 = self.SCM2(x_2) 331 | z4 = self.SCM1(x_4) 332 | 333 | outputs = list() 334 | 335 | x_ = self.feat_extract[0](x) 336 | res1 = self.Encoder[0](x_) 337 | 338 | z = self.feat_extract[1](res1) 339 | z = self.FAM2(z, z2) 340 | res2 = self.Encoder[1](z) 341 | 342 | z = self.feat_extract[2](res2) 343 | z = self.FAM1(z, z4) 344 | z = self.Encoder[2](z) 345 | 346 | z12 = F.interpolate(res1, scale_factor=0.5) 347 | z21 = F.interpolate(res2, scale_factor=2) 348 | z42 = F.interpolate(z, scale_factor=2) 349 | z41 = F.interpolate(z42, scale_factor=2) 350 | 351 | res2 = self.AFFs[1](z12, res2, z42) 352 | res1 = self.AFFs[0](res1, z21, z41) 353 | 354 | z = self.Decoder[0](z) 355 | z_ = self.ConvsOut[0](z) 356 | z = self.feat_extract[3](z) 357 | if not self.inference: 358 | outputs.append(z_+x_4) 359 | 360 | z = torch.cat([z, res2], dim=1) 361 | z = self.Convs[0](z) 362 | z = self.Decoder[1](z) 363 | z_ = self.ConvsOut[1](z) 364 | z = self.feat_extract[4](z) 365 | if not self.inference: 366 | outputs.append(z_+x_2) 367 | 368 | z = torch.cat([z, res1], dim=1) 369 | z = self.Convs[1](z) 370 | z = self.Decoder[2](z) 371 | z = self.feat_extract[5](z) 372 | if not self.inference: 373 | outputs.append(z+x) 374 | # print(outputs) 375 | return outputs[::-1] 376 | else: 377 | return z+x 378 | class DeepRFTPLUS(nn.Module): 379 | def __init__(self, num_res=20, inference=False): 380 | super(DeepRFTPLUS, self).__init__() 381 | # ResBlock = ResBlock_fft_bench 382 | self.inference = inference 383 | if not inference: 384 | BasicConv = BasicConv_do 385 | ResBlock = ResBlock_do_fft_bench 386 | else: 387 | BasicConv = BasicConv_do_eval 388 | ResBlock = ResBlock_do_fft_bench_eval 389 | base_channel = 32 390 | 391 | self.Encoder = nn.ModuleList([ 392 | EBlock(base_channel, num_res, ResBlock=ResBlock), 393 | EBlock(base_channel*2, num_res, ResBlock=ResBlock), 394 | EBlock(base_channel*4, num_res, ResBlock=ResBlock), 395 | ]) 396 | 397 | self.feat_extract = nn.ModuleList([ 398 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 399 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 400 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 401 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 402 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 403 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 404 | ]) 405 | 406 | self.Decoder = nn.ModuleList([ 407 | DBlock(base_channel * 4, num_res, ResBlock=ResBlock), 408 | DBlock(base_channel * 2, num_res, ResBlock=ResBlock), 409 | DBlock(base_channel, num_res, ResBlock=ResBlock) 410 | ]) 411 | 412 | self.Convs = nn.ModuleList([ 413 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 414 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 415 | ]) 416 | 417 | self.ConvsOut = nn.ModuleList( 418 | [ 419 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 420 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 421 | ] 422 | ) 423 | 424 | self.AFFs = nn.ModuleList([ 425 | AFF(base_channel * 7, base_channel*1, BasicConv=BasicConv), 426 | AFF(base_channel * 7, base_channel*2, BasicConv=BasicConv) 427 | ]) 428 | 429 | self.FAM1 = FAM(base_channel * 4, BasicConv=BasicConv) 430 | self.SCM1 = SCM(base_channel * 4, BasicConv=BasicConv) 431 | self.FAM2 = FAM(base_channel * 2, BasicConv=BasicConv) 432 | self.SCM2 = SCM(base_channel * 2, BasicConv=BasicConv) 433 | 434 | def forward(self, x): 435 | x_2 = F.interpolate(x, scale_factor=0.5) 436 | x_4 = F.interpolate(x_2, scale_factor=0.5) 437 | z2 = self.SCM2(x_2) 438 | z4 = self.SCM1(x_4) 439 | 440 | outputs = list() 441 | 442 | x_ = self.feat_extract[0](x) 443 | res1 = self.Encoder[0](x_) 444 | 445 | z = self.feat_extract[1](res1) 446 | z = self.FAM2(z, z2) 447 | res2 = self.Encoder[1](z) 448 | 449 | z = self.feat_extract[2](res2) 450 | z = self.FAM1(z, z4) 451 | z = self.Encoder[2](z) 452 | 453 | z12 = F.interpolate(res1, scale_factor=0.5) 454 | z21 = F.interpolate(res2, scale_factor=2) 455 | z42 = F.interpolate(z, scale_factor=2) 456 | z41 = F.interpolate(z42, scale_factor=2) 457 | 458 | res2 = self.AFFs[1](z12, res2, z42) 459 | res1 = self.AFFs[0](res1, z21, z41) 460 | 461 | z = self.Decoder[0](z) 462 | z_ = self.ConvsOut[0](z) 463 | z = self.feat_extract[3](z) 464 | if not self.inference: 465 | outputs.append(z_+x_4) 466 | 467 | z = torch.cat([z, res2], dim=1) 468 | z = self.Convs[0](z) 469 | z = self.Decoder[1](z) 470 | z_ = self.ConvsOut[1](z) 471 | z = self.feat_extract[4](z) 472 | if not self.inference: 473 | outputs.append(z_+x_2) 474 | 475 | z = torch.cat([z, res1], dim=1) 476 | z = self.Convs[1](z) 477 | z = self.Decoder[2](z) 478 | z = self.feat_extract[5](z) 479 | if not self.inference: 480 | outputs.append(z+x) 481 | # print(outputs) 482 | return outputs[::-1] 483 | else: 484 | return z+x 485 | 486 | 487 | 488 | 489 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-residual-fourier-transformation-for/deblurring-on-gopro)](https://paperswithcode.com/sota/deblurring-on-gopro?p=deep-residual-fourier-transformation-for) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-residual-fourier-transformation-for/deblurring-on-hide-trained-on-gopro)](https://paperswithcode.com/sota/deblurring-on-hide-trained-on-gopro?p=deep-residual-fourier-transformation-for) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-residual-fourier-transformation-for/deblurring-on-realblur-j-1)](https://paperswithcode.com/sota/deblurring-on-realblur-j-1?p=deep-residual-fourier-transformation-for) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-residual-fourier-transformation-for/deblurring-on-realblur-j-trained-on-gopro)](https://paperswithcode.com/sota/deblurring-on-realblur-j-trained-on-gopro?p=deep-residual-fourier-transformation-for) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-residual-fourier-transformation-for/deblurring-on-realblur-r)](https://paperswithcode.com/sota/deblurring-on-realblur-r?p=deep-residual-fourier-transformation-for) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-residual-fourier-transformation-for/deblurring-on-realblur-r-trained-on-gopro)](https://paperswithcode.com/sota/deblurring-on-realblur-r-trained-on-gopro?p=deep-residual-fourier-transformation-for) 7 | 8 | 9 | # Intriguing Findings of Frequency Selection for Image Deblurring (AAAI 2023) 10 | Xintian Mao, Yiming Liu, Fengze Liu, Qingli Li, Wei Shen and Yan Wang 11 | 12 | **Paper**: xxx 13 | 14 | **code**: https://github.com/INVOKERer/DeepRFT/tree/AAAI2023 15 | 16 | # Deep Residual Fourier Transformation for Single Image Deblurring 17 | Xintian Mao, Yiming Liu, Wei Shen, Qingli Li and Yan Wang 18 | 19 | 20 | **Paper**: https://arxiv.org/abs/2111.11745 21 | 22 | 23 | ## Network Architecture 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 |

Overall Framework of DeepRFT

32 | 33 | ## Installation 34 | The model is built in PyTorch 1.8.0 and tested on Ubuntu 18.04 environment (Python3.8, CUDA11.1). 35 | 36 | For installing, follow these intructions 37 | ``` 38 | conda create -n pytorch python=3.8 39 | conda activate pytorch 40 | conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge 41 | pip install matplotlib scikit-image opencv-python yacs joblib natsort h5py tqdm kornia tensorboard ptflops 42 | ``` 43 | 44 | Install warmup scheduler 45 | 46 | ``` 47 | cd pytorch-gradual-warmup-lr; python setup.py install; cd .. 48 | ``` 49 | 50 | ## Quick Run 51 | 52 | To test the pre-trained models of Deblur and Defocus [Google Drive](https://drive.google.com/file/d/1FoQZrbcYPGzU9xzOPI1Q1NybNUGR-ZUg/view?usp=sharing) or [百度网盘](https://pan.baidu.com/s/10DuQZiXC-Dc6jtLc9YJGbg)(提取码:phws) on your own images, run 53 | ``` 54 | python test.py --weights ckpt_path_here --input_dir path_to_images --result_dir save_images_here --win_size 256 --num_res 8 [4:small, 20:plus]# deblur 55 | python test.py --weights ckpt_path_here --input_dir path_to_images --result_dir save_images_here --win_size 512 --num_res 8 # defocus 56 | ``` 57 | Here is an example to train: 58 | ``` 59 | python train.py 60 | ``` 61 | 62 | 63 | ## Results 64 | Experiment for image deblurring. 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 |

Deblurring on GoPro Datasets.

73 | 74 | ## Reference Code: 75 | - https://github.com/yangyanli/DO-Conv 76 | - https://github.com/swz30/MPRNet 77 | - https://github.com/chosj95/MIMO-UNet 78 | - https://github.com/codeslake/IFAN 79 | 80 | ## Citation 81 | If you use DeepRFT, please consider citing: 82 | ``` 83 | @inproceedings{xint2023freqsel, 84 | title = {Intriguing Findings of Frequency Selection for Image Deblurring}, 85 | author = {Xintian Mao, Yiming Liu, Fengze Liu, Qingli Li, Wei Shen and Yan Wang}, 86 | booktitle = {Proceedings of the 37th AAAI Conference on Artificial Intelligence}, 87 | year = {2023} 88 | } 89 | or 90 | @inproceedings{, 91 | title={Deep Residual Fourier Transformation for Single Image Deblurring}, 92 | author={Xintian Mao, Yiming Liu, Wei Shen, Qingli Li, Yan Wang}, 93 | booktitle={arXiv:2111.11745}, 94 | year={2021} 95 | } 96 | ``` 97 | ## Contact 98 | If you have any question, please contact mxt_invoker1997@163.com 99 | -------------------------------------------------------------------------------- /data_RGB.py: -------------------------------------------------------------------------------- 1 | from dataset_RGB import * 2 | 3 | 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options) 7 | 8 | def get_validation_data(rgb_dir, img_options): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, img_options) 11 | 12 | def get_test_data(rgb_dir, img_options): 13 | assert os.path.exists(rgb_dir) 14 | return DataLoaderTest(rgb_dir, img_options) 15 | -------------------------------------------------------------------------------- /dataset_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | import torch 5 | from PIL import Image 6 | import torchvision.transforms.functional as TF 7 | import random 8 | 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 12 | 13 | class DataLoaderTrain(Dataset): 14 | def __init__(self, rgb_dir, img_options=None): 15 | super(DataLoaderTrain, self).__init__() 16 | 17 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'blur'))) 18 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'sharp'))) 19 | 20 | self.inp_filenames = [os.path.join(rgb_dir, 'blur', x) for x in inp_files if is_image_file(x)] 21 | self.tar_filenames = [os.path.join(rgb_dir, 'sharp', x) for x in tar_files if is_image_file(x)] 22 | 23 | self.img_options = img_options 24 | self.sizex = len(self.tar_filenames) # get the size of target 25 | 26 | self.ps = self.img_options['patch_size'] 27 | 28 | def __len__(self): 29 | return self.sizex 30 | 31 | def __getitem__(self, index): 32 | index_ = index % self.sizex 33 | ps = self.ps 34 | 35 | inp_path = self.inp_filenames[index_] 36 | tar_path = self.tar_filenames[index_] 37 | 38 | inp_img = Image.open(inp_path) 39 | tar_img = Image.open(tar_path) 40 | 41 | w,h = tar_img.size 42 | padw = ps-w if w 1: 61 | self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul)) 62 | init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32) 63 | self.D.data = torch.from_numpy(init_zero) 64 | 65 | eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N)) 66 | D_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N))) 67 | if self.D_mul % (M * N) != 0: # the cases when D_mul > M * N 68 | zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)]) 69 | self.D_diag = Parameter(torch.cat([D_diag, zeros], dim=2), requires_grad=False) 70 | else: # the case when D_mul = M * N 71 | self.D_diag = Parameter(D_diag, requires_grad=False) 72 | ################################################################################################## 73 | if simam: 74 | self.simam_block = simam_module() 75 | if bias: 76 | self.bias = Parameter(torch.Tensor(out_channels)) 77 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.W) 78 | bound = 1 / math.sqrt(fan_in) 79 | init.uniform_(self.bias, -bound, bound) 80 | else: 81 | self.register_parameter('bias', None) 82 | 83 | def extra_repr(self): 84 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 85 | ', stride={stride}') 86 | if self.padding != (0,) * len(self.padding): 87 | s += ', padding={padding}' 88 | if self.dilation != (1,) * len(self.dilation): 89 | s += ', dilation={dilation}' 90 | if self.groups != 1: 91 | s += ', groups={groups}' 92 | if self.bias is None: 93 | s += ', bias=False' 94 | if self.padding_mode != 'zeros': 95 | s += ', padding_mode={padding_mode}' 96 | return s.format(**self.__dict__) 97 | 98 | def __setstate__(self, state): 99 | super(DOConv2d, self).__setstate__(state) 100 | if not hasattr(self, 'padding_mode'): 101 | self.padding_mode = 'zeros' 102 | 103 | def _conv_forward(self, input, weight): 104 | if self.padding_mode != 'zeros': 105 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 106 | weight, self.bias, self.stride, 107 | (0, 0), self.dilation, self.groups) 108 | return F.conv2d(input, weight, self.bias, self.stride, 109 | self.padding, self.dilation, self.groups) 110 | 111 | def forward(self, input): 112 | M = self.kernel_size[0] 113 | N = self.kernel_size[1] 114 | DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N) 115 | if M * N > 1: 116 | ######################### Compute DoW ################# 117 | # (input_channels, D_mul, M * N) 118 | D = self.D + self.D_diag 119 | W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul)) 120 | 121 | # einsum outputs (out_channels // groups, in_channels, M * N), 122 | # which is reshaped to 123 | # (out_channels, in_channels // groups, M, N) 124 | DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape) 125 | ####################################################### 126 | else: 127 | DoW = torch.reshape(self.W, DoW_shape) 128 | if self.simam: 129 | DoW_h1, DoW_h2 = torch.chunk(DoW, 2, dim=2) 130 | DoW = torch.cat([self.simam_block(DoW_h1), DoW_h2], dim=2) 131 | 132 | return self._conv_forward(input, DoW) 133 | class DOConv2d_eval(Module): 134 | """ 135 | DOConv2d can be used as an alternative for torch.nn.Conv2d. 136 | The interface is similar to that of Conv2d, with one exception: 137 | 1. D_mul: the depth multiplier for the over-parameterization. 138 | Note that the groups parameter switchs between DO-Conv (groups=1), 139 | DO-DConv (groups=in_channels), DO-GConv (otherwise). 140 | """ 141 | __constants__ = ['stride', 'padding', 'dilation', 'groups', 142 | 'padding_mode', 'output_padding', 'in_channels', 143 | 'out_channels', 'kernel_size', 'D_mul'] 144 | __annotations__ = {'bias': Optional[torch.Tensor]} 145 | 146 | def __init__(self, in_channels, out_channels, kernel_size=3, D_mul=None, stride=1, 147 | padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros', simam=False): 148 | super(DOConv2d_eval, self).__init__() 149 | 150 | kernel_size = (kernel_size, kernel_size) 151 | stride = (stride, stride) 152 | padding = (padding, padding) 153 | dilation = (dilation, dilation) 154 | 155 | if in_channels % groups != 0: 156 | raise ValueError('in_channels must be divisible by groups') 157 | if out_channels % groups != 0: 158 | raise ValueError('out_channels must be divisible by groups') 159 | valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'} 160 | if padding_mode not in valid_padding_modes: 161 | raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format( 162 | valid_padding_modes, padding_mode)) 163 | self.in_channels = in_channels 164 | self.out_channels = out_channels 165 | self.kernel_size = kernel_size 166 | self.stride = stride 167 | self.padding = padding 168 | self.dilation = dilation 169 | self.groups = groups 170 | self.padding_mode = padding_mode 171 | self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2)) 172 | self.simam = simam 173 | #################################### Initailization of D & W ################################### 174 | M = self.kernel_size[0] 175 | N = self.kernel_size[1] 176 | self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, M, N)) 177 | init.kaiming_uniform_(self.W, a=math.sqrt(5)) 178 | 179 | self.register_parameter('bias', None) 180 | def extra_repr(self): 181 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 182 | ', stride={stride}') 183 | if self.padding != (0,) * len(self.padding): 184 | s += ', padding={padding}' 185 | if self.dilation != (1,) * len(self.dilation): 186 | s += ', dilation={dilation}' 187 | if self.groups != 1: 188 | s += ', groups={groups}' 189 | if self.bias is None: 190 | s += ', bias=False' 191 | if self.padding_mode != 'zeros': 192 | s += ', padding_mode={padding_mode}' 193 | return s.format(**self.__dict__) 194 | 195 | def __setstate__(self, state): 196 | super(DOConv2d, self).__setstate__(state) 197 | if not hasattr(self, 'padding_mode'): 198 | self.padding_mode = 'zeros' 199 | 200 | def _conv_forward(self, input, weight): 201 | if self.padding_mode != 'zeros': 202 | return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), 203 | weight, self.bias, self.stride, 204 | (0, 0), self.dilation, self.groups) 205 | return F.conv2d(input, weight, self.bias, self.stride, 206 | self.padding, self.dilation, self.groups) 207 | 208 | def forward(self, input): 209 | return self._conv_forward(input, self.W) 210 | 211 | class simam_module(torch.nn.Module): 212 | def __init__(self, e_lambda=1e-4): 213 | super(simam_module, self).__init__() 214 | self.activaton = nn.Sigmoid() 215 | self.e_lambda = e_lambda 216 | 217 | def forward(self, x): 218 | b, c, h, w = x.size() 219 | n = w * h - 1 220 | x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2) 221 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5 222 | return x * self.activaton(y) 223 | 224 | 225 | -------------------------------------------------------------------------------- /evaluate_GOPRO.m: -------------------------------------------------------------------------------- 1 | 2 | close all;clear all; 3 | 4 | datasets = {'GoPro'}; 5 | % datasets = {'GoPro', 'HIDE'}; 6 | num_set = length(datasets); 7 | 8 | for idx_set = 1:num_set 9 | file_path = strcat('./results/DeepRFT/GoPro/'); 10 | gt_path = strcat('./Datasets/GoPro/test/sharp/'); 11 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))]; 12 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))]; 13 | img_num = length(path_list); 14 | 15 | total_psnr = 0; 16 | total_ssim = 0; 17 | if img_num > 0 18 | for j = 1:img_num 19 | image_name = path_list(j).name; 20 | gt_name = gt_list(j).name; 21 | input = imread(strcat(file_path,image_name)); 22 | gt = imread(strcat(gt_path, gt_name)); 23 | ssim_val = ssim(input, gt); 24 | psnr_val = psnr(input, gt); 25 | total_ssim = total_ssim + ssim_val; 26 | total_psnr = total_psnr + psnr_val; 27 | end 28 | end 29 | qm_psnr = total_psnr / img_num; 30 | qm_ssim = total_ssim / img_num; 31 | 32 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim); 33 | 34 | end 35 | -------------------------------------------------------------------------------- /evaluate_RealBlur.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from glob import glob 4 | from natsort import natsorted 5 | from skimage import io 6 | import cv2 7 | from skimage.metrics import structural_similarity 8 | from tqdm import tqdm 9 | import concurrent.futures 10 | 11 | def image_align(deblurred, gt): 12 | # this function is based on kohler evaluation code 13 | z = deblurred 14 | c = np.ones_like(z) 15 | x = gt 16 | 17 | zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching 18 | 19 | warp_mode = cv2.MOTION_HOMOGRAPHY 20 | warp_matrix = np.eye(3, 3, dtype=np.float32) 21 | 22 | # Specify the number of iterations. 23 | number_of_iterations = 100 24 | 25 | termination_eps = 0 26 | 27 | criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 28 | number_of_iterations, termination_eps) 29 | 30 | # Run the ECC algorithm. The results are stored in warp_matrix. 31 | (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5) 32 | 33 | target_shape = x.shape 34 | shift = warp_matrix 35 | 36 | zr = cv2.warpPerspective( 37 | zs, 38 | warp_matrix, 39 | (target_shape[1], target_shape[0]), 40 | flags=cv2.INTER_CUBIC+ cv2.WARP_INVERSE_MAP, 41 | borderMode=cv2.BORDER_REFLECT) 42 | 43 | cr = cv2.warpPerspective( 44 | np.ones_like(zs, dtype='float32'), 45 | warp_matrix, 46 | (target_shape[1], target_shape[0]), 47 | flags=cv2.INTER_NEAREST+ cv2.WARP_INVERSE_MAP, 48 | borderMode=cv2.BORDER_CONSTANT, 49 | borderValue=0) 50 | 51 | zr = zr * cr 52 | xr = x * cr 53 | 54 | return zr, xr, cr, shift 55 | 56 | def compute_psnr(image_true, image_test, image_mask, data_range=None): 57 | # this function is based on skimage.metrics.peak_signal_noise_ratio 58 | err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask) 59 | return 10 * np.log10((data_range ** 2) / err) 60 | 61 | 62 | def compute_ssim(tar_img, prd_img, cr1): 63 | ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, use_sample_covariance=False, data_range = 1.0, full=True) 64 | ssim_map = ssim_map * cr1 65 | r = int(3.5 * 1.5 + 0.5) # radius as in ndimage 66 | win_size = 2 * r + 1 67 | pad = (win_size - 1) // 2 68 | ssim = ssim_map[pad:-pad,pad:-pad,:] 69 | crop_cr1 = cr1[pad:-pad,pad:-pad,:] 70 | ssim = ssim.sum(axis=0).sum(axis=0)/crop_cr1.sum(axis=0).sum(axis=0) 71 | ssim = np.mean(ssim) 72 | return ssim 73 | 74 | def proc(filename): 75 | tar,prd = filename 76 | tar_img = io.imread(tar) 77 | prd_img = io.imread(prd) 78 | 79 | tar_img = tar_img.astype(np.float32)/255.0 80 | prd_img = prd_img.astype(np.float32)/255.0 81 | 82 | prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img) 83 | 84 | PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1) 85 | SSIM = compute_ssim(tar_img, prd_img, cr1) 86 | return (PSNR,SSIM) 87 | 88 | datasets = ['RealBlur_J', 'RealBlur_R'] 89 | # datasets = ['RealBlur_J'] 90 | # datasets = ['RealBlur_R'] 91 | for dataset in datasets: 92 | 93 | file_path = os.path.join('./results/DeepRFT', dataset) 94 | gt_path = os.path.join('./Datasets/RealBlur', dataset, 'test/sharp') 95 | 96 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg'))) 97 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg'))) 98 | 99 | assert len(path_list) != 0, "Predicted files not found" 100 | assert len(gt_list) != 0, "Target files not found" 101 | 102 | psnr, ssim = [], [] 103 | 104 | img_files = [(i, j) for i,j in zip(gt_list,path_list)] 105 | # print(img_files) 106 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 107 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 108 | 109 | psnr.append(PSNR_SSIM[0]) 110 | ssim.append(PSNR_SSIM[1]) 111 | # print(filename, PSNR_SSIM[0]) 112 | 113 | avg_psnr = sum(psnr)/len(psnr) 114 | avg_ssim = sum(ssim)/len(ssim) 115 | 116 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 117 | -------------------------------------------------------------------------------- /get_parameter_number.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_parameter_number(net): 4 | total_num = sum(np.prod(p.size()) for p in net.parameters()) 5 | trainable_num = sum(np.prod(p.size()) for p in net.parameters() if p.requires_grad) 6 | print('Total: ', total_num) 7 | print('Trainable: ', trainable_num) 8 | 9 | 10 | if __name__=='__main__': 11 | from DeepRFT_MIMO import DeepRFT_flops as Net 12 | import torch 13 | from ptflops import get_model_complexity_info 14 | with torch.cuda.device(0): 15 | net = Net() 16 | macs, params = get_model_complexity_info(net, (3, 256, 256), as_strings=True, 17 | print_per_layer_stat=True, verbose=True) 18 | print('{:<30} {:<8}'.format('Computational complexity: ', macs)) 19 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 20 | -------------------------------------------------------------------------------- /images/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/images/framework.png -------------------------------------------------------------------------------- /images/psnr_params_flops.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/images/psnr_params_flops.png -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | from doconv_pytorch import * 2 | 3 | 4 | class BasicConv(nn.Module): 5 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=False, norm=False, relu=True, transpose=False, 6 | channel_shuffle_g=0, norm_method=nn.BatchNorm2d, groups=1): 7 | super(BasicConv, self).__init__() 8 | self.channel_shuffle_g = channel_shuffle_g 9 | self.norm = norm 10 | if bias and norm: 11 | bias = False 12 | 13 | padding = kernel_size // 2 14 | layers = list() 15 | if transpose: 16 | padding = kernel_size // 2 - 1 17 | layers.append( 18 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups)) 19 | else: 20 | layers.append( 21 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups)) 22 | if norm: 23 | layers.append(norm_method(out_channel)) 24 | elif relu: 25 | layers.append(nn.ReLU(inplace=True)) 26 | 27 | self.main = nn.Sequential(*layers) 28 | 29 | def forward(self, x): 30 | return self.main(x) 31 | 32 | class BasicConv_do(nn.Module): 33 | def __init__(self, in_channel, out_channel, kernel_size, stride=1, bias=False, norm=False, relu=True, transpose=False, 34 | relu_method=nn.ReLU, groups=1, norm_method=nn.BatchNorm2d): 35 | super(BasicConv_do, self).__init__() 36 | if bias and norm: 37 | bias = False 38 | 39 | padding = kernel_size // 2 40 | layers = list() 41 | if transpose: 42 | padding = kernel_size // 2 - 1 43 | layers.append( 44 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 45 | else: 46 | layers.append( 47 | DOConv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups)) 48 | if norm: 49 | layers.append(norm_method(out_channel)) 50 | if relu: 51 | if relu_method == nn.ReLU: 52 | layers.append(nn.ReLU(inplace=True)) 53 | elif relu_method == nn.LeakyReLU: 54 | layers.append(nn.LeakyReLU(inplace=True)) 55 | else: 56 | layers.append(relu_method()) 57 | self.main = nn.Sequential(*layers) 58 | 59 | def forward(self, x): 60 | return self.main(x) 61 | 62 | class BasicConv_do_eval(nn.Module): 63 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=False, norm=False, relu=True, transpose=False, 64 | relu_method=nn.ReLU, groups=1, norm_method=nn.BatchNorm2d): 65 | super(BasicConv_do_eval, self).__init__() 66 | if bias and norm: 67 | bias = False 68 | 69 | padding = kernel_size // 2 70 | layers = list() 71 | if transpose: 72 | padding = kernel_size // 2 - 1 73 | layers.append( 74 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 75 | else: 76 | layers.append( 77 | DOConv2d_eval(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias, groups=groups)) 78 | if norm: 79 | layers.append(norm_method(out_channel)) 80 | if relu: 81 | if relu_method == nn.ReLU: 82 | layers.append(nn.ReLU(inplace=True)) 83 | elif relu_method == nn.LeakyReLU: 84 | layers.append(nn.LeakyReLU(inplace=True)) 85 | else: 86 | layers.append(relu_method()) 87 | self.main = nn.Sequential(*layers) 88 | 89 | def forward(self, x): 90 | return self.main(x) 91 | 92 | class ResBlock(nn.Module): 93 | def __init__(self, out_channel): 94 | super(ResBlock, self).__init__() 95 | self.main = nn.Sequential( 96 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=True, norm=False), 97 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False, norm=False) 98 | ) 99 | 100 | def forward(self, x): 101 | return self.main(x) + x 102 | 103 | class ResBlock_do(nn.Module): 104 | def __init__(self, out_channel): 105 | super(ResBlock_do, self).__init__() 106 | self.main = nn.Sequential( 107 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=True), 108 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 109 | ) 110 | 111 | def forward(self, x): 112 | return self.main(x) + x 113 | 114 | class ResBlock_do_eval(nn.Module): 115 | def __init__(self, out_channel): 116 | super(ResBlock_do_eval, self).__init__() 117 | self.main = nn.Sequential( 118 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=True), 119 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 120 | ) 121 | 122 | def forward(self, x): 123 | return self.main(x) + x 124 | 125 | 126 | class ResBlock_do_fft_bench(nn.Module): 127 | def __init__(self, out_channel, norm='backward'): 128 | super(ResBlock_do_fft_bench, self).__init__() 129 | self.main = nn.Sequential( 130 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=True), 131 | BasicConv_do(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 132 | ) 133 | self.main_fft = nn.Sequential( 134 | BasicConv_do(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=True), 135 | BasicConv_do(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=False) 136 | ) 137 | self.dim = out_channel 138 | self.norm = norm 139 | def forward(self, x): 140 | _, _, H, W = x.shape 141 | dim = 1 142 | y = torch.fft.rfft2(x, norm=self.norm) 143 | y_imag = y.imag 144 | y_real = y.real 145 | y_f = torch.cat([y_real, y_imag], dim=dim) 146 | y = self.main_fft(y_f) 147 | y_real, y_imag = torch.chunk(y, 2, dim=dim) 148 | y = torch.complex(y_real, y_imag) 149 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm) 150 | return self.main(x) + x + y 151 | 152 | class ResBlock_fft_bench(nn.Module): 153 | def __init__(self, n_feat, norm='backward'): # 'ortho' 154 | super(ResBlock_fft_bench, self).__init__() 155 | self.main = nn.Sequential( 156 | BasicConv(n_feat, n_feat, kernel_size=3, stride=1, relu=True), 157 | BasicConv(n_feat, n_feat, kernel_size=3, stride=1, relu=False) 158 | ) 159 | self.main_fft = nn.Sequential( 160 | BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=True), 161 | BasicConv(n_feat*2, n_feat*2, kernel_size=1, stride=1, relu=False) 162 | ) 163 | self.dim = n_feat 164 | self.norm = norm 165 | def forward(self, x): 166 | _, _, H, W = x.shape 167 | dim = 1 168 | y = torch.fft.rfft2(x, norm=self.norm) 169 | y_imag = y.imag 170 | y_real = y.real 171 | y_f = torch.cat([y_real, y_imag], dim=dim) 172 | y = self.main_fft(y_f) 173 | y_real, y_imag = torch.chunk(y, 2, dim=dim) 174 | y = torch.complex(y_real, y_imag) 175 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm) 176 | return self.main(x) + x + y 177 | class ResBlock_do_fft_bench_eval(nn.Module): 178 | def __init__(self, out_channel, norm='backward'): 179 | super(ResBlock_do_fft_bench_eval, self).__init__() 180 | self.main = nn.Sequential( 181 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=True), 182 | BasicConv_do_eval(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 183 | ) 184 | self.main_fft = nn.Sequential( 185 | BasicConv_do_eval(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=True), 186 | BasicConv_do_eval(out_channel*2, out_channel*2, kernel_size=1, stride=1, relu=False) 187 | ) 188 | self.dim = out_channel 189 | self.norm = norm 190 | def forward(self, x): 191 | _, _, H, W = x.shape 192 | dim = 1 193 | y = torch.fft.rfft2(x, norm=self.norm) 194 | y_imag = y.imag 195 | y_real = y.real 196 | y_f = torch.cat([y_real, y_imag], dim=dim) 197 | y = self.main_fft(y_f) 198 | y_real, y_imag = torch.chunk(y, 2, dim=dim) 199 | y = torch.complex(y_real, y_imag) 200 | y = torch.fft.irfft2(y, s=(H, W), norm=self.norm) 201 | return self.main(x) + x + y 202 | 203 | def window_partitions(x, window_size): 204 | """ 205 | Args: 206 | x: (B, C, H, W) 207 | window_size (int): window size 208 | 209 | Returns: 210 | windows: (num_windows*B, C, window_size, window_size) 211 | """ 212 | B, C, H, W = x.shape 213 | x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) 214 | windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size, window_size) 215 | return windows 216 | 217 | 218 | def window_reverses(windows, window_size, H, W): 219 | """ 220 | Args: 221 | windows: (num_windows*B, C, window_size, window_size) 222 | window_size (int): Window size 223 | H (int): Height of image 224 | W (int): Width of image 225 | 226 | Returns: 227 | x: (B, C, H, W) 228 | """ 229 | # B = int(windows.shape[0] / (H * W / window_size / window_size)) 230 | # print('B: ', B) 231 | # print(H // window_size) 232 | # print(W // window_size) 233 | C = windows.shape[1] 234 | # print('C: ', C) 235 | x = windows.view(-1, H // window_size, W // window_size, C, window_size, window_size) 236 | x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) 237 | return x 238 | 239 | def window_partitionx(x, window_size): 240 | _, _, H, W = x.shape 241 | h, w = window_size * (H // window_size), window_size * (W // window_size) 242 | x_main = window_partitions(x[:, :, :h, :w], window_size) 243 | b_main = x_main.shape[0] 244 | if h == H and w == W: 245 | return x_main, [b_main] 246 | if h != H and w != W: 247 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size) 248 | b_r = x_r.shape[0] + b_main 249 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size) 250 | b_d = x_d.shape[0] + b_r 251 | x_dd = x[:, :, -window_size:, -window_size:] 252 | b_dd = x_dd.shape[0] + b_d 253 | # batch_list = [b_main, b_r, b_d, b_dd] 254 | return torch.cat([x_main, x_r, x_d, x_dd], dim=0), [b_main, b_r, b_d, b_dd] 255 | if h == H and w != W: 256 | x_r = window_partitions(x[:, :, :h, -window_size:], window_size) 257 | b_r = x_r.shape[0] + b_main 258 | return torch.cat([x_main, x_r], dim=0), [b_main, b_r] 259 | if h != H and w == W: 260 | x_d = window_partitions(x[:, :, -window_size:, :w], window_size) 261 | b_d = x_d.shape[0] + b_main 262 | return torch.cat([x_main, x_d], dim=0), [b_main, b_d] 263 | 264 | def window_reversex(windows, window_size, H, W, batch_list): 265 | h, w = window_size * (H // window_size), window_size * (W // window_size) 266 | x_main = window_reverses(windows[:batch_list[0], ...], window_size, h, w) 267 | B, C, _, _ = x_main.shape 268 | # print('windows: ', windows.shape) 269 | # print('batch_list: ', batch_list) 270 | res = torch.zeros([B, C, H, W],device=windows.device) 271 | res[:, :, :h, :w] = x_main 272 | if h == H and w == W: 273 | return res 274 | if h != H and w != W and len(batch_list) == 4: 275 | x_dd = window_reverses(windows[batch_list[2]:, ...], window_size, window_size, window_size) 276 | res[:, :, h:, w:] = x_dd[:, :, h - H:, w - W:] 277 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size) 278 | res[:, :, :h, w:] = x_r[:, :, :, w - W:] 279 | x_d = window_reverses(windows[batch_list[1]:batch_list[2], ...], window_size, window_size, w) 280 | res[:, :, h:, :w] = x_d[:, :, h - H:, :] 281 | return res 282 | if w != W and len(batch_list) == 2: 283 | x_r = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, h, window_size) 284 | res[:, :, :h, w:] = x_r[:, :, :, w - W:] 285 | if h != H and len(batch_list) == 2: 286 | x_d = window_reverses(windows[batch_list[0]:batch_list[1], ...], window_size, window_size, w) 287 | res[:, :, h:, :w] = x_d[:, :, h - H:, :] 288 | return res 289 | -------------------------------------------------------------------------------- /license.md: -------------------------------------------------------------------------------- 1 | ## ACADEMIC PUBLIC LICENSE 2 | 3 | ### Permissions 4 | :heavy_check_mark: Non-Commercial use 5 | :heavy_check_mark: Modification 6 | :heavy_check_mark: Distribution 7 | :heavy_check_mark: Private use 8 | 9 | ### Limitations 10 | :x: Commercial Use 11 | :x: Liability 12 | :x: Warranty 13 | 14 | ### Conditions 15 | :information_source: License and copyright notice 16 | :information_source: Same License 17 | 18 | DeepRFT is free for use in noncommercial settings: at academic institutions for teaching and research use, and at non-profit research organizations. 19 | You can use DeepRFT in your research, academic work, non-commercial work, projects and personal work. We only ask you to credit us appropriately. 20 | 21 | You have the right to use the software, to distribute copies, to receive source code, to change the software and distribute your modifications or the modified software. 22 | If you distribute verbatim or modified copies of this software, they must be distributed under this license. 23 | This license guarantees that you're safe when using DeepRFT in your work, for teaching or research. 24 | This license guarantees that DeepRFT will remain available free of charge for nonprofit use. 25 | You can modify DeepRFT to your purposes, and you can also share your modifications. 26 | 27 | If you would like to use DeepRFT in commercial settings, contact us so we can discuss options. Send an email to mxt_invoker1997@163.com 28 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-3): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x.to('cuda:0') - y.to('cuda:0') 14 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 15 | return loss 16 | 17 | class EdgeLoss(nn.Module): 18 | def __init__(self): 19 | super(EdgeLoss, self).__init__() 20 | k = torch.Tensor([[.05, .25, .4, .25, .05]]) 21 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1) 22 | if torch.cuda.is_available(): 23 | self.kernel = self.kernel.to('cuda:0') 24 | self.loss = CharbonnierLoss() 25 | 26 | def conv_gauss(self, img): 27 | n_channels, _, kw, kh = self.kernel.shape 28 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 29 | return F.conv2d(img, self.kernel, groups=n_channels) 30 | 31 | def laplacian_kernel(self, current): 32 | filtered = self.conv_gauss(current) 33 | down = filtered[:,:,::2,::2] 34 | new_filter = torch.zeros_like(filtered) 35 | new_filter[:,:,::2,::2] = down*4 36 | filtered = self.conv_gauss(new_filter) 37 | diff = current - filtered 38 | return diff 39 | 40 | def forward(self, x, y): 41 | loss = self.loss(self.laplacian_kernel(x.to('cuda:0')), self.laplacian_kernel(y.to('cuda:0'))) 42 | return loss 43 | 44 | class fftLoss(nn.Module): 45 | def __init__(self): 46 | super(fftLoss, self).__init__() 47 | 48 | def forward(self, x, y): 49 | diff = torch.fft.fft2(x.to('cuda:0')) - torch.fft.fft2(y.to('cuda:0')) 50 | loss = torch.mean(abs(diff)) 51 | return loss 52 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import setuptools 6 | 7 | _VERSION = '0.3' 8 | 9 | REQUIRED_PACKAGES = [ 10 | ] 11 | 12 | DEPENDENCY_LINKS = [ 13 | ] 14 | 15 | setuptools.setup( 16 | name='warmup_scheduler', 17 | version=_VERSION, 18 | description='Gradually Warm-up LR Scheduler for Pytorch', 19 | install_requires=REQUIRED_PACKAGES, 20 | dependency_links=DEPENDENCY_LINKS, 21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr', 22 | license='MIT License', 23 | package_dir={}, 24 | packages=setuptools.find_packages(exclude=['tests']), 25 | ) 26 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch.nn as nn 4 | import torch 5 | from torch.utils.data import DataLoader 6 | import utils 7 | from data_RGB import get_test_data 8 | from DeepRFT_MIMO import DeepRFT as mynet 9 | from skimage import img_as_ubyte 10 | from get_parameter_number import get_parameter_number 11 | from tqdm import tqdm 12 | from layers import * 13 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 14 | import cv2 15 | 16 | 17 | parser = argparse.ArgumentParser(description='Image Deblurring') 18 | parser.add_argument('--input_dir', default='./Datasets/GoPro/test/blur', type=str, help='Directory of validation images') 19 | parser.add_argument('--target_dir', default='./Datasets/GoPro/test/sharp', type=str, help='Directory of validation images') 20 | parser.add_argument('--output_dir', default='./results/DeepRFT/GoPro', type=str, help='Directory of validation images') 21 | parser.add_argument('--weights', default='./checkpoints/DeepRFT/model_GoPro.pth', type=str, help='Path to weights') 22 | parser.add_argument('--get_psnr', default=False, type=bool, help='PSNR') 23 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES') 24 | parser.add_argument('--save_result', default=False, type=bool, help='save result') 25 | parser.add_argument('--win_size', default=256, type=int, help='window size, [GoPro, HIDE, RealBlur]=256, [DPDD]=512') 26 | parser.add_argument('--num_res', default=8, type=int, help='num of resblocks, [Small, Med, PLus]=[4, 8, 20]') 27 | args = parser.parse_args() 28 | result_dir = args.output_dir 29 | win = args.win_size 30 | get_psnr = args.get_psnr 31 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 32 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 33 | # model_restoration = mynet() 34 | model_restoration = mynet(num_res=args.num_res, inference=True) 35 | # print number of model 36 | get_parameter_number(model_restoration) 37 | # utils.load_checkpoint(model_restoration, args.weights) 38 | utils.load_checkpoint_compress_doconv(model_restoration, args.weights) 39 | print("===>Testing using weights: ", args.weights) 40 | model_restoration.cuda() 41 | model_restoration = nn.DataParallel(model_restoration) 42 | model_restoration.eval() 43 | 44 | # dataset = args.dataset 45 | rgb_dir_test = args.input_dir 46 | test_dataset = get_test_data(rgb_dir_test, img_options={}) 47 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) 48 | psnr_val_rgb = [] 49 | psnr = 0 50 | 51 | utils.mkdir(result_dir) 52 | 53 | with torch.no_grad(): 54 | psnr_list = [] 55 | ssim_list = [] 56 | for ii, data_test in enumerate(tqdm(test_loader), 0): 57 | 58 | torch.cuda.ipc_collect() 59 | torch.cuda.empty_cache() 60 | input_ = data_test[0].cuda() 61 | filenames = data_test[1] 62 | _, _, Hx, Wx = input_.shape 63 | filenames = data_test[1] 64 | input_re, batch_list = window_partitionx(input_, win) 65 | restored = model_restoration(input_re) 66 | restored = window_reversex(restored, win, Hx, Wx, batch_list) 67 | 68 | restored = torch.clamp(restored, 0, 1) 69 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() 70 | for batch in range(len(restored)): 71 | restored_img = restored[batch] 72 | restored_img = img_as_ubyte(restored[batch]) 73 | if get_psnr: 74 | rgb_gt = cv2.imread(os.path.join(args.target_dir, filenames[batch]+'.png')) 75 | rgb_gt = cv2.cvtColor(rgb_gt, cv2.COLOR_BGR2RGB) 76 | psnr_val_rgb.append(psnr_loss(restored_img, rgb_gt)) 77 | if args.save_result: 78 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img) 79 | 80 | if get_psnr: 81 | psnr = sum(psnr_val_rgb) / len(test_dataset) 82 | print("PSNR: %f" % psnr) 83 | -------------------------------------------------------------------------------- /test_speed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch.nn as nn 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from data_RGB import get_test_data 8 | from DeepRFT_MIMO import DeepRFT as mynet 9 | 10 | from get_parameter_number import get_parameter_number 11 | from tqdm import tqdm 12 | from layers import * 13 | import time 14 | 15 | 16 | parser = argparse.ArgumentParser(description='Image Deblurring') 17 | parser.add_argument('--input_dir', default='./Datasets/GoPro/test/blur', type=str, help='Directory of validation images') 18 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES') 19 | 20 | args = parser.parse_args() 21 | 22 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 23 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 24 | 25 | # model_restoration = mynet(inference=True) 26 | model_restoration = mynet() 27 | # print number of model 28 | get_parameter_number(model_restoration) 29 | 30 | # utils.load_checkpoint_compress_doconv(model_restoration, args.weights) 31 | # print("===>Testing using weights: ", args.weights) 32 | model_restoration.cuda() 33 | model_restoration = nn.DataParallel(model_restoration) 34 | model_restoration.eval() 35 | 36 | # dataset = args.dataset 37 | rgb_dir_test = args.input_dir 38 | test_dataset = get_test_data(rgb_dir_test, img_options={}) 39 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) 40 | 41 | win = 256 42 | all_time = 0. 43 | with torch.no_grad(): 44 | psnr_list = [] 45 | ssim_list = [] 46 | for ii, data_test in enumerate(tqdm(test_loader), 0): 47 | 48 | torch.cuda.ipc_collect() 49 | torch.cuda.empty_cache() 50 | 51 | input_ = data_test[0].cuda() 52 | filenames = data_test[1] 53 | _, _, Hx, Wx = input_.shape 54 | filenames = data_test[1] 55 | 56 | torch.cuda.synchronize() 57 | start = time.time() 58 | input_re, batch_list = window_partitionx(input_, win) 59 | restored = model_restoration(input_re) 60 | # print(restored[0].shape) 61 | restored = window_reversex(restored[0], win, Hx, Wx, batch_list) 62 | restored = torch.clamp(restored, 0, 1) 63 | # print(restored.shape) 64 | torch.cuda.synchronize() 65 | end = time.time() 66 | all_time += end - start 67 | print('average_time: ', all_time / len(test_dataset)) 68 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' 4 | 5 | import torch 6 | torch.backends.cudnn.benchmark = True 7 | 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | 12 | import random 13 | import time 14 | import numpy as np 15 | 16 | import utils 17 | from data_RGB import get_training_data, get_validation_data 18 | from DeepRFT_MIMO import DeepRFT as myNet 19 | import losses 20 | from warmup_scheduler import GradualWarmupScheduler 21 | from tqdm import tqdm 22 | from get_parameter_number import get_parameter_number 23 | import kornia 24 | from torch.utils.tensorboard import SummaryWriter 25 | import argparse 26 | 27 | ######### Set Seeds ########### 28 | random.seed(1234) 29 | np.random.seed(1234) 30 | torch.manual_seed(1234) 31 | torch.cuda.manual_seed_all(1234) 32 | 33 | start_epoch = 1 34 | 35 | parser = argparse.ArgumentParser(description='Image Deblurring') 36 | 37 | parser.add_argument('--train_dir', default='./Datasets/GoPro/train', type=str, help='Directory of train images') 38 | parser.add_argument('--val_dir', default='./Datasets/GoPro/val', type=str, help='Directory of validation images') 39 | parser.add_argument('--model_save_dir', default='./checkpoints', type=str, help='Path to save weights') 40 | parser.add_argument('--pretrain_weights', default='./checkpoints/model_best.pth', type=str, help='Path to pretrain-weights') 41 | parser.add_argument('--mode', default='Deblurring', type=str) 42 | parser.add_argument('--session', default='DeepRFT_gopro', type=str, help='session') 43 | parser.add_argument('--patch_size', default=256, type=int, help='patch size, for paper: [GoPro, HIDE, RealBlur]=256, [DPDD]=512') 44 | parser.add_argument('--num_epochs', default=3000, type=int, help='num_epochs') 45 | parser.add_argument('--batch_size', default=16, type=int, help='batch_size') 46 | parser.add_argument('--val_epochs', default=20, type=int, help='val_epochs') 47 | args = parser.parse_args() 48 | 49 | mode = args.mode 50 | session = args.session 51 | patch_size = args.patch_size 52 | 53 | model_dir = os.path.join(args.model_save_dir, mode, 'models', session) 54 | utils.mkdir(model_dir) 55 | 56 | train_dir = args.train_dir 57 | val_dir = args.val_dir 58 | 59 | num_epochs = args.num_epochs 60 | batch_size = args.batch_size 61 | val_epochs = args.val_epochs 62 | 63 | start_lr = 2e-4 64 | end_lr = 1e-6 65 | 66 | ######### Model ########### 67 | model_restoration = myNet() 68 | 69 | # print number of model 70 | get_parameter_number(model_restoration) 71 | 72 | model_restoration.cuda() 73 | 74 | device_ids = [i for i in range(torch.cuda.device_count())] 75 | if torch.cuda.device_count() > 1: 76 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") 77 | 78 | optimizer = optim.Adam(model_restoration.parameters(), lr=start_lr, betas=(0.9, 0.999), eps=1e-8) 79 | 80 | ######### Scheduler ########### 81 | warmup_epochs = 3 82 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs-warmup_epochs, eta_min=end_lr) 83 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 84 | 85 | RESUME = False 86 | Pretrain = False 87 | model_pre_dir = '' 88 | ######### Pretrain ########### 89 | if Pretrain: 90 | utils.load_checkpoint(model_restoration, model_pre_dir) 91 | 92 | print('------------------------------------------------------------------------------') 93 | print("==> Retrain Training with: " + model_pre_dir) 94 | print('------------------------------------------------------------------------------') 95 | 96 | ######### Resume ########### 97 | if RESUME: 98 | path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') 99 | utils.load_checkpoint(model_restoration,path_chk_rest) 100 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 101 | utils.load_optim(optimizer, path_chk_rest) 102 | 103 | for i in range(1, start_epoch): 104 | scheduler.step() 105 | new_lr = scheduler.get_lr()[0] 106 | print('------------------------------------------------------------------------------') 107 | print("==> Resuming Training with learning rate:", new_lr) 108 | print('------------------------------------------------------------------------------') 109 | 110 | if len(device_ids)>1: 111 | model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids) 112 | 113 | ######### Loss ########### 114 | criterion_char = losses.CharbonnierLoss() 115 | criterion_edge = losses.EdgeLoss() 116 | criterion_fft = losses.fftLoss() 117 | ######### DataLoaders ########### 118 | train_dataset = get_training_data(train_dir, {'patch_size':patch_size}) 119 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=False, pin_memory=True) 120 | 121 | val_dataset = get_validation_data(val_dir, {'patch_size':patch_size}) 122 | val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) 123 | 124 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch, num_epochs + 1)) 125 | print('===> Loading datasets') 126 | 127 | best_psnr = 0 128 | best_epoch = 0 129 | writer = SummaryWriter(model_dir) 130 | iter = 0 131 | 132 | for epoch in range(start_epoch, num_epochs + 1): 133 | epoch_start_time = time.time() 134 | epoch_loss = 0 135 | train_id = 1 136 | 137 | model_restoration.train() 138 | for i, data in enumerate(tqdm(train_loader), 0): 139 | 140 | # zero_grad 141 | for param in model_restoration.parameters(): 142 | param.grad = None 143 | 144 | target_ = data[0].cuda() 145 | input_ = data[1].cuda() 146 | target = kornia.geometry.transform.build_pyramid(target_, 3) 147 | restored = model_restoration(input_) 148 | 149 | loss_fft = criterion_fft(restored[0], target[0]) + criterion_fft(restored[1], target[1]) + criterion_fft( 150 | restored[2], target[2]) 151 | loss_char = criterion_char(restored[0], target[0]) + criterion_char(restored[1], target[1]) + criterion_char(restored[2], target[2]) 152 | loss_edge = criterion_edge(restored[0], target[0]) + criterion_edge(restored[1], target[1]) + criterion_edge(restored[2], target[2]) 153 | loss = loss_char + 0.01 * loss_fft + 0.05 * loss_edge 154 | loss.backward() 155 | optimizer.step() 156 | epoch_loss +=loss.item() 157 | iter += 1 158 | writer.add_scalar('loss/fft_loss', loss_fft, iter) 159 | writer.add_scalar('loss/char_loss', loss_char, iter) 160 | writer.add_scalar('loss/edge_loss', loss_edge, iter) 161 | writer.add_scalar('loss/iter_loss', loss, iter) 162 | writer.add_scalar('loss/epoch_loss', epoch_loss, epoch) 163 | #### Evaluation #### 164 | if epoch % val_epochs == 0: 165 | model_restoration.eval() 166 | psnr_val_rgb = [] 167 | for ii, data_val in enumerate((val_loader), 0): 168 | target = data_val[0].cuda() 169 | input_ = data_val[1].cuda() 170 | 171 | with torch.no_grad(): 172 | restored = model_restoration(input_) 173 | 174 | for res,tar in zip(restored[0], target): 175 | psnr_val_rgb.append(utils.torchPSNR(res, tar)) 176 | 177 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() 178 | writer.add_scalar('val/psnr', psnr_val_rgb, epoch) 179 | if psnr_val_rgb > best_psnr: 180 | best_psnr = psnr_val_rgb 181 | best_epoch = epoch 182 | torch.save({'epoch': epoch, 183 | 'state_dict': model_restoration.state_dict(), 184 | 'optimizer' : optimizer.state_dict() 185 | }, os.path.join(model_dir,"model_best.pth")) 186 | 187 | print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr)) 188 | 189 | torch.save({'epoch': epoch, 190 | 'state_dict': model_restoration.state_dict(), 191 | 'optimizer' : optimizer.state_dict() 192 | }, os.path.join(model_dir,f"model_epoch_{epoch}.pth")) 193 | 194 | scheduler.step() 195 | 196 | print("------------------------------------------------------------------") 197 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0])) 198 | print("------------------------------------------------------------------") 199 | 200 | torch.save({'epoch': epoch, 201 | 'state_dict': model_restoration.state_dict(), 202 | 'optimizer' : optimizer.state_dict() 203 | }, os.path.join(model_dir,"model_latest.pth")) 204 | 205 | writer.close() 206 | -------------------------------------------------------------------------------- /train_wo_warmup.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' 4 | 5 | import torch 6 | torch.backends.cudnn.benchmark = True 7 | 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | 12 | import random 13 | import time 14 | import numpy as np 15 | 16 | import utils 17 | from data_RGB import get_training_data, get_validation_data 18 | from DeepRFT_MIMO import DeepRFT as myNet 19 | import losses 20 | from tqdm import tqdm 21 | from get_parameter_number import get_parameter_number 22 | import kornia 23 | from torch.utils.tensorboard import SummaryWriter 24 | import argparse 25 | 26 | ######### Set Seeds ########### 27 | random.seed(1234) 28 | np.random.seed(1234) 29 | torch.manual_seed(1234) 30 | torch.cuda.manual_seed_all(1234) 31 | 32 | start_epoch = 1 33 | 34 | parser = argparse.ArgumentParser(description='Image Deblurring') 35 | 36 | parser.add_argument('--train_dir', default='./Datasets/GoPro/train', type=str, help='Directory of train images') 37 | parser.add_argument('--val_dir', default='./Datasets/GoPro/val', type=str, help='Directory of validation images') 38 | parser.add_argument('--model_save_dir', default='./checkpoints', type=str, help='Path to save weights') 39 | parser.add_argument('--pretrain_weights', default='./checkpoints/model_best.pth', type=str, help='Path to pretrain-weights') 40 | parser.add_argument('--mode', default='Deblurring', type=str) 41 | parser.add_argument('--session', default='DeepRFT_gopro', type=str, help='session') 42 | parser.add_argument('--patch_size', default=256, type=int, help='patch size, for paper: [GoPro, HIDE, RealBlur]=256, [DPDD]=512') 43 | parser.add_argument('--num_epochs', default=3000, type=int, help='num_epochs') 44 | parser.add_argument('--batch_size', default=16, type=int, help='batch_size') 45 | parser.add_argument('--val_epochs', default=20, type=int, help='val_epochs') 46 | args = parser.parse_args() 47 | 48 | mode = args.mode 49 | session = args.session 50 | patch_size = args.patch_size 51 | 52 | model_dir = os.path.join(args.model_save_dir, mode, 'models', session) 53 | utils.mkdir(model_dir) 54 | 55 | train_dir = args.train_dir 56 | val_dir = args.val_dir 57 | 58 | num_epochs = args.num_epochs 59 | batch_size = args.batch_size 60 | val_epochs = args.val_epochs 61 | 62 | start_lr = 2e-4 63 | end_lr = 1e-6 64 | 65 | ######### Model ########### 66 | model_restoration = myNet() 67 | # print number of model 68 | get_parameter_number(model_restoration) 69 | 70 | model_restoration.cuda() 71 | 72 | device_ids = [i for i in range(torch.cuda.device_count())] 73 | if torch.cuda.device_count() > 1: 74 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") 75 | 76 | optimizer = optim.Adam(model_restoration.parameters(), lr=start_lr, betas=(0.9, 0.999), eps=1e-8) 77 | 78 | ######### Scheduler ########### 79 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs, eta_min=end_lr) 80 | 81 | RESUME = False 82 | Pretrain = False 83 | model_pre_dir = '' 84 | ######### Pretrain ########### 85 | if Pretrain: 86 | utils.load_checkpoint(model_restoration, model_pre_dir) 87 | 88 | print('------------------------------------------------------------------------------') 89 | print("==> Retrain Training with: " + model_pre_dir) 90 | print('------------------------------------------------------------------------------') 91 | 92 | ######### Resume ########### 93 | if RESUME: 94 | path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') 95 | utils.load_checkpoint(model_restoration,path_chk_rest) 96 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 97 | utils.load_optim(optimizer, path_chk_rest) 98 | 99 | for i in range(1, start_epoch): 100 | scheduler.step() 101 | new_lr = scheduler.get_lr()[0] 102 | print('------------------------------------------------------------------------------') 103 | print("==> Resuming Training with learning rate:", new_lr) 104 | print('------------------------------------------------------------------------------') 105 | 106 | if len(device_ids)>1: 107 | model_restoration = nn.DataParallel(model_restoration, device_ids=device_ids) 108 | 109 | ######### Loss ########### 110 | criterion_char = losses.CharbonnierLoss() 111 | criterion_edge = losses.EdgeLoss() 112 | criterion_fft = losses.fftLoss() 113 | ######### DataLoaders ########### 114 | train_dataset = get_training_data(train_dir, {'patch_size':patch_size}) 115 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=False, pin_memory=True) 116 | 117 | val_dataset = get_validation_data(val_dir, {'patch_size':patch_size}) 118 | val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) 119 | 120 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch, num_epochs + 1)) 121 | print('===> Loading datasets') 122 | 123 | best_psnr = 0 124 | best_epoch = 0 125 | writer = SummaryWriter(model_dir) 126 | iter = 0 127 | for epoch in range(start_epoch, num_epochs + 1): 128 | epoch_start_time = time.time() 129 | epoch_loss = 0 130 | train_id = 1 131 | 132 | model_restoration.train() 133 | for i, data in enumerate(tqdm(train_loader), 0): 134 | 135 | # zero_grad 136 | for param in model_restoration.parameters(): 137 | param.grad = None 138 | 139 | target_ = data[0].cuda() 140 | input_ = data[1].cuda() 141 | target = kornia.geometry.transform.build_pyramid(target_, 3) 142 | restored = model_restoration(input_) 143 | 144 | loss_fft = criterion_fft(restored[0], target[0]) + criterion_fft(restored[1], target[1]) + criterion_fft( 145 | restored[2], target[2]) 146 | loss_char = criterion_char(restored[0], target[0]) + criterion_char(restored[1], target[1]) + criterion_char(restored[2], target[2]) 147 | loss_edge = criterion_edge(restored[0], target[0]) + criterion_edge(restored[1], target[1]) + criterion_edge(restored[2], target[2]) 148 | loss = loss_char + 0.01 * loss_fft + 0.05 * loss_edge 149 | loss.backward() 150 | optimizer.step() 151 | epoch_loss +=loss.item() 152 | iter += 1 153 | writer.add_scalar('loss/fft_loss', loss_fft, iter) 154 | writer.add_scalar('loss/char_loss', loss_char, iter) 155 | writer.add_scalar('loss/edge_loss', loss_edge, iter) 156 | writer.add_scalar('loss/iter_loss', loss, iter) 157 | writer.add_scalar('loss/epoch_loss', epoch_loss, epoch) 158 | 159 | #### Evaluation #### 160 | if epoch % val_epochs == 0: 161 | model_restoration.eval() 162 | psnr_val_rgb = [] 163 | for ii, data_val in enumerate((val_loader), 0): 164 | target = data_val[0].cuda() 165 | input_ = data_val[1].cuda() 166 | 167 | with torch.no_grad(): 168 | restored = model_restoration(input_) 169 | 170 | for res,tar in zip(restored[0], target): 171 | psnr_val_rgb.append(utils.torchPSNR(res, tar)) 172 | 173 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() 174 | writer.add_scalar('val/psnr', psnr_val_rgb, epoch) 175 | if psnr_val_rgb > best_psnr: 176 | best_psnr = psnr_val_rgb 177 | best_epoch = epoch 178 | torch.save({'epoch': epoch, 179 | 'state_dict': model_restoration.state_dict(), 180 | 'optimizer' : optimizer.state_dict() 181 | }, os.path.join(model_dir,"model_best.pth")) 182 | 183 | print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr)) 184 | 185 | torch.save({'epoch': epoch, 186 | 'state_dict': model_restoration.state_dict(), 187 | 'optimizer' : optimizer.state_dict() 188 | }, os.path.join(model_dir,f"model_epoch_{epoch}.pth")) 189 | 190 | scheduler.step() 191 | 192 | print("------------------------------------------------------------------") 193 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0])) 194 | print("------------------------------------------------------------------") 195 | 196 | torch.save({'epoch': epoch, 197 | 'state_dict': model_restoration.state_dict(), 198 | 'optimizer' : optimizer.state_dict() 199 | }, os.path.join(model_dir,"model_latest.pth")) 200 | 201 | writer.close() 202 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .image_utils import * 3 | from .model_utils import * 4 | from .dataset_utils import * 5 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/dataset_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dir_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/dir_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/image_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepMed-Lab/DeepRFT-AAAI2023/de71c43694b75f9f664d36b8a1fa205b1466d20b/utils/__pycache__/model_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MixUp_AUG: 4 | def __init__(self): 5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6])) 6 | 7 | def aug(self, rgb_gt, rgb_noisy): 8 | bs = rgb_gt.size(0) 9 | indices = torch.randperm(bs) 10 | rgb_gt2 = rgb_gt[indices] 11 | rgb_noisy2 = rgb_noisy[indices] 12 | 13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() 14 | 15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 17 | 18 | return rgb_gt, rgb_noisy -------------------------------------------------------------------------------- /utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | def torchPSNR(tar_img, prd_img): 6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 7 | rmse = (imdff**2).mean().sqrt() 8 | ps = 20*torch.log10(1/rmse) 9 | return ps 10 | 11 | def save_img(filepath, img): 12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 13 | 14 | def numpyPSNR(tar_img, prd_img): 15 | imdff = np.float32(prd_img) - np.float32(tar_img) 16 | rmse = np.sqrt(np.mean(imdff**2)) 17 | ps = 20*np.log10(255/rmse) 18 | return ps 19 | -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | import numpy as np 5 | def freeze(model): 6 | for p in model.parameters(): 7 | p.requires_grad=False 8 | 9 | def unfreeze(model): 10 | for p in model.parameters(): 11 | p.requires_grad=True 12 | 13 | def is_frozen(model): 14 | x = [p.requires_grad for p in model.parameters()] 15 | return not all(x) 16 | 17 | def save_checkpoint(model_dir, state, session): 18 | epoch = state['epoch'] 19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 20 | torch.save(state, model_out_path) 21 | 22 | def load_checkpoint(model, weights): 23 | checkpoint = torch.load(weights) 24 | # print(checkpoint) 25 | try: 26 | model.load_state_dict(checkpoint["state_dict"]) 27 | except: 28 | state_dict = checkpoint["state_dict"] 29 | new_state_dict = OrderedDict() 30 | for k, v in state_dict.items(): 31 | # print(k) 32 | name = k[7:] # remove `module.` 33 | new_state_dict[name] = v 34 | 35 | model.load_state_dict(new_state_dict) 36 | 37 | 38 | def load_checkpoint_compress_doconv(model, weights): 39 | checkpoint = torch.load(weights) 40 | # print(checkpoint) 41 | # state_dict = OrderedDict() 42 | # try: 43 | # model.load_state_dict(checkpoint["state_dict"]) 44 | # state_dict = checkpoint["state_dict"] 45 | # except: 46 | old_state_dict = checkpoint["state_dict"] 47 | state_dict = OrderedDict() 48 | for k, v in old_state_dict.items(): 49 | # print(k) 50 | name = k 51 | if k[:7] == 'module.': 52 | name = k[7:] # remove `module.` 53 | state_dict[name] = v 54 | # state_dict = checkpoint["state_dict"] 55 | do_state_dict = OrderedDict() 56 | for k, v in state_dict.items(): 57 | if k[-1] == 'W' and k[:-1] + 'D' in state_dict: 58 | k_D = k[:-1] + 'D' 59 | k_D_diag = k_D + '_diag' 60 | W = v 61 | D = state_dict[k_D] 62 | D_diag = state_dict[k_D_diag] 63 | D = D + D_diag 64 | # W = torch.reshape(W, (out_channels, in_channels, D_mul)) 65 | out_channels, in_channels, MN = W.shape 66 | M = int(np.sqrt(MN)) 67 | DoW_shape = (out_channels, in_channels, M, M) 68 | DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape) 69 | do_state_dict[k] = DoW 70 | elif k[-1] == 'D' or k[-6:] == 'D_diag': 71 | continue 72 | elif k[-1] == 'W': 73 | out_channels, in_channels, MN = v.shape 74 | M = int(np.sqrt(MN)) 75 | W_shape = (out_channels, in_channels, M, M) 76 | do_state_dict[k] = torch.reshape(v, W_shape) 77 | else: 78 | do_state_dict[k] = v 79 | model.load_state_dict(do_state_dict) 80 | def load_checkpoint_hin(model, weights): 81 | checkpoint = torch.load(weights) 82 | # print(checkpoint) 83 | try: 84 | model.load_state_dict(checkpoint) 85 | except: 86 | state_dict = checkpoint 87 | new_state_dict = OrderedDict() 88 | for k, v in state_dict.items(): 89 | name = k[7:] # remove `module.` 90 | new_state_dict[name] = v 91 | model.load_state_dict(new_state_dict) 92 | def load_checkpoint_multigpu(model, weights): 93 | checkpoint = torch.load(weights) 94 | state_dict = checkpoint["state_dict"] 95 | new_state_dict = OrderedDict() 96 | for k, v in state_dict.items(): 97 | name = k[7:] # remove `module.` 98 | new_state_dict[name] = v 99 | model.load_state_dict(new_state_dict) 100 | 101 | def load_start_epoch(weights): 102 | checkpoint = torch.load(weights) 103 | epoch = checkpoint["epoch"] 104 | return epoch 105 | 106 | def load_optim(optimizer, weights): 107 | checkpoint = torch.load(weights) 108 | optimizer.load_state_dict(checkpoint['optimizer']) 109 | # for p in optimizer.param_groups: lr = p['lr'] 110 | # return lr 111 | --------------------------------------------------------------------------------