├── kuangjia.png ├── __pycache__ ├── loss.cpython-310.pyc ├── loss.cpython-36.pyc ├── loss.cpython-37.pyc ├── utils.cpython-36.pyc ├── utils.cpython-37.pyc ├── utils.cpython-310.pyc ├── JigsawNet.cpython-37.pyc ├── loss_util.cpython-310.pyc ├── loss_util.cpython-36.pyc ├── loss_util.cpython-37.pyc ├── GaussianSmoothLayer.cpython-310.pyc ├── GaussianSmoothLayer.cpython-36.pyc ├── GaussianSmoothLayer.cpython-37.pyc └── GaussianSmoothLayer.cpython-39.pyc ├── networks3 ├── __pycache__ │ ├── APBSN.cpython-37.pyc │ ├── DBSNl.cpython-37.pyc │ ├── UNetD.cpython-36.pyc │ ├── UNetD.cpython-37.pyc │ ├── UNetG.cpython-36.pyc │ ├── UNetG.cpython-37.pyc │ ├── util.cpython-37.pyc │ ├── UNetD.cpython-310.pyc │ ├── UNetG.cpython-310.pyc │ ├── __init__.cpython-37.pyc │ ├── SubBlocks.cpython-310.pyc │ ├── SubBlocks.cpython-36.pyc │ ├── SubBlocks.cpython-37.pyc │ ├── SubBlocks.cpython-39.pyc │ ├── Discriminator.cpython-36.pyc │ ├── Discriminator.cpython-37.pyc │ ├── Discriminator.cpython-39.pyc │ └── Discriminator.cpython-310.pyc ├── SubBlocks.py ├── OptimizedBlock.py ├── DisResBlock.py ├── RRDBNet_arch.py ├── UNetG.py ├── util.py ├── UNetD.py ├── Discriminator.py └── __init__.py ├── datasets ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── data_tools.cpython-36.pyc │ ├── data_tools.cpython-37.pyc │ ├── DenoisingDatasets.cpython-310.pyc │ ├── DenoisingDatasets.cpython-36.pyc │ └── DenoisingDatasets.cpython-37.pyc ├── __init__.py ├── data_tools.py └── DenoisingDatasets.py ├── configs └── DANet_v5.json ├── loss_util.py ├── README.md ├── utils.py ├── GaussianSmoothLayer.py ├── train_v6.py └── loss.py /kuangjia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/kuangjia.png -------------------------------------------------------------------------------- /__pycache__/loss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/loss.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/JigsawNet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/JigsawNet.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/loss_util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/loss_util.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/loss_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/loss_util.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/loss_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/loss_util.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/APBSN.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/APBSN.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/DBSNl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/DBSNl.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/UNetD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/UNetD.cpython-36.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/UNetD.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/UNetD.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/UNetG.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/UNetG.cpython-36.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/UNetG.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/UNetG.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/UNetD.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/UNetD.cpython-310.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/UNetG.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/UNetG.cpython-310.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/GaussianSmoothLayer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/GaussianSmoothLayer.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/GaussianSmoothLayer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/GaussianSmoothLayer.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/GaussianSmoothLayer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/GaussianSmoothLayer.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/GaussianSmoothLayer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/__pycache__/GaussianSmoothLayer.cpython-39.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_tools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/datasets/__pycache__/data_tools.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/datasets/__pycache__/data_tools.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/SubBlocks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/SubBlocks.cpython-310.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/SubBlocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/SubBlocks.cpython-36.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/SubBlocks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/SubBlocks.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/SubBlocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/SubBlocks.cpython-39.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/Discriminator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/Discriminator.cpython-36.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/Discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/Discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/Discriminator.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/Discriminator.cpython-39.pyc -------------------------------------------------------------------------------- /networks3/__pycache__/Discriminator.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/networks3/__pycache__/Discriminator.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/DenoisingDatasets.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/datasets/__pycache__/DenoisingDatasets.cpython-310.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/DenoisingDatasets.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/datasets/__pycache__/DenoisingDatasets.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/DenoisingDatasets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linxin0/SCPGabNet/HEAD/datasets/__pycache__/DenoisingDatasets.cpython-37.pyc -------------------------------------------------------------------------------- /networks3/SubBlocks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-09-01 19:19:32 4 | 5 | import torch 6 | import torch.nn as nn 7 | import sys 8 | import torch.nn.functional as F 9 | import torch.nn.utils as utils 10 | 11 | def conv3x3(in_chn, out_chn, bias=True): 12 | layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias) 13 | return layer 14 | 15 | def conv_down(in_chn, out_chn, bias=False): 16 | layer = nn.Conv2d(in_chn, out_chn, kernel_size=4, stride=2, padding=1, bias=bias) 17 | return layer 18 | 19 | def conv1x1(in_chn,out_chn,bias=True): 20 | layer = nn.Conv2d(in_chn, out_chn, kernel_size=1, stride=1, padding=1, bias = bias) 21 | return layer 22 | -------------------------------------------------------------------------------- /networks3/OptimizedBlock.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # https://github.com/XHChen0528/SNGAN_Projection_Pytorch 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import init 10 | from torch.nn import utils 11 | 12 | class OptimizedBlock(nn.Module): 13 | 14 | def __init__(self, in_ch, out_ch, ksize=3, pad=1, activation=F.relu): 15 | super(OptimizedBlock, self).__init__() 16 | self.activation = activation 17 | 18 | self.c1 = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, ksize, 1, pad)) 19 | self.c2 = utils.spectral_norm(nn.Conv2d(out_ch, out_ch, ksize, 1, pad)) 20 | self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0)) 21 | 22 | self._initialize() 23 | 24 | def _initialize(self): 25 | init.xavier_uniform_(self.c1.weight.data, math.sqrt(2)) 26 | init.xavier_uniform_(self.c2.weight.data, math.sqrt(2)) 27 | init.xavier_uniform_(self.c_sc.weight.data) 28 | 29 | def forward(self, x): 30 | return self.shortcut(x) + self.residual(x) 31 | 32 | def shortcut(self, x): 33 | return self.c_sc(F.avg_pool2d(x, 2)) 34 | 35 | def residual(self, x): 36 | h = self.activation(self.c1(x)) 37 | return F.avg_pool2d(self.c2(h), 2) 38 | -------------------------------------------------------------------------------- /configs/DANet_v5.json: -------------------------------------------------------------------------------- 1 | { 2 | # training settings 3 | 4 | "batch_size": 12, 5 | "patch_size": 128, 6 | "epochs": 400, 7 | "lr_D": 1e-4, 8 | "lr_G": 1e-4, 9 | "lr_P": 1e-4, 10 | "lr_E": 1e-4, 11 | "lr": 1e-4, 12 | "lr_decay": 50, 13 | "gamma": 0.5, 14 | 15 | 16 | 17 | 18 | "print_freq": 10, 19 | "num_workers": 0, 20 | "gpu_id": "0", 21 | 22 | "resume":"", 23 | # suporrt vanilla||lsgan 24 | "gan_mode": "lsgan", 25 | 26 | "milestones": [50,100,125,150,175,190], 27 | "weight_decay": 0, 28 | 29 | ############### network architecture ############### 30 | # number of filters of the first convolution in UNet 31 | "wf": 32, 32 | # depth of UNet 33 | "depth": 5, 34 | # number of filters of the first convolution in Discriminator 35 | "ndf": 64, 36 | 37 | ######### training and validation data path ######## 38 | "SIDD_train_h5_noisy": "./dataset/small_imgs_train_im_noisy.hdf5", 39 | "SIDD_train_h5_gt": "./dataset/small_imgs_train_im_gt.hdf5", 40 | "SIDD_test_h5": "./dataset/small_imgs_test.hdf5", 41 | 42 | # saving models and logs 43 | "model_dir": "./model_ours_12", 44 | "log_dir": "./logs_DANet", 45 | "save_dir": "./logs_DANet", 46 | "model_dir_test":"./model_ours_12", 47 | 48 | ########### hyper-parameters of our model ########## 49 | "alpha": 0.5, 50 | # kernel size for the Gauss filter in loss function 51 | "ksize": 5, 52 | "lambda_gp": 10, 53 | "tau_D": 1000, 54 | "tau_G": 10, 55 | "rec_x": 1, 56 | "rec_y": 1, 57 | "l1_loss": 1, 58 | "idt": 1, 59 | "adversarial_loss_factor": 1, 60 | "perceptual_loss_factor": 1, 61 | "bgm_loss": 6, 62 | "num_critic":1, 63 | "mse_loss_factor":1 64 | } 65 | -------------------------------------------------------------------------------- /networks3/DisResBlock.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # https://github.com/XHChen0528/SNGAN_Projection_Pytorch 4 | 5 | import math 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.nn import init 11 | from torch.nn import utils 12 | 13 | 14 | class DisResBlock(nn.Module): 15 | 16 | def __init__(self, in_ch, out_ch, h_ch=None, ksize=3, pad=1, 17 | activation=F.relu, downsample=False): 18 | super(DisResBlock, self).__init__() 19 | 20 | self.activation = activation 21 | self.downsample = downsample 22 | 23 | self.learnable_sc = (in_ch != out_ch) or downsample 24 | if h_ch is None: 25 | h_ch = in_ch 26 | else: 27 | h_ch = out_ch 28 | 29 | self.c1 = utils.spectral_norm(nn.Conv2d(in_ch, h_ch, ksize, 1, pad)) 30 | self.c2 = utils.spectral_norm(nn.Conv2d(h_ch, out_ch, ksize, 1, pad)) 31 | if self.learnable_sc: 32 | self.c_sc = utils.spectral_norm(nn.Conv2d(in_ch, out_ch, 1, 1, 0)) 33 | 34 | self._initialize() 35 | 36 | def _initialize(self): 37 | init.xavier_uniform_(self.c1.weight.data, math.sqrt(2)) 38 | init.xavier_uniform_(self.c2.weight.data, math.sqrt(2)) 39 | if self.learnable_sc: 40 | init.xavier_uniform_(self.c_sc.weight.data) 41 | 42 | def forward(self, x): 43 | return self.shortcut(x) + self.residual(x) 44 | 45 | def shortcut(self, x): 46 | if self.learnable_sc: 47 | x = self.c_sc(x) 48 | if self.downsample: 49 | return F.avg_pool2d(x, 2) 50 | return x 51 | 52 | def residual(self, x): 53 | h = self.c1(self.activation(x)) 54 | h = self.c2(self.activation(h)) 55 | if self.downsample: 56 | h = F.avg_pool2d(h, 2) 57 | return h 58 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-09-02 15:35:24 4 | 5 | import random 6 | import numpy as np 7 | import torch.utils.data as uData 8 | import h5py as h5 9 | import cv2 10 | 11 | # Base Datasets 12 | class BaseDataSetH5(uData.Dataset): 13 | def __init__(self, h5_path, length=None): 14 | ''' 15 | Args: 16 | h5_path (str): path of the hdf5 file 17 | length (int): length of Datasets 18 | ''' 19 | super(BaseDataSetH5, self).__init__() 20 | self.h5_path = h5_path 21 | self.length = length 22 | with h5.File(h5_path, 'r') as h5_file: 23 | self.keys = list(h5_file.keys()) 24 | self.num_images = len(self.keys) 25 | 26 | def __len__(self): 27 | if self.length == None: 28 | return self.num_images 29 | else: 30 | return self.length 31 | 32 | def crop_patch(self, imgs_sets): 33 | H, W, C2 = imgs_sets.shape 34 | # minus the bayer patter channel 35 | C = int(C2/2) 36 | ind_H = random.randint(0, H-self.pch_size) 37 | ind_W = random.randint(0, W-self.pch_size) 38 | im_noisy = np.array(imgs_sets[ind_H:ind_H+self.pch_size, ind_W:ind_W+self.pch_size, :C]) 39 | im_gt = np.array(imgs_sets[ind_H:ind_H+self.pch_size, ind_W:ind_W+self.pch_size, C:]) 40 | return im_gt, im_noisy 41 | 42 | class BaseDataSetFolder(uData.Dataset): 43 | def __init__(self, path_list, pch_size, length=None): 44 | ''' 45 | Args: 46 | path_list (str): path of the hdf5 file 47 | length (int): length of Datasets 48 | ''' 49 | super(BaseDataSetFolder, self).__init__() 50 | self.path_list = path_list 51 | self.length = length 52 | self.pch_size = pch_size 53 | self.num_images = len(path_list) 54 | 55 | def __len__(self): 56 | if self.length == None: 57 | return self.num_images 58 | else: 59 | return self.length 60 | 61 | def crop_patch(self, im): 62 | pch_size = self.pch_size 63 | H, W, _ = im.shape 64 | if H < self.pch_size or W < self.pch_size: 65 | H = max(pch_size, H) 66 | W = max(pch_size, W) 67 | im = cv2.resize(im, (W, H)) 68 | ind_H = random.randint(0, H-pch_size) 69 | ind_W = random.randint(0, W-pch_size) 70 | im_pch = im[ind_H:ind_H+pch_size, ind_W:ind_W+pch_size,] 71 | return im_pch 72 | -------------------------------------------------------------------------------- /networks3/RRDBNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_layer(block, n_layers): 8 | layers = [] 9 | for _ in range(n_layers): 10 | layers.append(block()) 11 | return nn.Sequential(*layers) 12 | 13 | 14 | class ResidualDenseBlock_5C(nn.Module): 15 | def __init__(self, nf=64, gc=32, bias=True): 16 | super(ResidualDenseBlock_5C, self).__init__() 17 | # gc: growth channel, i.e. intermediate channels 18 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 19 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 20 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 21 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 22 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 23 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 24 | 25 | # initialization 26 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 27 | 28 | def forward(self, x): 29 | x1 = self.lrelu(self.conv1(x)) 30 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 31 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 32 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 33 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 34 | return x5 * 0.2 + x 35 | 36 | 37 | class RRDB(nn.Module): 38 | '''Residual in Residual Dense Block''' 39 | 40 | def __init__(self, nf, gc=32): 41 | super(RRDB, self).__init__() 42 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 43 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 44 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 45 | 46 | def forward(self, x): 47 | out = self.RDB1(x) 48 | out = self.RDB2(out) 49 | out = self.RDB3(out) 50 | return out * 0.2 + x 51 | 52 | 53 | class RRDBNet(nn.Module): 54 | def __init__(self, in_nc, out_nc, nf, nb, gc=32): 55 | super(RRDBNet, self).__init__() 56 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 57 | 58 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 59 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 60 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 61 | #### upsampling 62 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 64 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 65 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 66 | 67 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 68 | 69 | def forward(self, x): 70 | fea = self.conv_first(x) 71 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 72 | fea = fea + trunk 73 | 74 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 75 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 76 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 77 | 78 | return out 79 | -------------------------------------------------------------------------------- /loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 【ICCV 2023】 Unsupervised Image Denoising in Real-World Scenarios via Self-Collaboration Parallel Generative Adversarial Branches 2 | 3 | ### Xin Lin, Chao Ren, Xiao Liu, Jie Huang, Yinjie Lei 4 | 5 | [![paper](https://img.shields.io/badge/arXiv-Paper-green_yellow)]([https://arxiv.org/pdf/2308.06776.pdf](https://openaccess.thecvf.com/content/ICCV2023/papers/Lin_Unsupervised_Image_Denoising_in_Real-World_Scenarios_via_Self-Collaboration_Parallel_Generative_ICCV_2023_paper.pdf)) 6 | 7 | # The journal version is here [RSCP^2^GAN](https://arxiv.org/pdf/2408.09241). 8 | 9 | This is the official code of SCPGabNet. 10 | 11 | ![main_fig](./kuangjia.png) 12 | 13 | 14 | ## Abstract 15 | Deep learning methods have shown remarkable performance in image denoising, particularly when trained on large-scale paired datasets. However, acquiring such paired datasets for real-world scenarios poses a significant challenge. Although unsupervised approaches based on generative adversarial networks (GANs) offer a promising solution for denoising without paired datasets, they are difficult to surpass the performance limitations of conventional GAN-based unsupervised frameworks without significantly modifying existing structures or increase the computational complexity of denoisers. To address this problem, we propose a self-collaboration (SC) strategy for multiple denoisers. This strategy can achieve significant performance improvement without increasing the inference complexity of the GAN-based denoising framework. Its basic idea is to iteratively replace the previous less powerful denoiser in the filter-guided noise extraction module with the current powerful denoiser. This process generates better synthetic clean-noisy image pairs, leading to a more powerful denoiser for the next iteration. In addition, we propose a baseline method that includes parallel generative adversarial branches with complementary “self-synthesis” and “unpaired-synthesis” constraints. This baseline ensures the stability and effectiveness of the training network. The experimental results demonstrate the superiority of our method over state-of-the-art unsupervised methods. 16 | 17 | ## Requirements 18 | Our experiments are done with: 19 | 20 | - Python 3.7.13 21 | - PyTorch 1.13.0 22 | - numpy 1.21.5 23 | - opencv 4.6.0 24 | - scikit-image 0.19.3 25 | 26 | ## Dateset 27 | 28 | SIDD 29 | Train: https://pan.baidu.com/s/1c1iPIIJvSfq6s6_M7iyjPA 2oe5 30 | 31 | Test: https://pan.baidu.com/s/1yltsD684qpJa0SMJ9SdR5w 8qzf 32 | 33 | ## Pre-trained Models 34 | 35 | Google Drive: https://drive.google.com/file/d/1wzceTFvnoepftIEJn1DI73iHnGi_pvMl/view?usp=sharing 36 | 37 | Baidu Drive: https://pan.baidu.com/s/1EdXN7o9EW_ssDRHxDKFeXw icp1 38 | 39 | 40 | 41 | ## Train & Test 42 | You can get the complete SIDD validation dataset from https://www.eecs.yorku.ca/~kamel/sidd/benchmark.php. 43 | 44 | '.mat' files need to be converted to images ('.png'). 45 | 46 | train and test are both in `train_v6.py`. 47 | 48 | run `trainv6.py`. 49 | 50 | ## Citation 51 | 52 | @inproceedings{scpgabnet, 53 | title={Unsupervised Image Denoising in Real-World Scenarios via Self-Collaboration Parallel Generative Adversarial Branches}, 54 | author={Xin Lin and Chao Ren and Xiao Liu and Jie Huang and Yinjie Lei}, 55 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 56 | year={2023} 57 | } 58 | 59 | ## Contact 60 | If you have any questions, please contact linxin@stu.scu.edu.cn 61 | -------------------------------------------------------------------------------- /datasets/data_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-09-02 15:11:05 4 | 5 | import cv2 6 | import numpy as np 7 | import random 8 | from math import ceil 9 | 10 | def data_augmentation(image, mode): 11 | ''' 12 | Performs data augmentation of the input image 13 | Input: 14 | image: a cv2 (OpenCV) image 15 | mode: int. Choice of transformation to apply to the image 16 | 0 - no transformation 17 | 1 - flip up and down 18 | 2 - rotate counterwise 90 degree 19 | 3 - rotate 90 degree and flip up and down 20 | 4 - rotate 180 degree 21 | 5 - rotate 180 degree and flip 22 | 6 - rotate 270 degree 23 | 7 - rotate 270 degree and flip 24 | ''' 25 | if mode == 0: 26 | # original 27 | out = image 28 | elif mode == 1: 29 | # flip up and down 30 | out = np.flipud(image) 31 | elif mode == 2: 32 | # rotate counterwise 90 degree 33 | out = np.rot90(image) 34 | elif mode == 3: 35 | # rotate 90 degree and flip up and down 36 | out = np.rot90(image) 37 | out = np.flipud(out) 38 | elif mode == 4: 39 | # rotate 180 degree 40 | out = np.rot90(image, k=2) 41 | elif mode == 5: 42 | # rotate 180 degree and flip 43 | out = np.rot90(image, k=2) 44 | out = np.flipud(out) 45 | elif mode == 6: 46 | # rotate 270 degree 47 | out = np.rot90(image, k=3) 48 | elif mode == 7: 49 | # rotate 270 degree and flip 50 | out = np.rot90(image, k=3) 51 | out = np.flipud(out) 52 | else: 53 | raise Exception('Invalid choice of image transformation') 54 | 55 | return out 56 | 57 | def inverse_data_augmentation(image, mode): 58 | ''' 59 | Performs inverse data augmentation of the input image 60 | ''' 61 | if mode == 0: 62 | # original 63 | out = image 64 | elif mode == 1: 65 | out = np.flipud(image) 66 | elif mode == 2: 67 | out = np.rot90(image, axes=(1,0)) 68 | elif mode == 3: 69 | out = np.flipud(image) 70 | out = np.rot90(out, axes=(1,0)) 71 | elif mode == 4: 72 | out = np.rot90(image, k=2, axes=(1,0)) 73 | elif mode == 5: 74 | out = np.flipud(image) 75 | out = np.rot90(out, k=2, axes=(1,0)) 76 | elif mode == 6: 77 | out = np.rot90(image, k=3, axes=(1,0)) 78 | elif mode == 7: 79 | # rotate 270 degree and flip 80 | out = np.flipud(image) 81 | out = np.rot90(out, k=3, axes=(1,0)) 82 | else: 83 | raise Exception('Invalid choice of image transformation') 84 | 85 | return out 86 | 87 | def random_augmentation(*args): 88 | out = [] 89 | if random.randint(0,1) == 1: 90 | flag_aug = random.randint(1,7) 91 | for data in args: 92 | out.append(data_augmentation(data, flag_aug).copy()) 93 | else: 94 | for data in args: 95 | out.append(data) 96 | return out 97 | 98 | def im_pad_fun(image, offset): 99 | ''' 100 | Input: 101 | image: numpy array, H x W x C 102 | ''' 103 | H, W, C = image.shape 104 | if (H % offset == 0) and (W % offset == 0): 105 | image_pad = image 106 | else: 107 | H_pad = H if (H % offset == 0) else (offset * ceil(H / offset)) 108 | W_pad = W if (W % offset == 0) else (offset * ceil(W / offset)) 109 | image_pad = np.zeros([H_pad, W_pad, C], dtype=image.dtype) 110 | image_pad[:H, :W] = image 111 | 112 | if (H % offset) != 0: 113 | image_pad[H:, :W] = image[(H%offset-offset):, ][::-1,] 114 | 115 | if (W % offset) != 0: 116 | image_pad[:, W:] = image_pad[:, (W-(offset-W%offset)):W][:, ::-1] 117 | 118 | return image_pad 119 | 120 | if __name__ == '__main__': 121 | # aa = np.random.randn(4,4) 122 | # for ii in range(8): 123 | # bb1 = data_augmentation(aa, ii) 124 | # bb2 = inverse_data_augmentation(bb1, ii) 125 | # if np.allclose(aa, bb2): 126 | # print('Flag: {:d}, Sccessed!'.format(ii)) 127 | # else: 128 | # print('Flag: {:d}, Failed!'.format(ii)) 129 | 130 | aa = np.random.randn(6, 6, 3) 131 | bb = im_pad_fun(aa, 3) 132 | print(aa[:, :, 0]) 133 | print(bb[:, :, 0]) 134 | 135 | 136 | -------------------------------------------------------------------------------- /networks3/UNetG.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-03-20 19:48:14 4 | # Adapted from https://github.com/jvanvugt/pytorch-unet 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from .SubBlocks import conv3x3, conv_down 10 | from .UNetD import UNetD 11 | from torch.nn import init 12 | class UNetG(UNetD): 13 | def __init__(self, in_chn, wf=32, depth=5, relu_slope=0.20): 14 | """ 15 | Reference: 16 | Ronneberger O., Fischer P., Brox T. (2015) U-Net: Convolutional Networks for Biomedical 17 | Image Segmentation. MICCAI 2015. 18 | ArXiv Version: https://arxiv.org/abs/1505.04597 19 | 20 | Args: 21 | in_chn (int): number of input channels, Default 3 22 | depth (int): depth of the network, Default 4 23 | wf (int): number of filters in the first layer, Default 32 24 | """ 25 | super(UNetG, self).__init__(in_chn, wf, depth, relu_slope) 26 | 27 | def get_input_chn(self, in_chn): 28 | return in_chn+3 29 | 30 | 31 | 32 | def sample_generator(netG, x): 33 | z = torch.randn([x.shape[0], 3, x.shape[2], x.shape[3]], device=x.device) 34 | # x1 = torch.cat([x, z], dim=1) 35 | out = netG(x,z) 36 | 37 | return out+x 38 | 39 | def sample_generator_1(netG, x,z): 40 | # z = torch.randn([x.shape[0], 1, x.shape[2], x.shape[3]], device=x.device) 41 | x1 = torch.cat([x, z], dim=1) 42 | out = netG(x1) 43 | 44 | return out+x 45 | 46 | 47 | class _Conv_Block(nn.Module): 48 | def __init__(self): 49 | super(_Conv_Block, self).__init__() 50 | 51 | self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 52 | self.in1 = nn.BatchNorm2d(64, affine=True) 53 | self.relu = nn.LeakyReLU(0.2, inplace=True) 54 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 55 | self.in2 = nn.BatchNorm2d(64, affine=True) 56 | 57 | def forward(self, x): 58 | identity_data = x 59 | output = self.relu(self.in1(self.conv1(x))) 60 | output = self.in2(self.conv2(output)) 61 | return output 62 | 63 | 64 | class _Residual_Block(nn.Module): 65 | def __init__(self): 66 | super(_Residual_Block, self).__init__() 67 | 68 | self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1) 69 | # self.in1 = nn.BatchNorm2d(64, affine=True) 70 | self.relu = nn.LeakyReLU(0.2, inplace=True) 71 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, ) 72 | 73 | # self.in2 = nn.BatchNorm2d(64, affine=True) 74 | 75 | def forward(self, x): 76 | identity_data = x 77 | output = self.relu((self.conv1(x))) 78 | 79 | output = self.relu((self.conv2(output))) 80 | output = torch.add(output, identity_data) 81 | return output 82 | 83 | # class _NetG_DOWN(nn.Module): 84 | # def __init__(self, stride=2): 85 | # super(_NetG_DOWN, self).__init__() 86 | # 87 | # self.conv_input = nn.Sequential( 88 | # nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, stride=1, padding=3, ), 89 | # nn.LeakyReLU(0.2, inplace=True), 90 | # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=stride + 2, stride=stride, padding=1, ), 91 | # nn.LeakyReLU(0.2, inplace=True), 92 | # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=stride + 2, stride=stride, padding=1, ), 93 | # nn.LeakyReLU(0.2, inplace=True), 94 | # ) 95 | # 96 | # self.relu = nn.LeakyReLU(0.2, inplace=True) 97 | # 98 | # self.residual = self.make_layer(_Residual_Block, 6) 99 | # 100 | # self.conv_output = nn.Sequential( 101 | # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 102 | # nn.LeakyReLU(0.2, inplace=True), 103 | # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, ), 104 | # nn.LeakyReLU(0.2, inplace=True), 105 | # nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, stride=1, padding=3, ), 106 | # ) 107 | # 108 | # def make_layer(self, block, num_of_layer): 109 | # layers = [] 110 | # for _ in range(num_of_layer): 111 | # layers.append(block()) 112 | # return nn.Sequential(*layers) 113 | # 114 | # def forward(self, x): 115 | # out = self.conv_input(x) 116 | # 117 | # 118 | # out = self.residual(out) 119 | # 120 | # 121 | # out = self.conv_output(out) 122 | # 123 | # return out 124 | 125 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-01-22 22:07:08 4 | 5 | import math 6 | import torch 7 | import torch.nn.functional as F 8 | from skimage import img_as_ubyte 9 | from loss import get_gausskernel, gaussblur 10 | import numpy as np 11 | import cv2 12 | 13 | def ssim(img1, img2): 14 | C1 = (0.01 * 255)**2 15 | C2 = (0.03 * 255)**2 16 | 17 | img1 = img1.astype(np.float64) 18 | img2 = img2.astype(np.float64) 19 | kernel = cv2.getGaussianKernel(11, 1.5) 20 | window = np.outer(kernel, kernel.transpose()) 21 | 22 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 23 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 24 | mu1_sq = mu1**2 25 | mu2_sq = mu2**2 26 | mu1_mu2 = mu1 * mu2 27 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 28 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 29 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 30 | 31 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 32 | (sigma1_sq + sigma2_sq + C2)) 33 | return ssim_map.mean() 34 | 35 | def calculate_ssim(img1, img2, border=0): 36 | '''calculate SSIM 37 | the same outputs as MATLAB's 38 | img1, img2: [0, 255] 39 | ''' 40 | if not img1.shape == img2.shape: 41 | raise ValueError('Input images must have the same dimensions.') 42 | h, w = img1.shape[:2] 43 | img1 = img1[border:h-border, border:w-border] 44 | img2 = img2[border:h-border, border:w-border] 45 | 46 | if img1.ndim == 2: 47 | return ssim(img1, img2) 48 | elif img1.ndim == 3: 49 | if img1.shape[2] == 3: 50 | ssims = [] 51 | for i in range(3): 52 | ssims.append(ssim(img1[:,:,i], img2[:,:,i])) 53 | return np.array(ssims).mean() 54 | elif img1.shape[2] == 1: 55 | return ssim(np.squeeze(img1), np.squeeze(img2)) 56 | else: 57 | raise ValueError('Wrong input image dimensions.') 58 | 59 | def calculate_psnr(im1, im2, border=0): 60 | if not im1.shape == im2.shape: 61 | raise ValueError('Input images must have the same dimensions.') 62 | h, w = im1.shape[:2] 63 | im1 = im1[border:h-border, border:w-border] 64 | im2 = im2[border:h-border, border:w-border] 65 | 66 | im1 = im1.astype(np.float64) 67 | im2 = im2.astype(np.float64) 68 | mse = np.mean((im1 - im2)**2) 69 | if mse == 0: 70 | return float('inf') 71 | return 20 * math.log10(255.0 / math.sqrt(mse)) 72 | 73 | def batch_PSNR(img, imclean, border=0): 74 | Img = img.data.cpu().numpy() 75 | Iclean = imclean.data.cpu().numpy() 76 | Img = img_as_ubyte(Img) 77 | Iclean = img_as_ubyte(Iclean) 78 | PSNR = 0 79 | for i in range(Img.shape[0]): 80 | PSNR += calculate_psnr(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border) 81 | return (PSNR/Img.shape[0]) 82 | 83 | def batch_SSIM(img, imclean, border=0): 84 | Img = img.data.cpu().numpy() 85 | Iclean = imclean.data.cpu().numpy() 86 | Img = img_as_ubyte(Img) 87 | Iclean = img_as_ubyte(Iclean) 88 | SSIM = 0 89 | for i in range(Img.shape[0]): 90 | SSIM += calculate_ssim(Iclean[i,:,].transpose((1,2,0)), Img[i,:,].transpose((1,2,0)), border) 91 | return (SSIM/Img.shape[0]) 92 | 93 | def kl_gauss_zero_center(sigma_fake, sigma_real): 94 | ''' 95 | Input: 96 | sigma_fake: 1 x C x H x W, torch array 97 | sigma_real: 1 x C x H x W, torch array 98 | ''' 99 | div_sigma = torch.div(sigma_fake, sigma_real) 100 | div_sigma.clamp_(min=0.1, max=10) 101 | log_sigma = torch.log(1 / div_sigma) 102 | distance = 0.5 * torch.mean(log_sigma + div_sigma - 1.) 103 | return distance 104 | 105 | def estimate_sigma_gauss(img_noisy, img_gt): 106 | win_size = 7 107 | err2 = (img_noisy - img_gt) ** 2 108 | kernel = get_gausskernel(win_size, chn=3).to(img_gt.device) 109 | sigma = gaussblur(err2, kernel, win_size, chn=3) 110 | sigma.clamp_(min=1e-10) 111 | 112 | return sigma 113 | 114 | class PadUNet: 115 | ''' 116 | im: N x C x H x W torch tensor 117 | dep_U: depth of UNet 118 | ''' 119 | def __init__(self, im, dep_U, mode='reflect'): 120 | self.im_old = im 121 | self.dep_U = dep_U 122 | self.mode = mode 123 | self.H_old = im.shape[2] 124 | self.W_old = im.shape[3] 125 | 126 | def pad(self): 127 | lenU = 2 ** (self.dep_U-1) 128 | padH = 0 if ((self.H_old % lenU) == 0) else (lenU - (self.H_old % lenU)) 129 | padW = 0 if ((self.W_old % lenU) == 0) else (lenU - (self.W_old % lenU)) 130 | padding = (0, padW, 0, padH) 131 | out = F.pad(self.im_old, pad=padding, mode=self.mode) 132 | return out 133 | 134 | def pad_inverse(self, im_new): 135 | return im_new[:, :, :self.H_old, :self.W_old] 136 | 137 | from torch.nn import init 138 | def init_weights(net, init_type='normal', init_gain=0.02): 139 | """Initialize network weights. 140 | 141 | Parameters: 142 | net (network) -- network to be initialized 143 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 144 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 145 | 146 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 147 | work better for some applications. Feel free to try yourself. 148 | """ 149 | def init_func(m): # define the initialization function 150 | classname = m.__class__.__name__ 151 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 152 | if init_type == 'normal': 153 | init.normal_(m.weight.data, 0.0, init_gain) 154 | elif init_type == 'xavier': 155 | init.xavier_normal_(m.weight.data, gain=init_gain) 156 | elif init_type == 'kaiming': 157 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 158 | elif init_type == 'orthogonal': 159 | init.orthogonal_(m.weight.data, gain=init_gain) 160 | else: 161 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 162 | if hasattr(m, 'bias') and m.bias is not None: 163 | init.constant_(m.bias.data, 0.0) 164 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 165 | init.normal_(m.weight.data, 1.0, init_gain) 166 | init.constant_(m.bias.data, 0.0) 167 | 168 | print('initialize network with %s' % init_type) 169 | net.apply(init_func) # apply the initialization function 170 | 171 | -------------------------------------------------------------------------------- /datasets/DenoisingDatasets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-09-02 15:51:11 4 | 5 | import sys 6 | import torch 7 | import h5py as h5 8 | import random 9 | import cv2 10 | import os 11 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 12 | import numpy as np 13 | import torch.utils.data as uData 14 | from skimage import img_as_float32 as img_as_float 15 | # from . import BaseDataSetH5, BaseDataSetFolder 16 | class BaseDataSetH5(uData.Dataset): 17 | def __init__(self, h5_path, length=None): 18 | ''' 19 | Args: 20 | h5_path (str): path of the hdf5 file 21 | length (int): length of Datasets 22 | ''' 23 | super(BaseDataSetH5, self).__init__() 24 | self.h5_path = h5_path 25 | self.length = length 26 | with h5.File(h5_path, 'r') as h5_file: 27 | self.keys = list(h5_file.keys()) 28 | self.num_images = len(self.keys) 29 | 30 | def __len__(self): 31 | if self.length == None: 32 | return self.num_images 33 | else: 34 | return self.length 35 | 36 | def crop_patch(self, imgs_sets): 37 | H, W, C2 = imgs_sets.shape 38 | # minus the bayer patter channel 39 | C = int(C2/2) 40 | ind_H = random.randint(0, H-self.pch_size) 41 | ind_W = random.randint(0, W-self.pch_size) 42 | im_noisy = np.array(imgs_sets[ind_H:ind_H+self.pch_size, ind_W:ind_W+self.pch_size, :C]) 43 | im_gt = np.array(imgs_sets[ind_H:ind_H+self.pch_size, ind_W:ind_W+self.pch_size, C:]) 44 | return im_gt, im_noisy 45 | 46 | class BaseDataSetFolder(uData.Dataset): 47 | def __getitem__(self, path_list, pch_size, length=None): 48 | ''' 49 | Args: 50 | path_list (str): path of the hdf5 file 51 | length (int): length of Datasets 52 | ''' 53 | super(BaseDataSetFolder, self).__init__() 54 | self.path_list = path_list 55 | self.length = length 56 | self.pch_size = pch_size 57 | self.num_images = len(path_list) 58 | 59 | def __len__(self): 60 | if self.length == None: 61 | return self.num_images 62 | else: 63 | return self.length 64 | 65 | def crop_patch(self, im): 66 | pch_size = self.pch_size 67 | H, W, _ = im.shape 68 | if H < self.pch_size or W < self.pch_size: 69 | H = max(pch_size, H) 70 | W = max(pch_size, W) 71 | im = cv2.resize(im, (W, H)) 72 | ind_H = random.randint(0, H-pch_size) 73 | ind_W = random.randint(0, W-pch_size) 74 | im_pch = im[ind_H:ind_H+pch_size, ind_W:ind_W+pch_size,] 75 | return im_pch 76 | 77 | def data_augmentation(image, mode): 78 | ''' 79 | Performs data augmentation of the input image 80 | Input: 81 | image: a cv2 (OpenCV) image 82 | mode: int. Choice of transformation to apply to the image 83 | 0 - no transformation 84 | 1 - flip up and down 85 | 2 - rotate counterwise 90 degree 86 | 3 - rotate 90 degree and flip up and down 87 | 4 - rotate 180 degree 88 | 5 - rotate 180 degree and flip 89 | 6 - rotate 270 degree 90 | 7 - rotate 270 degree and flip 91 | ''' 92 | if mode == 0: 93 | # original 94 | out = image 95 | elif mode == 1: 96 | # flip up and down 97 | out = np.flipud(image) 98 | elif mode == 2: 99 | # rotate counterwise 90 degree 100 | out = np.rot90(image) 101 | elif mode == 3: 102 | # rotate 90 degree and flip up and down 103 | out = np.rot90(image) 104 | out = np.flipud(out) 105 | elif mode == 4: 106 | # rotate 180 degree 107 | out = np.rot90(image, k=2) 108 | elif mode == 5: 109 | # rotate 180 degree and flip 110 | out = np.rot90(image, k=2) 111 | out = np.flipud(out) 112 | elif mode == 6: 113 | # rotate 270 degree 114 | out = np.rot90(image, k=3) 115 | elif mode == 7: 116 | # rotate 270 degree and flip 117 | out = np.rot90(image, k=3) 118 | out = np.flipud(out) 119 | else: 120 | raise Exception('Invalid choice of image transformation') 121 | 122 | return out 123 | 124 | def random_augmentation(*args): 125 | out = [] 126 | if random.randint(0,1) == 1: 127 | flag_aug = random.randint(1,7) 128 | for data in args: 129 | out.append(data_augmentation(data, flag_aug).copy()) 130 | else: 131 | for data in args: 132 | out.append(data) 133 | return out 134 | 135 | # Benchmardk Datasets: and SIDD 136 | class BenchmarkTrain(BaseDataSetH5): 137 | def __init__(self, h5_file, length, pch_size=128, mask=False): 138 | super(BenchmarkTrain, self).__init__(h5_file, length) 139 | self.pch_size = pch_size 140 | self.mask = mask 141 | 142 | 143 | def __getitem__(self, index): 144 | num_images = self.num_images 145 | ind_im = random.randint(0, num_images - 1) 146 | 147 | with h5.File(self.h5_path, 'r') as h5_file: 148 | imgs_sets = h5_file[self.keys[ind_im]] 149 | im_gt, im_noisy = self.crop_patch(imgs_sets) 150 | im_gt = img_as_float(im_gt) 151 | im_noisy = img_as_float(im_noisy) 152 | 153 | # data augmentation 154 | im_gt, im_noisy = random_augmentation(im_gt, im_noisy) 155 | 156 | im_gt = torch.from_numpy(im_gt.transpose((2, 0, 1))) 157 | im_noisy = torch.from_numpy(im_noisy.transpose((2, 0, 1))) 158 | 159 | if self.mask: 160 | return im_noisy, im_gt, torch.ones((1, 1, 1), dtype=torch.float32) 161 | else: 162 | return im_noisy, im_gt 163 | 164 | class BenchmarkTest(BaseDataSetH5): 165 | def __getitem__(self, index): 166 | with h5.File(self.h5_path, 'r') as h5_file: 167 | imgs_sets = h5_file[self.keys[index]] 168 | C2 = imgs_sets.shape[2] 169 | C = int(C2/2) 170 | im_noisy = np.array(imgs_sets[:, :, :C]) 171 | im_gt = np.array(imgs_sets[:, :, C:]) 172 | im_gt = img_as_float(im_gt) 173 | im_noisy = img_as_float(im_noisy) 174 | 175 | im_gt = torch.from_numpy(im_gt.transpose((2, 0, 1))) 176 | im_noisy = torch.from_numpy(im_noisy.transpose((2, 0, 1))) 177 | 178 | return im_noisy, im_gt 179 | 180 | class FakeTrain(BaseDataSetFolder): 181 | def __init__(self, path_list, length, pch_size=128): 182 | super(FakeTrain, self).__init__(path_list, pch_size, length) 183 | 184 | def __getitem__(self, index): 185 | num_images = self.num_images 186 | ind_im = random.randint(0, num_images-1) 187 | 188 | im_gt = img_as_float(cv2.imread(self.path_list[ind_im], 1)[:, :, ::-1]) 189 | im_gt = self.crop_patch(im_gt) 190 | 191 | # data augmentation 192 | im_gt = random_augmentation(im_gt)[0] 193 | 194 | im_gt = torch.from_numpy(im_gt.transpose((2, 0, 1))) 195 | 196 | return im_gt, im_gt, torch.zeros((1,1,1), dtype=torch.float32) 197 | 198 | class PolyuTrain(BaseDataSetFolder): 199 | def __init__(self, path_list, length, pch_size=128, mask=False): 200 | super(PolyuTrain, self).__init__(path_list, pch_size, length) 201 | self.mask = mask 202 | 203 | def __getitem__(self, index): 204 | num_images = self.num_images 205 | ind_im = random.randint(0, num_images-1) 206 | 207 | path_noisy = self.path_list[ind_im] 208 | head, tail = os.path.split(path_noisy) 209 | path_gt = os.path.join(head, tail.replace('real', 'mean')) 210 | im_noisy = img_as_float(cv2.imread(path_noisy, 1)[:, :, ::-1]) 211 | im_gt = img_as_float(cv2.imread(path_gt, 1)[:, :, ::-1]) 212 | im_noisy, im_gt = self.crop_patch(im_noisy, im_gt) 213 | 214 | # data augmentation 215 | im_gt, im_noisy = random_augmentation(im_gt, im_noisy) 216 | 217 | im_gt = torch.from_numpy(im_gt.transpose((2, 0, 1))) 218 | im_noisy = torch.from_numpy(im_noisy.transpose((2, 0, 1))) 219 | 220 | if self.mask: 221 | return im_noisy, im_gt, torch.ones((1,1,1), dtype=torch.float32) 222 | else: 223 | return im_noisy, im_gt 224 | 225 | def crop_patch(self, im_noisy, im_gt): 226 | pch_size = self.pch_size 227 | H, W, _ = im_noisy.shape 228 | ind_H = random.randint(0, H-pch_size) 229 | ind_W = random.randint(0, W-pch_size) 230 | im_pch_noisy = im_noisy[ind_H:ind_H+pch_size, ind_W:ind_W+pch_size,] 231 | im_pch_gt = im_gt[ind_H:ind_H+pch_size, ind_W:ind_W+pch_size,] 232 | return im_pch_noisy, im_pch_gt 233 | -------------------------------------------------------------------------------- /networks3/util.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 8 | 9 | 10 | def np2tensor(n:np.array): 11 | ''' 12 | transform numpy array (image) to torch Tensor 13 | BGR -> RGB 14 | (h,w,c) -> (c,h,w) 15 | ''' 16 | # gray 17 | if len(n.shape) == 2: 18 | return torch.from_numpy(np.ascontiguousarray(np.transpose(n, (2,0,1)))) 19 | # RGB -> BGR 20 | elif len(n.shape) == 3: 21 | return torch.from_numpy(np.ascontiguousarray(np.transpose(np.flip(n, axis=2), (2,0,1)))) 22 | else: 23 | raise RuntimeError('wrong numpy dimensions : %s'%(n.shape,)) 24 | 25 | def tensor2np(t:torch.Tensor): 26 | ''' 27 | transform torch Tensor to numpy having opencv image form. 28 | RGB -> BGR 29 | (c,h,w) -> (h,w,c) 30 | ''' 31 | t = t.cpu().detach() 32 | 33 | # gray 34 | if len(t.shape) == 2: 35 | return t.permute(1,2,0).numpy() 36 | # RGB -> BGR 37 | elif len(t.shape) == 3: 38 | return np.flip(t.permute(1,2,0).numpy(), axis=2) 39 | # image batch 40 | elif len(t.shape) == 4: 41 | return np.flip(t.permute(0,2,3,1).numpy(), axis=3) 42 | else: 43 | raise RuntimeError('wrong tensor dimensions : %s'%(t.shape,)) 44 | 45 | def imwrite_tensor(t, name='test.png'): 46 | cv2.imwrite('./%s'%name, tensor2np(t.cpu())) 47 | 48 | def imread_tensor(name='test'): 49 | return np2tensor(cv2.imread('./%s'%name)) 50 | 51 | def rot_hflip_img(img:torch.Tensor, rot_times:int=0, hflip:int=0): 52 | ''' 53 | rotate '90 x times degree' & horizontal flip image 54 | (shape of img: b,c,h,w or c,h,w) 55 | ''' 56 | b=0 if len(img.shape)==3 else 1 57 | # no flip 58 | if hflip % 2 == 0: 59 | # 0 degrees 60 | if rot_times % 4 == 0: 61 | return img 62 | # 90 degrees 63 | elif rot_times % 4 == 1: 64 | return img.flip(b+1).transpose(b+1,b+2) 65 | # 180 degrees 66 | elif rot_times % 4 == 2: 67 | return img.flip(b+2).flip(b+1) 68 | # 270 degrees 69 | else: 70 | return img.flip(b+2).transpose(b+1,b+2) 71 | # horizontal flip 72 | else: 73 | # 0 degrees 74 | if rot_times % 4 == 0: 75 | return img.flip(b+2) 76 | # 90 degrees 77 | elif rot_times % 4 == 1: 78 | return img.flip(b+1).flip(b+2).transpose(b+1,b+2) 79 | # 180 degrees 80 | elif rot_times % 4 == 2: 81 | return img.flip(b+1) 82 | # 270 degrees 83 | else: 84 | return img.transpose(b+1,b+2) 85 | 86 | def pixel_shuffle_down_sampling(x:torch.Tensor, f:int, pad:int=0, pad_value:float=0.): 87 | ''' 88 | pixel-shuffle down-sampling (PD) from "When AWGN-denoiser meets real-world noise." (AAAI 2019) 89 | Args: 90 | x (Tensor) : input tensor 91 | f (int) : factor of PD 92 | pad (int) : number of pad between each down-sampled images 93 | pad_value (float) : padding value 94 | Return: 95 | pd_x (Tensor) : down-shuffled image tensor with pad or not 96 | ''' 97 | # single image tensor 98 | if len(x.shape) == 3: 99 | c,w,h = x.shape 100 | unshuffled = F.pixel_unshuffle(x, f) 101 | if pad != 0: unshuffled = F.pad(unshuffled, (pad, pad, pad, pad), value=pad_value) 102 | return unshuffled.view(c,f,f,w//f+2*pad,h//f+2*pad).permute(0,1,3,2,4).reshape(c, w+2*f*pad, h+2*f*pad) 103 | # batched image tensor 104 | else: 105 | b,c,w,h = x.shape 106 | unshuffled = F.pixel_unshuffle(x, f) 107 | if pad != 0: unshuffled = F.pad(unshuffled, (pad, pad, pad, pad), value=pad_value) 108 | return unshuffled.view(b,c,f,f,w//f+2*pad,h//f+2*pad).permute(0,1,2,4,3,5).reshape(b,c,w+2*f*pad, h+2*f*pad) 109 | 110 | def pixel_shuffle_up_sampling(x:torch.Tensor, f:int, pad:int=0): 111 | ''' 112 | inverse of pixel-shuffle down-sampling (PD) 113 | see more details about PD in pixel_shuffle_down_sampling() 114 | Args: 115 | x (Tensor) : input tensor 116 | f (int) : factor of PD 117 | pad (int) : number of pad will be removed 118 | ''' 119 | # single image tensor 120 | if len(x.shape) == 3: 121 | c,w,h = x.shape 122 | before_shuffle = x.view(c,f,w//f,f,h//f).permute(0,1,3,2,4).reshape(c*f*f,w//f,h//f) 123 | if pad != 0: before_shuffle = before_shuffle[..., pad:-pad, pad:-pad] 124 | return F.pixel_shuffle(before_shuffle, f) 125 | # batched image tensor 126 | else: 127 | b,c,w,h = x.shape 128 | before_shuffle = x.view(b,c,f,w//f,f,h//f).permute(0,1,2,4,3,5).reshape(b,c*f*f,w//f,h//f) 129 | if pad != 0: before_shuffle = before_shuffle[..., pad:-pad, pad:-pad] 130 | return F.pixel_shuffle(before_shuffle, f) 131 | 132 | def human_format(num): 133 | magnitude=0 134 | while abs(num)>=1000: 135 | magnitude+=1 136 | num/=1000.0 137 | return '%.1f%s'%(num,['','K','M','G','T','P'][magnitude]) 138 | 139 | def psnr(img1, img2): 140 | ''' 141 | image value range : [0 - 255] 142 | clipping for model output 143 | ''' 144 | if len(img1.shape) == 4: 145 | img1 = img1[0] 146 | if len(img2.shape) == 4: 147 | img2 = img2[0] 148 | 149 | # tensor to numpy 150 | if isinstance(img1, torch.Tensor): 151 | img1 = tensor2np(img1) 152 | if isinstance(img2, torch.Tensor): 153 | img2 = tensor2np(img2) 154 | 155 | # numpy value cliping & chnage type to uint8 156 | img1 = np.clip(img1, 0, 255) 157 | img2 = np.clip(img2, 0, 255) 158 | 159 | return peak_signal_noise_ratio(img1, img2, data_range=255) 160 | 161 | def ssim(img1, img2): 162 | ''' 163 | image value range : [0 - 255] 164 | clipping for model output 165 | ''' 166 | if len(img1.shape) == 4: 167 | img1 = img1[0] 168 | if len(img2.shape) == 4: 169 | img2 = img2[0] 170 | 171 | # tensor to numpy 172 | if isinstance(img1, torch.Tensor): 173 | img1 = tensor2np(img1) 174 | if isinstance(img2, torch.Tensor): 175 | img2 = tensor2np(img2) 176 | 177 | # numpy value cliping 178 | img2 = np.clip(img2, 0, 255) 179 | img1 = np.clip(img1, 0, 255) 180 | 181 | return structural_similarity(img1, img2, multichannel=True, data_range=255) 182 | 183 | def get_gaussian_2d_filter(window_size, sigma, channel=1, device=torch.device('cpu')): 184 | ''' 185 | return 2d gaussian filter window as tensor form 186 | Arg: 187 | window_size : filter window size 188 | sigma : standard deviation 189 | ''' 190 | gauss = torch.ones(window_size, device=device) 191 | for x in range(window_size): gauss[x] = exp(-(x - window_size//2)**2/float(2*sigma**2)) 192 | gauss = gauss.unsqueeze(1) 193 | #gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)], device=device).unsqueeze(1) 194 | filter2d = gauss.mm(gauss.t()).float() 195 | filter2d = (filter2d/filter2d.sum()).unsqueeze(0).unsqueeze(0) 196 | return filter2d.expand(channel, 1, window_size, window_size) 197 | 198 | def get_mean_2d_filter(window_size, channel=1, device=torch.device('cpu')): 199 | ''' 200 | return 2d mean filter as tensor form 201 | Args: 202 | window_size : filter window size 203 | ''' 204 | window = torch.ones((window_size, window_size), device=device) 205 | window = (window/window.sum()).unsqueeze(0).unsqueeze(0) 206 | return window.expand(channel, 1, window_size, window_size) 207 | 208 | def mean_conv2d(x, window_size=None, window=None, filter_type='gau', sigma=None, keep_sigma=False, padd=True): 209 | ''' 210 | color channel-wise 2d mean or gaussian convolution 211 | Args: 212 | x : input image 213 | window_size : filter window size 214 | filter_type(opt) : 'gau' or 'mean' 215 | sigma : standard deviation of gaussian filter 216 | ''' 217 | b_x = x.unsqueeze(0) if len(x.shape) == 3 else x 218 | 219 | if window is None: 220 | if sigma is None: sigma = (window_size-1)/6 221 | if filter_type == 'gau': 222 | window = get_gaussian_2d_filter(window_size, sigma=sigma, channel=b_x.shape[1], device=x.device) 223 | else: 224 | window = get_mean_2d_filter(window_size, channel=b_x.shape[1], device=x.device) 225 | else: 226 | window_size = window.shape[-1] 227 | 228 | if padd: 229 | pl = (window_size-1)//2 230 | b_x = F.pad(b_x, (pl,pl,pl,pl), 'reflect') 231 | 232 | m_b_x = F.conv2d(b_x, window, groups=b_x.shape[1]) 233 | 234 | if keep_sigma: 235 | m_b_x /= (window**2).sum().sqrt() 236 | 237 | if len(x.shape) == 4: 238 | return m_b_x 239 | elif len(x.shape) == 3: 240 | return m_b_x.squeeze(0) 241 | else: 242 | raise ValueError('input image shape is not correct') 243 | -------------------------------------------------------------------------------- /networks3/UNetD.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-03-20 19:48:14 4 | # Adapted from https://github.com/jvanvugt/pytorch-unet 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from .SubBlocks import conv3x3, conv_down, conv1x1 10 | 11 | 12 | class UNetD(nn.Module): 13 | def __init__(self, in_chn, wf=32, depth=5, relu_slope=0.2): 14 | super(UNetD, self).__init__() 15 | self.depth = depth 16 | self.down_path = nn.ModuleList() 17 | prev_channels = self.get_input_chn(in_chn) 18 | for i in range(depth): 19 | downsample = True if (i+1) < depth else False 20 | self.down_path.append(UNetConvBlock(prev_channels, (2**i)*wf, downsample, relu_slope)) 21 | prev_channels = (2**i) * wf 22 | 23 | self.up_path = nn.ModuleList() 24 | for i in reversed(range(depth - 1)): 25 | self.up_path.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope)) 26 | prev_channels = (2**i)*wf 27 | 28 | self.last = conv3x3(prev_channels, in_chn, bias=True) 29 | # self._initialize() 30 | 31 | def forward(self, x1): 32 | res = x1 33 | blocks = [] 34 | for i, down in enumerate(self.down_path): 35 | if (i+1) < self.depth: 36 | x1, x1_up = down(x1) 37 | blocks.append(x1_up) 38 | else: 39 | x1 = down(x1) 40 | 41 | for i, up in enumerate(self.up_path): 42 | x1 = up(x1, blocks[-i-1]) 43 | 44 | out = self.last(x1) 45 | return out+res 46 | 47 | def get_input_chn(self, in_chn): 48 | return in_chn 49 | 50 | def _initialize(self): 51 | gain = nn.init.calculate_gain('leaky_relu', 0.20) 52 | for m in self.modules(): 53 | if isinstance(m, nn.Conv2d): 54 | nn.init.orthogonal_(m.weight, gain=gain) 55 | if not m.bias is None: 56 | nn.init.constant_(m.bias, 0) 57 | class spatial_attn_layer(nn.Module): 58 | def __init__(self, kernel_size=3): 59 | super(spatial_attn_layer, self).__init__() 60 | self.compress = ChannelPool() 61 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 62 | def forward(self, x): 63 | # import pdb;pdb.set_trace() 64 | x_compress = self.compress(x) 65 | x_out = self.spatial(x_compress) 66 | scale = torch.sigmoid(x_out) # broadcasting 67 | return x * scale 68 | 69 | class BasicConv(nn.Module): 70 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 71 | bn=False, bias=False): 72 | super(BasicConv, self).__init__() 73 | self.out_channels = out_planes 74 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 75 | dilation=dilation, groups=groups, bias=bias) 76 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 77 | self.relu = nn.ReLU() if relu else None 78 | 79 | def forward(self, x): 80 | x = self.conv(x) 81 | if self.bn is not None: 82 | x = self.bn(x) 83 | if self.relu is not None: 84 | x = self.relu(x) 85 | return x 86 | class ChannelPool(nn.Module): 87 | def forward(self, x): 88 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 89 | 90 | class CALayer(nn.Module): 91 | def __init__(self, channel, reduction=16): 92 | super(CALayer, self).__init__() 93 | # global average pooling: feature --> point 94 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 95 | # feature channel downscale and upscale --> channel weight 96 | self.conv_du = nn.Sequential( 97 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 98 | nn.ReLU(inplace=True), 99 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 100 | nn.Sigmoid() 101 | ) 102 | 103 | def forward(self, x): 104 | y = self.avg_pool(x) 105 | y = self.conv_du(y) 106 | return x * y 107 | class UNetConvBlock(nn.Module): 108 | def __init__(self, in_size, out_size, downsample, relu_slope): 109 | super(UNetConvBlock, self).__init__() 110 | self.downsample = downsample 111 | self.block = nn.Sequential( 112 | nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True), 113 | nn.LeakyReLU(relu_slope, inplace=True), 114 | nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True), 115 | nn.LeakyReLU(relu_slope, inplace=True)) 116 | if downsample: 117 | self.downsample = conv_down(out_size, out_size, bias=False) 118 | self.SA = spatial_attn_layer() ## Spatial Attention 119 | self.CA = CALayer(in_size, 8) ## Channel Attention 120 | self.conv1x1 = nn.Conv2d(2 * in_size, out_size, kernel_size=1) 121 | 122 | 123 | def forward(self, x): 124 | out = self.block(x) 125 | sa_branch = self.SA(out) 126 | ca_branch = self.CA(out) 127 | res = torch.cat([sa_branch, ca_branch], dim=1) 128 | out = self.conv1x1(res) 129 | if self.downsample: 130 | out_down = self.downsample(out) 131 | return out_down, out 132 | else: 133 | return out 134 | 135 | class UNetUpBlock(nn.Module): 136 | def __init__(self, in_size, out_size, relu_slope): 137 | super(UNetUpBlock, self).__init__() 138 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True) 139 | self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope) 140 | 141 | def forward(self, x, bridge): 142 | up = self.up(x) 143 | out = torch.cat([up, bridge], 1) 144 | out = self.conv_block(out) 145 | 146 | return out 147 | 148 | class DnCNN(nn.Module): 149 | def __init__(self, channels, num_of_layers=17): 150 | super(DnCNN, self).__init__() 151 | kernel_size = 3 152 | padding = 1 153 | features = 64 154 | layers = [] 155 | layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 156 | layers.append(nn.ReLU(inplace=True)) 157 | for _ in range(num_of_layers-2): 158 | layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 159 | layers.append(nn.BatchNorm2d(features)) 160 | layers.append(nn.ReLU(inplace=True)) 161 | layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False)) 162 | self.dncnn = nn.Sequential(*layers) 163 | def forward(self, x): 164 | out = self.dncnn(x) 165 | return out 166 | 167 | class tizao(nn.Module): 168 | def __init__(self, in_chn, wf=32, depth=5, relu_slope=0.2): 169 | super(tizao, self).__init__() 170 | self.depth = depth 171 | self.down_path = nn.ModuleList() 172 | prev_channels = self.get_input_chn(in_chn) 173 | for i in range(depth): 174 | downsample = True if (i+1) < depth else False 175 | self.down_path.append(UNetConvBlock(prev_channels, (2**i)*wf, downsample, relu_slope)) 176 | prev_channels = (2**i) * wf 177 | 178 | self.up_path = nn.ModuleList() 179 | for i in reversed(range(depth - 1)): 180 | self.up_path.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope)) 181 | prev_channels = (2**i)*wf 182 | 183 | self.last = conv1x1(prev_channels, in_chn, bias=True) 184 | # self._initialize() 185 | 186 | def forward(self, x1): 187 | res = x1 188 | blocks = [] 189 | for i, down in enumerate(self.down_path): 190 | if (i+1) < self.depth: 191 | x1, x1_up = down(x1) 192 | blocks.append(x1_up) 193 | else: 194 | x1 = down(x1) 195 | 196 | for i, up in enumerate(self.up_path): 197 | x1 = up(x1, blocks[-i-1]) 198 | 199 | out = self.last(x1) 200 | return out 201 | 202 | def get_input_chn(self, in_chn): 203 | return in_chn 204 | 205 | def _initialize(self): 206 | gain = nn.init.calculate_gain('leaky_relu', 0.20) 207 | for m in self.modules(): 208 | if isinstance(m, nn.Conv2d): 209 | nn.init.orthogonal_(m.weight, gain=gain) 210 | if not m.bias is None: 211 | nn.init.constant_(m.bias, 0) 212 | 213 | 214 | class shengcheng(nn.Module): 215 | def __init__(self, stride=2): 216 | super(shengcheng, self).__init__() 217 | def forward(self, x, y): 218 | # z = torch.cat([x, self.scale*y], dim=1) 219 | 220 | # z = torch.cat([x, y], dim=1) 221 | z = torch.cat([x, y - tizao.y], dim=1) 222 | out = self.conv_input(z) 223 | 224 | out = self.residual(out) 225 | 226 | out = self.conv_output(out) 227 | 228 | return out + x 229 | 230 | -------------------------------------------------------------------------------- /GaussianSmoothLayer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import torch 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import cv2 9 | import torchvision.transforms as transforms 10 | from PIL import Image 11 | import matplotlib.pyplot as plt 12 | from skimage import img_as_ubyte 13 | import time 14 | from skimage import io 15 | 16 | class GaussionSmoothLayer(nn.Module): 17 | def __init__(self, channel, kernel_size, sigma, dim = 2): 18 | super(GaussionSmoothLayer, self).__init__() 19 | kernel_x = cv2.getGaussianKernel(kernel_size, sigma) 20 | kernel_y = cv2.getGaussianKernel(kernel_size, sigma) 21 | kernel = kernel_x * kernel_y.T 22 | self.kernel_data = kernel 23 | self.groups = channel 24 | if dim == 1: 25 | self.conv = nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, \ 26 | groups= channel, bias= False) 27 | elif dim == 2: 28 | self.conv = nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, \ 29 | groups= channel, bias= False) 30 | elif dim == 3: 31 | self.conv = nn.Conv3d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, \ 32 | groups= channel, bias= False) 33 | raise RuntimeError( 34 | 'input dim is not supported !, please check it !' 35 | ) 36 | self.conv.weight.requires_grad = False 37 | for name, f in self.named_parameters(): 38 | f.data.copy_(torch.from_numpy(kernel)) 39 | self.pad = int((kernel_size - 1) / 2) 40 | def forward(self, input): 41 | intdata = input 42 | intdata = F.pad(intdata, (self.pad, self.pad, self.pad, self.pad), mode='reflect') 43 | 44 | output = self.conv(intdata) 45 | return output 46 | 47 | class LapLasGradient(nn.Module): 48 | def __init__(self, indim, outdim): 49 | super(LapLasGradient, self).__init__() 50 | # @ define the sobel filter for x and y axis 51 | kernel = torch.tensor( 52 | [[0, -1, 0], 53 | [-1, 4, -1], 54 | [0, -1, 0] 55 | ] 56 | ) 57 | kernel2 = torch.tensor( 58 | [[0, 1, 0], 59 | [1, -4, 1], 60 | [0, 1, 0] 61 | ] 62 | ) 63 | kernel3 = torch.tensor( 64 | [[-1, -1, -1], 65 | [-1, 8, -1], 66 | [-1, -1, -1] 67 | ] 68 | ) 69 | kernel4 = torch.tensor( 70 | [[1, 1, 1], 71 | [1, -8, 1], 72 | [1, 1, 1] 73 | ] 74 | ) 75 | self.conv = nn.Conv2d(indim, outdim, 3, 1, padding= 1, bias=False) 76 | self.conv.weight.data.copy_(kernel4) 77 | self.conv.weight.requires_grad = False 78 | 79 | def forward(self, x): 80 | grad = self.conv(x) 81 | return grad 82 | 83 | 84 | 85 | class GradientLoss(nn.Module): 86 | def __init__(self, indim, outdim): 87 | super(GradientLoss, self).__init__() 88 | # @ define the sobel filter for x and y axis 89 | x_kernel = torch.tensor( 90 | [ [1, 0, -1], 91 | [2, 0, -2], 92 | [1, 0, -1] 93 | ] 94 | ) 95 | y_kernel = torch.tensor( 96 | [[1, 2, 1], 97 | [0, 0, 0], 98 | [-1, -2, -1] 99 | ] 100 | ) 101 | self.conv_x = nn.Conv2d(indim, outdim, 3, 1, padding= 1, bias=False) 102 | self.conv_y = nn.Conv2d(indim, outdim, 3, 1, padding= 1, bias=False) 103 | 104 | self.conv_x.weight.data.copy_(x_kernel) 105 | self.conv_y.weight.data.copy_(y_kernel) 106 | self.conv_x.weight.requires_grad = False 107 | self.conv_y.weight.requires_grad = False 108 | 109 | def forward(self, x): 110 | grad_x = self.conv_x(x) 111 | grad_y = self.conv_y(x) 112 | gradient = torch.sqrt(torch.pow(grad_x, 2) + torch.pow(grad_y, 2)) 113 | return gradient 114 | class GradientLoss_v1(nn.Module): 115 | def __init__(self, indim, outdim): 116 | super(GradientLoss_v1, self).__init__() 117 | # @ define the sobel filter for x and y axis 118 | x_kernel = torch.tensor( 119 | [[0, -1, 0], 120 | [0, 0, 0], 121 | [0, 1, 0]] 122 | ) 123 | y_kernel = torch.tensor( 124 | [[0, 0, 0], 125 | [-1, 0, 1], 126 | [0, 0, 0]] 127 | ) 128 | self.conv_x = nn.Conv2d(indim, outdim, 3, 1, padding= 1, bias=False) 129 | self.conv_y = nn.Conv2d(indim, outdim, 3, 1, padding= 1, bias=False) 130 | 131 | self.conv_x.weight.data.copy_(x_kernel) 132 | self.conv_y.weight.data.copy_(y_kernel) 133 | self.conv_x.weight.requires_grad = False 134 | self.conv_y.weight.requires_grad = False 135 | 136 | def forward(self, x): 137 | grad_x = self.conv_x(x) 138 | grad_y = self.conv_y(x) 139 | gradient = torch.sqrt(torch.pow(grad_x, 2) + torch.pow(grad_y, 2)) 140 | return gradient 141 | def Norm(x): 142 | Max_item = torch.max(x) 143 | Min_item = torch.min(x) 144 | return (x-Min_item)/(Max_item-Min_item) 145 | 146 | class Get_gradient(nn.Module): 147 | def __init__(self): 148 | super(Get_gradient, self).__init__() 149 | kernel_v = [[0, -1, 0], 150 | [0, 0, 0], 151 | [0, 1, 0]] 152 | kernel_h = [[0, 0, 0], 153 | [-1, 0, 1], 154 | [0, 0, 0]] 155 | kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) 156 | kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) 157 | self.weight_h = nn.Parameter(data = kernel_h, requires_grad = False) 158 | self.weight_v = nn.Parameter(data = kernel_v, requires_grad = False) 159 | 160 | def forward(self, x): 161 | x0 = x[:, 0] 162 | x1 = x[:, 1] 163 | x2 = x[:, 2] 164 | x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1) 165 | x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1) 166 | 167 | x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=1) 168 | x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=1) 169 | 170 | x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=1) 171 | x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=1) 172 | 173 | x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) 174 | x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6) 175 | x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6) 176 | 177 | x = torch.cat([x0, x1, x2], dim=1) 178 | return x 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | def main(file1,file2): 187 | mat = cv2.imread(file1) 188 | nmat = cv2.imread(file2) 189 | tensor = torch.from_numpy(mat).float() 190 | tensor1 = torch.from_numpy(nmat).float() 191 | 192 | # 11, 17, 25, 50 193 | blurkernel = GaussionSmoothLayer(3, 11, 50) 194 | gradloss = GradientLoss(3, 3) 195 | 196 | 197 | tensor = tensor.permute(2, 0, 1) 198 | tensor = torch.unsqueeze(tensor, dim = 0) 199 | 200 | tensor1 = tensor1.permute(2, 0, 1) 201 | tensor1 = torch.unsqueeze(tensor1, dim = 0) 202 | 203 | out = blurkernel(tensor) 204 | out1 = blurkernel(tensor1) 205 | 206 | loss = gradloss(out) 207 | loss1 = gradloss(out1) 208 | 209 | out = out.permute(0, 2, 3, 1).int() 210 | out = out.numpy().squeeze().astype(np.uint8) 211 | 212 | out1 = out1.permute(0, 2, 3, 1).int() 213 | out1 = out1.numpy().squeeze().astype(np.uint8) 214 | 215 | cv2.imshow("1", out) 216 | cv2.imshow("2", out1) 217 | cv2.waitKey(0) 218 | 219 | # \ 220 | # 221 | def testPIL(file1, file2): 222 | transform = transforms.Compose([ 223 | transforms.ToTensor() 224 | ]) 225 | image11 = transform(Image.open(file1).convert('RGB')).unsqueeze(0) 226 | image22 = transform(Image.open(file2).convert('RGB')).unsqueeze(0) 227 | # image1 = image11 - F.avg_pool2d( 228 | # F.pad(image11, (2, 2, 2, 2), mode='reflect'), 5, 1, padding=0) 229 | 230 | blurkernel = GaussionSmoothLayer(3, 15, 25) 231 | # blurkernel = Get_gradient() 232 | # blurkerne2 = GradientLoss_v1(3,3) 233 | # image1 = image11-blurkernel(image11) 234 | image2 = blurkernel(image11) 235 | 236 | 237 | # # print('自定义花费时间为:{:.8f}s'.format(time.time() - t1)) 238 | # image1 = Norm(image1) 239 | # image2 = Norm(image2) 240 | # print(F.mse_loss(image11, image22+image1)) 241 | 242 | 243 | image2 = image2.numpy().squeeze() 244 | 245 | # 246 | # 247 | # 248 | image2 = np.transpose(image2, (1, 2, 0)) 249 | 250 | io.imsave('NOI_SRGB_%d_%d.png' % (1, 1), np.uint8(np.round(image2*255))) 251 | 252 | 253 | # image2 = ((image22+image1).clamp_(0,1)).numpy().squeeze() 254 | 255 | # image2 = np.transpose(image2, (1,2,0)) 256 | # 257 | # img3 = np.transpose(img3, (1, 2, 0)) 258 | 259 | 260 | 261 | # plt.figure('1') 262 | # plt.imshow(img_as_ubyte(image11), interpolation='nearest') 263 | # plt.figure('2') 264 | # plt.imshow(img_as_ubyte(image2), interpolation='nearest') 265 | # 266 | # plt.figure('3') 267 | # # plt.imshow(img_as_ubyte(img3), interpolation='nearest') 268 | # plt.show() 269 | 270 | def adjust_learning_rate(epoch): 271 | """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" 272 | lr = 0.0001 * (0.5 ** ((epoch) //40)) 273 | # lr = opt['lr'] 274 | return str(lr) 275 | if __name__ == "__main__": 276 | # for i in range(0,400): 277 | # print("第{:0>3d}epoch的学习率是:".format(i)+adjust_learning_rate(i)) 278 | file2 = './figs/NOISY_SRGB_0_0.png' 279 | file1 = './figs/NOISY_SRGB_0_0.png' 280 | # # main(file1, file2) 281 | testPIL(file1, file2) 282 | 283 | 284 | -------------------------------------------------------------------------------- /networks3/Discriminator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-09-18 22:31:45 4 | 5 | import torch.nn as nn 6 | import torch 7 | import torch.nn.init as init 8 | import torch.nn.utils as utils 9 | import torch.nn.functional as F 10 | from .SubBlocks import conv_down 11 | from GaussianSmoothLayer import GaussionSmoothLayer 12 | 13 | class DiscriminatorLinear(nn.Module): 14 | def __init__(self, in_chn, ndf=64, slope=0.2): 15 | ''' 16 | ndf: number of filters 17 | ''' 18 | super(DiscriminatorLinear, self).__init__() 19 | self.ndf = ndf 20 | # input is N x C x 128 x 128 21 | main_module = [conv_down(in_chn, ndf, bias=False), 22 | nn.LeakyReLU(slope, inplace=True)] 23 | # state size: N x ndf x 64 x 64 24 | main_module.append(conv_down(ndf, ndf*2, bias=False)) 25 | main_module.append(nn.LeakyReLU(slope, inplace=True)) 26 | # state size: N x (ndf*2) x 32 x 32 27 | main_module.append(conv_down(ndf*2, ndf*4, bias=False)) 28 | main_module.append(nn.LeakyReLU(slope, inplace=True)) 29 | # state size: N x (ndf*4) x 16 x 16 30 | main_module.append(conv_down(ndf*4, ndf*8, bias=False)) 31 | main_module.append(nn.LeakyReLU(slope, inplace=True)) 32 | # state size: N x (ndf*8) x 8 x 8 33 | main_module.append(conv_down(ndf*8, ndf*16, bias=False)) 34 | main_module.append(nn.LeakyReLU(slope, inplace=True)) 35 | # state size: N x (ndf*16) x 4 x 4 36 | main_module.append(nn.Conv2d(ndf*16, ndf*32, 4, stride=1, padding=0, bias=False)) 37 | main_module.append(nn.LeakyReLU(slope, inplace=True)) 38 | # state size: N x (ndf*32) x 1 x 1 39 | self.main = nn.Sequential(*main_module) 40 | self.output = nn.Linear(ndf*32, 1) 41 | 42 | self._initialize() 43 | 44 | def forward(self, x): 45 | x = torch.cat([x, x - F.avg_pool2d( 46 | F.pad(x, (1, 1, 1, 1), mode='reflect'), 3, 1, padding=0)], dim=1) 47 | feature = self.main(x) 48 | feature = feature.view(-1, self.ndf*32) 49 | out = self.output(feature) 50 | return out.view(-1) 51 | 52 | def _initialize(self): 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 55 | init.normal_(m.weight.data, 0., 0.02) 56 | if not m.bias is None: 57 | init.constant_(m.bias, 0) 58 | 59 | 60 | class _NetD(nn.Module): 61 | def __init__(self, stride=1): 62 | super(_NetD, self).__init__() 63 | 64 | self.Gas = GaussionSmoothLayer(3, 15, 9) 65 | 66 | self.features = nn.Sequential( 67 | 68 | # input is (3) x 96 x 96 69 | nn.Conv2d(in_channels=6, out_channels=64, kernel_size=4, stride=stride, padding=1), 70 | nn.LeakyReLU(0.2, inplace=True), 71 | 72 | # state size. (64) x 96 x 96 73 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=stride, padding=1, bias=False), 74 | nn.BatchNorm2d(128), 75 | nn.LeakyReLU(0.2, inplace=True), 76 | 77 | # state size. (64) x 96 x 96 78 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=stride, padding=1, bias=False), 79 | nn.BatchNorm2d(256), 80 | nn.LeakyReLU(0.2, inplace=True), 81 | 82 | # state size. (64) x 48 x 48 83 | nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=False), 84 | nn.BatchNorm2d(512), 85 | nn.LeakyReLU(0.2, inplace=True), 86 | 87 | # state size. (128) x 48 x 48 88 | nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1), 89 | ) 90 | 91 | def forward(self, input): 92 | input = torch.cat([input, input - self.Gas(input)], dim=1) 93 | 94 | # input = torch.cat([input, input - F.avg_pool2d( 95 | # F.pad(input, (1, 1, 1, 1), mode='reflect'), 3, 1, padding=0)], dim=1) 96 | 97 | out = self.features(input) 98 | return out # self.sigmoid(out)#.view(-1, 1).squeeze(1) 99 | 100 | class Discriminator1(nn.Module): 101 | def __init__(self, num_conv_block=4): 102 | super(Discriminator1, self).__init__() 103 | 104 | block = [] 105 | 106 | in_channels = 6 107 | out_channels = 64 108 | 109 | for _ in range(num_conv_block): 110 | block += [nn.ReflectionPad2d(1), 111 | nn.Conv2d(in_channels, out_channels, 3), 112 | nn.LeakyReLU(), 113 | nn.BatchNorm2d(out_channels)] 114 | in_channels = out_channels 115 | 116 | block += [nn.ReflectionPad2d(1), 117 | nn.Conv2d(in_channels, out_channels, 3, 2), 118 | nn.LeakyReLU()] 119 | out_channels *= 2 120 | 121 | out_channels //= 2 122 | in_channels = out_channels 123 | 124 | block += [nn.Conv2d(in_channels, out_channels, 3), 125 | nn.LeakyReLU(0.2), 126 | nn.Conv2d(out_channels, out_channels, 3)] 127 | 128 | self.feature_extraction = nn.Sequential(*block) 129 | 130 | self.avgpool = nn.AdaptiveAvgPool2d((512, 512)) 131 | 132 | self.classification = nn.Sequential( 133 | nn.Linear(8192, 100), 134 | nn.Linear(100, 1) 135 | ) 136 | 137 | def forward(self, x): 138 | x = torch.cat([x, x - F.avg_pool2d(x, 3, 1, padding=1)], dim=1) 139 | 140 | x = self.feature_extraction(x) 141 | x = x.view(x.size(0), -1) 142 | x = self.classification(x) 143 | return x 144 | 145 | class VGGStyleDiscriminator128(nn.Module): 146 | """VGG style discriminator with input size 128 x 128. 147 | 148 | It is used to train SRGAN and ESRGAN. 149 | 150 | Args: 151 | num_in_ch (int): Channel number of inputs. Default: 3. 152 | num_feat (int): Channel number of base intermediate features. 153 | Default: 64. 154 | """ 155 | 156 | def __init__(self, num_in_ch=6, num_feat=64): 157 | super(VGGStyleDiscriminator128, self).__init__() 158 | 159 | self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True) 160 | self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False) 161 | self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True) 162 | 163 | self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False) 164 | self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True) 165 | self.conv1_1 = nn.Conv2d( 166 | num_feat * 2, num_feat * 2, 4, 2, 1, bias=False) 167 | self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True) 168 | 169 | self.conv2_0 = nn.Conv2d( 170 | num_feat * 2, num_feat * 4, 3, 1, 1, bias=False) 171 | self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True) 172 | self.conv2_1 = nn.Conv2d( 173 | num_feat * 4, num_feat * 4, 4, 2, 1, bias=False) 174 | self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True) 175 | 176 | self.conv3_0 = nn.Conv2d( 177 | num_feat * 4, num_feat * 8, 3, 1, 1, bias=False) 178 | self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True) 179 | self.conv3_1 = nn.Conv2d( 180 | num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) 181 | self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True) 182 | 183 | self.conv4_0 = nn.Conv2d( 184 | num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) 185 | self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True) 186 | self.conv4_1 = nn.Conv2d( 187 | num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) 188 | self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True) 189 | 190 | self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100) 191 | self.linear2 = nn.Linear(100, 1) 192 | 193 | # activation function 194 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 195 | 196 | def forward(self, x): 197 | assert x.size(2) == 128 and x.size(3) == 128, ( 198 | f'Input spatial size must be 128x128, ' 199 | f'but received {x.size()}.') 200 | x = torch.cat([x, x - F.avg_pool2d(x, 3, 1, padding=1)], dim=1) 201 | 202 | feat = self.lrelu(self.conv0_0(x)) 203 | feat = self.lrelu(self.bn0_1( 204 | self.conv0_1(feat))) # output spatial size: (64, 64) 205 | 206 | feat = self.lrelu(self.bn1_0(self.conv1_0(feat))) 207 | feat = self.lrelu(self.bn1_1( 208 | self.conv1_1(feat))) # output spatial size: (32, 32) 209 | 210 | feat = self.lrelu(self.bn2_0(self.conv2_0(feat))) 211 | feat = self.lrelu(self.bn2_1( 212 | self.conv2_1(feat))) # output spatial size: (16, 16) 213 | 214 | feat = self.lrelu(self.bn3_0(self.conv3_0(feat))) 215 | feat = self.lrelu(self.bn3_1( 216 | self.conv3_1(feat))) # output spatial size: (8, 8) 217 | 218 | feat = self.lrelu(self.bn4_0(self.conv4_0(feat))) 219 | feat = self.lrelu(self.bn4_1( 220 | self.conv4_1(feat))) # output spatial size: (4, 4) 221 | 222 | feat = feat.view(feat.size(0), -1) 223 | feat = self.lrelu(self.linear1(feat)) 224 | out = self.linear2(feat) 225 | return out 226 | 227 | import functools 228 | 229 | class NLayerDiscriminator(nn.Module): 230 | """Defines a PatchGAN discriminator""" 231 | 232 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d): 233 | """Construct a PatchGAN discriminator 234 | 235 | Parameters: 236 | input_nc (int) -- the number of channels in input images 237 | ndf (int) -- the number of filters in the last conv layer 238 | n_layers (int) -- the number of conv layers in the discriminator 239 | norm_layer -- normalization layer 240 | """ 241 | super(NLayerDiscriminator, self).__init__() 242 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 243 | use_bias = norm_layer.func == nn.InstanceNorm2d 244 | else: 245 | use_bias = norm_layer == nn.InstanceNorm2d 246 | 247 | kw = 4 248 | padw = 1 249 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)] 250 | nf_mult = 1 251 | nf_mult_prev = 1 252 | for n in range(1, n_layers): # gradually increase the number of filters 253 | nf_mult_prev = nf_mult 254 | nf_mult = min(2 ** n, 8) 255 | sequence += [ 256 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 257 | norm_layer(ndf * nf_mult), 258 | nn.LeakyReLU(0.2, False) 259 | ] 260 | 261 | nf_mult_prev = nf_mult 262 | nf_mult = min(2 ** n_layers, 8) 263 | sequence += [ 264 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 265 | norm_layer(ndf * nf_mult), 266 | nn.LeakyReLU(0.2, False) 267 | ] 268 | 269 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 270 | self.model = nn.Sequential(*sequence) 271 | # self.Gas = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=7, stride=1, padding=3) 272 | self.Gas = GaussionSmoothLayer(3, 15, 9) 273 | # kernel = 15 274 | # self.k = kernel // 2 275 | 276 | def forward(self, input): 277 | """Standard forward.""" 278 | input = torch.cat([input, input - self.Gas(input)], dim=1) 279 | # input = input 280 | # input = torch.cat([input, input - F.avg_pool2d( 281 | # F.pad(input, (self.k, self.k, self.k, self.k), mode='reflect'), 15, 1, padding=0)], dim=1) 282 | 283 | return self.model(input) 284 | def print_network(net): 285 | num_params = 0 286 | for param in net.parameters(): 287 | num_params += param.numel() 288 | print(net) 289 | print('Total number of parameters: %d' % num_params) 290 | # input = torch.rand(1,3,128,128).cuda() 291 | # Net = NLayerDiscriminator(6).cuda() 292 | # print_network(Net) 293 | # print(Net(input).size()) 294 | 295 | 296 | 297 | -------------------------------------------------------------------------------- /train_v6.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 3 | import sys 4 | import time 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.utils.data as uData 8 | from networks3 import _NetG_DOWN,NLayerDiscriminator,Deam 9 | from datasets.DenoisingDatasets import BenchmarkTrain, BenchmarkTest 10 | from math import ceil 11 | from utils import * 12 | from loss import get_gausskernel, GANLoss, log_SSIM_loss 13 | import warnings 14 | from pathlib import Path 15 | import commentjson as json 16 | from GaussianSmoothLayer import GaussionSmoothLayer 17 | # filter warnings 18 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 19 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 20 | device = torch.device('cuda:0'if torch.cuda.is_available()else'cpu') 21 | warnings.simplefilter('ignore', Warning, lineno=0) 22 | torch.set_default_dtype(torch.float32) 23 | _C = 3 24 | _modes = ['train', 'val'] 25 | BGBlur_kernel = [3, 9, 15] 26 | BlurWeight = [0.01,0.1,1.] 27 | # For blurring of BGMLOSS 28 | BlurNet = [GaussionSmoothLayer(3, k_size, 25).cuda() for k_size in BGBlur_kernel] 29 | 30 | def main(): 31 | with open('./configs/DANet_v5.json', 'r') as f: 32 | args = json.load(f) 33 | 34 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 35 | os.environ['CUDA_VISIBLE_DEVICES'] = args["gpu_id"] 36 | torch.backends.cudnn.enabled = True 37 | torch.backends.cudnn.benchmark = True 38 | # build up the E when SC 39 | netE = torch.nn.DataParallel(Deam(1)).cuda() 40 | # build up the denoiser 41 | print('start') 42 | netD = torch.nn.DataParallel(Deam(1)).cuda() 43 | print('for') 44 | # build up the generator 45 | netG = torch.nn.DataParallel(_NetG_DOWN(stride=1)).cuda() 46 | # build up the discriminator 47 | netP = torch.nn.DataParallel(NLayerDiscriminator(6)).cuda() 48 | 49 | 50 | criterionGAN = GANLoss(args['gan_mode']).cuda() 51 | init_weights(netG, init_type='normal',init_gain=0.02) 52 | init_weights(netP, init_type='normal', init_gain=0.02) 53 | # !!!Before the SC, No optimizerE and netE 54 | net = {'E':netE,'D': netD, 'G': netG, 'P': netP} 55 | # optimizer 56 | optimizerG = optim.Adam(netG.parameters(), lr=args['lr_G']) 57 | optimizerD = optim.Adam(netD.parameters(), lr=args['lr_D']) 58 | optimizerP = optim.Adam(netP.parameters(), lr=args['lr_P']) 59 | optimizer = {'D': optimizerD, 'G': optimizerG, 'P': optimizerP} 60 | if args['resume']: 61 | if Path(args['resume']).is_file(): 62 | print('=> Loading checkpoint {:s}'.format(str(Path(args['resume'])))) 63 | checkpoint = torch.load(str(Path(args['resume'])), map_location='cpu') 64 | args['epoch_start'] = checkpoint['epoch'] 65 | # args['epoch_start'] = 3 66 | # optimizerE.load_state_dict(checkpoint['optimizer_state_dict']['E']) 67 | # optimizerD.load_state_dict(checkpoint['optimizer_state_dict']['D']) 68 | # optimizerG.load_state_dict(checkpoint['optimizer_state_dict']['G']) 69 | # optimizerP.load_state_dict(checkpoint['optimizer_state_dict']['P']) 70 | netE.load_state_dict(checkpoint['model_state_dict']['D']) 71 | netD.load_state_dict(checkpoint['model_state_dict']['D']) 72 | netG.load_state_dict(checkpoint['model_state_dict']['G']) 73 | netP.load_state_dict(checkpoint['model_state_dict']['P']) 74 | print('=> Loaded checkpoint {:s} (epoch {:d})'.format(args['resume'], checkpoint['epoch'])) 75 | else: 76 | sys.exit('Please provide corrected model path!') 77 | else: 78 | args['epoch_start'] = 0 79 | if not Path(args['log_dir']).is_dir(): 80 | Path(args['log_dir']).mkdir() 81 | if not Path(args['model_dir']).is_dir(): 82 | Path(args['model_dir']).mkdir() 83 | 84 | for key, value in args.items(): 85 | print('{:<15s}: {:s}'.format(key, str(value))) 86 | 87 | # making dataset, out dataset are hdf5 88 | datasets = {'train': BenchmarkTrain(h5_file=args['SIDD_train_h5_noisy'], 89 | length=2000 * args['batch_size'] * args['num_critic'], 90 | pch_size=args['patch_size'], 91 | mask=False), 92 | 'val': BenchmarkTest(args['SIDD_test_h5'])} 93 | 94 | # build the Gaussian kernel for loss 95 | global kernel 96 | kernel = get_gausskernel(args['ksize'], chn=_C).cuda() 97 | # train model 98 | print('\nBegin training with GPU: ' + (args['gpu_id'])) 99 | train_epoch(net, datasets, optimizer, args, criterionGAN) 100 | 101 | def train_epoch(net, datasets, optimizer, args, criterionGAN): 102 | criterion = nn.L1Loss().cuda() 103 | loss_ssim = log_SSIM_loss().cuda() 104 | batch_size = {'train': args['batch_size'], 'val': 4} 105 | data_loader = {phase: uData.DataLoader(datasets[phase], batch_size=batch_size[phase], 106 | shuffle=True, num_workers=0, pin_memory=True) for phase in 107 | _modes} 108 | data_set_gt = BenchmarkTrain(h5_file=args['SIDD_train_h5_gt'], 109 | length=2000 * args['batch_size'] * args['num_critic'], 110 | pch_size=args['patch_size'], 111 | mask=False) 112 | # todo gt dataset has no key() 113 | data_loader_gt = uData.DataLoader(data_set_gt, batch_size=batch_size['train'], 114 | shuffle=True, num_workers=0, pin_memory=True) 115 | 116 | num_data = {phase: len(datasets[phase]) for phase in _modes} 117 | num_iter_epoch = {phase: ceil(num_data[phase] / batch_size[phase]) for phase in _modes} 118 | 119 | for epoch in range(args['epoch_start'], args['epochs']): 120 | loss_epoch = {x: 0 for x in ['PL', 'DL', 'GL']} 121 | subloss_epoch = {x: 0 for x in 122 | ['loss_GAN_DG', 'loss_l1', 'perceptual_loss', 'loss_bgm', 'loss_GAN_P_real', 'loss_GAN_P_fake']} 123 | mae_epoch = {'train': 0, 'val': 0} 124 | 125 | optD, optP, optG = optimizer['D'],optimizer['P'], optimizer['G'] 126 | 127 | 128 | tic = time.time() 129 | # train stage 130 | net['D'].train() 131 | net['G'].train() 132 | net['P'].train() 133 | 134 | lr_D = optimizer['D'].param_groups[0]['lr'] 135 | lr_G = optimizer['G'].param_groups[0]['lr'] 136 | lr_P = optimizer['P'].param_groups[0]['lr'] 137 | 138 | if lr_D < 1e-6: 139 | sys.exit('Reach the minimal learning rate') 140 | phase = 'train' 141 | 142 | for ii, (data, data1) in enumerate(zip(data_loader[phase], data_loader_gt)): 143 | 144 | im_noisy,_ = [x.cuda() for x in data] 145 | _,im_gt = [x.cuda() for x in data1] 146 | ################################ 147 | #training generator 148 | ############################## 149 | optimizer['G'].zero_grad() 150 | optimizer['D'].zero_grad() 151 | # !!!first stage, No SC 152 | # fake_im_noisy1 = net['G'](im_gt, im_noisy) 153 | # rec_x1 = net['D'](fake_im_noisy1.detach()) 154 | # rec_x2 = net['D'](im_noisy.detach()) 155 | # fake_im_noisy2 = net['G'](rec_x2, im_noisy) 156 | # fake_im_noisy3 = net['G'](rec_x2, fake_im_noisy1) 157 | # fake_im_noisy4 = net['G'](rec_x1, fake_im_noisy1) 158 | 159 | 160 | # SC stage 161 | tizao_1 = net['E'](im_noisy) 162 | rec_x2 = net['D'](im_noisy.detach()) 163 | fake_im_noisy1 = net['G'](im_gt, (im_noisy-tizao_1)) 164 | fake_im_noisy2 = net['G'](rec_x2, (im_noisy - tizao_1)) 165 | rec_x1 = net['D'](fake_im_noisy1.detach()) 166 | tizao_2 = net['E'](fake_im_noisy1) 167 | fake_im_noisy3 = net['G'](rec_x2, (fake_im_noisy1 - tizao_2)) 168 | fake_im_noisy4 = net['G'](rec_x1, (fake_im_noisy1 - tizao_2)) 169 | 170 | 171 | set_requires_grad([net['P']], False) 172 | 173 | subloss_epoch['perceptual_loss'] += 0 174 | adversarial_loss1 = criterionGAN(net['P'](fake_im_noisy1), True) 175 | 176 | 177 | adversarial_loss = adversarial_loss1 178 | identity_loss = 0 179 | bgm_loss1 = 0 180 | bgm_loss2 = 0 181 | bgm_loss = 0 182 | # BCM Loss 183 | for index, weight in enumerate(BlurWeight): 184 | out_b1 = BlurNet[index](im_gt) 185 | out_real_b1 = BlurNet[index](fake_im_noisy1) 186 | out_b2 = BlurNet[index](rec_x2) 187 | out_real_b2 = BlurNet[index](fake_im_noisy2) 188 | grad_loss_b1 = criterion(out_b1, out_real_b1) 189 | grad_loss_b2 = criterion(out_b2, out_real_b2) 190 | bgm_loss1 += weight * (grad_loss_b1) 191 | bgm_loss2 += weight * (grad_loss_b2) 192 | bgm_loss += bgm_loss1 + bgm_loss2 193 | loss_G = adversarial_loss * args['adversarial_loss_factor'] + \ 194 | bgm_loss1 * args['bgm_loss'] + \ 195 | bgm_loss2 * args['bgm_loss'] 196 | 197 | 198 | los_ssim = loss_ssim(rec_x1, im_gt) 199 | loss_recon = criterion(rec_x1, im_gt) 200 | # first stage no SC 201 | # loss_D = loss_recon + los_ssim 202 | 203 | # SC stage 204 | los_ssim1 = loss_ssim(rec_x2, tizao_1) 205 | loss_recon1 = criterion(rec_x2, tizao_1) 206 | los_ssim2 = loss_ssim(rec_x1, tizao_2) 207 | loss_recon2 = criterion(rec_x1, tizao_2) 208 | loss_D = loss_recon + los_ssim + loss_recon2 + los_ssim2 + los_ssim1 + loss_recon1 209 | 210 | 211 | loss_G.backward(retain_graph=True) 212 | 213 | loss_D.backward(retain_graph=True) 214 | optimizer['G'].step() 215 | optimizer['D'].step() 216 | loss_epoch['DL'] += loss_D.item() 217 | loss_epoch['GL'] += loss_G.item() 218 | 219 | subloss_epoch['loss_GAN_DG'] += adversarial_loss.item() 220 | subloss_epoch['loss_bgm'] += bgm_loss.item() 221 | 222 | ########################## 223 | # training discriminator # 224 | ########################## 225 | if (ii+1) % args['num_critic'] == 0: 226 | set_requires_grad([net['P']], True) 227 | 228 | pred_real1 = net['P'](im_noisy) 229 | loss_P_real = criterionGAN(pred_real1, True) 230 | pred_fake = net['P'](fake_im_noisy1.detach()) 231 | loss_P_fake = criterionGAN(pred_fake, False) 232 | 233 | # Combined loss and calculate gradients 234 | loss_P = (loss_P_real + loss_P_fake) * 0.5 235 | loss_P.backward() 236 | optimizer['P'].step() 237 | optimizer['P'].zero_grad() 238 | 239 | loss_epoch['PL'] += loss_P.item() 240 | subloss_epoch['loss_GAN_P_real'] += loss_P_real.item() 241 | subloss_epoch['loss_GAN_P_real'] += loss_P_fake.item() 242 | 243 | if (ii + 1) % args['print_freq'] == 0: 244 | template = '[Epoch:{:>2d}/{:<3d}] {:s}:{:0>5d}/{:0>5d},' + \ 245 | ' PL:{:>6.6f}, GL:{:>6.6f}, DL:{:>6.6f}, ' \ 246 | 'loss_GAN_G:{:>6.6f},' + \ 247 | 'loss_bgm:{:>6.9f}, loss_P_real:{:>6.4f}, ' \ 248 | 'loss_P_fake:{:>6.4f}, indentity_loss:{:>6.4f}' 249 | print(template.format(epoch + 1, args['epochs'], phase, ii + 1, num_iter_epoch[phase], 250 | loss_P.item(), loss_G.item(),loss_D.item(), 251 | # loss_P1.item(), loss_G.item(), loss_D.item(), 252 | adversarial_loss.item(), bgm_loss1.item(), loss_P_real.item(), loss_P_fake.item(),identity_loss)) 253 | 254 | loss_epoch['GL'] /= (ii + 1) 255 | 256 | subloss_epoch['loss_GAN_DG'] /= (ii + 1) 257 | subloss_epoch['loss_bgm'] /= (ii + 1) 258 | subloss_epoch 259 | loss_epoch['PL'] /= (ii + 1) 260 | subloss_epoch['loss_GAN_P_real'] /= (ii + 1) 261 | subloss_epoch['loss_GAN_P_fake'] /= (ii + 1) 262 | 263 | template = '{:s}: PL:{:>6.6f}, GL:{:>6.6f},loss_GAN_DG:{:>6.6f}, ' + \ 264 | ' loss_bgm:{:>6.4f}, loss_P_real:{:>6.4f}, ' \ 265 | 'loss_P_fake:{:>6.4f}, lrDG/P:{:.2e}/{:.2e}' 266 | print(template.format(phase, loss_epoch['PL'], loss_epoch['GL'], subloss_epoch['loss_GAN_DG'], 267 | subloss_epoch['loss_bgm'], 268 | subloss_epoch['loss_GAN_P_real'], 269 | subloss_epoch['loss_GAN_P_fake'], lr_D, lr_P)) 270 | 271 | net['G'].eval() 272 | print('Epoch [{0}]\t' 273 | 'lr: {lr:.6f}\t' 274 | 'Loss: {loss:.5f}' 275 | .format( 276 | epoch, 277 | lr=lr_D, 278 | loss=loss_epoch['DL'])) 279 | 280 | print('-' * 150) 281 | 282 | # test stage 283 | net['D'].eval() 284 | psnr_per_epoch = ssim_per_epoch = 0 285 | phase = 'val' 286 | for ii, data in enumerate(data_loader[phase]): 287 | im_noisy, im_gt = [x.cuda() for x in data] 288 | with torch.set_grad_enabled(False): 289 | im_denoise = net['D'](im_noisy) 290 | 291 | mae_iter = F.l1_loss(im_denoise, im_gt) 292 | im_denoise.clamp_(0.0, 1.0) 293 | mae_epoch[phase] += mae_iter 294 | psnr_iter = batch_PSNR(im_denoise, im_gt) 295 | psnr_per_epoch += psnr_iter 296 | ssim_iter = batch_SSIM(im_denoise, im_gt) 297 | ssim_per_epoch += ssim_iter 298 | if (ii + 1) % 50 == 0: 299 | log_str = '[Epoch:{:>2d}/{:<2d}] {:s}:{:0>3d}/{:0>3d}, mae={:.2e}, ' + \ 300 | 'psnr={:4.2f}, ssim={:5.4f}' 301 | print(log_str.format(epoch + 1, args['epochs'], phase, ii + 1, num_iter_epoch[phase], 302 | mae_iter, psnr_iter, ssim_iter)) 303 | 304 | psnr_per_epoch /= (ii + 1) 305 | ssim_per_epoch /= (ii + 1) 306 | mae_epoch[phase] /= (ii + 1) 307 | print('{:s}: mae={:.3e}, PSNR={:4.2f}, SSIM={:5.4f}'.format(phase, mae_epoch[phase], 308 | psnr_per_epoch, ssim_per_epoch)) 309 | print('-' * 150) 310 | 311 | # save model 312 | model_prefix = 'model_' 313 | save_path_model = str(Path(args['model_dir']) / (model_prefix + str(epoch + 1))) 314 | torch.save({ 315 | 'epoch': epoch + 1, 316 | 'model_state_dict': {x: net[x].state_dict() for x in ['E','D', 'G', 'P']}, 317 | 'optimizer_state_dict': {x: optimizer[x].state_dict() for x in ['D', 'P', 'G']}, 318 | }, save_path_model) 319 | model_prefix = 'model_state_' 320 | save_path_model = str(Path(args['model_dir']) / (model_prefix + str(epoch + 1) + 'PSNR{:.2f}_SSIM{:.4f}'. 321 | format(psnr_per_epoch, ssim_per_epoch) + '.pt')) 322 | torch.save({x: net[x].state_dict() for x in ['E','D', 'G', 'P']}, save_path_model) 323 | 324 | toc = time.time() 325 | print('This epoch take time {:.2f}'.format(toc - tic)) 326 | 327 | print('Reach the maximal epochs! Finish training') 328 | 329 | def set_requires_grad(nets, requires_grad=False): 330 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 331 | Parameters: 332 | nets (network list) -- a list of networks 333 | requires_grad (bool) -- whether the networks require gradients or no 334 | """ 335 | if not isinstance(nets, list): 336 | nets = [nets] 337 | for net in nets: 338 | if net is not None: 339 | for param in net.parameters(): 340 | param.requires_grad = requires_grad 341 | 342 | 343 | def adjust_learning_rate(epoch, opt): 344 | """Sets the learning rate to the initial LR decayed by 10 every 10 epochs""" 345 | lr = opt['lr'] * (opt['gamma'] ** ((epoch) // opt['lr_decay'])) 346 | # lr = opt['lr'] 347 | return lr 348 | 349 | 350 | if __name__ == '__main__': 351 | main() 352 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-10-31 21:31:50 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import functools 8 | from math import exp 9 | import cv2 10 | import numpy as np 11 | import torch.nn as nn 12 | from torchvision.models.vgg import vgg19 13 | from loss_util import weighted_loss 14 | from GaussianSmoothLayer import GaussionSmoothLayer 15 | _reduction_modes = ['none', 'mean', 'sum'] 16 | 17 | from torch.autograd import Variable 18 | 19 | 20 | @weighted_loss 21 | def l1_loss(pred, target): 22 | return F.l1_loss(pred, target, reduction='none') 23 | 24 | 25 | @weighted_loss 26 | def mse_loss(pred, target): 27 | return F.mse_loss(pred, target, reduction='none') 28 | 29 | def gradient_penalty(real_data, generated_data, netP, lambda_gp): 30 | batch_size = real_data.size()[0] 31 | 32 | # Calculate interpolation 33 | alpha = torch.rand(batch_size, 1, 1, 1) 34 | alpha = alpha.expand_as(real_data) 35 | alpha = alpha.to(real_data.device) 36 | interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data 37 | interpolated.requires_grad=True 38 | 39 | # Calculate probability of interpolated examples 40 | prob_interpolated = netP(interpolated) 41 | 42 | # Calculate gradients of probabilities with respect to examples 43 | grad_outputs = torch.ones(prob_interpolated.size(), device=real_data.device, dtype=torch.float32) 44 | gradients = torch.autograd.grad(outputs=prob_interpolated, inputs=interpolated, 45 | grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0] 46 | 47 | # Gradients have shape (batch_size, num_channels, img_width, img_height), 48 | # so flatten to easily take norm per example in batch 49 | gradients = gradients.view(batch_size, -1) 50 | 51 | # Derivatives of the gradient close to 0 can cause problems because of 52 | # the square root, so manually calculate norm and add epsilon 53 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) 54 | 55 | # Return gradient penalty 56 | return lambda_gp * ((gradients_norm - 1) ** 2).mean() 57 | 58 | def get_gausskernel(p, chn=3): 59 | ''' 60 | Build a 2-dimensional Gaussian filter with size p 61 | ''' 62 | x = cv2.getGaussianKernel(p, sigma=-1) # p x 1 63 | y = np.matmul(x, x.T)[np.newaxis, np.newaxis,] # 1x 1 x p x p 64 | out = np.tile(y, (chn, 1, 1, 1)) # chn x 1 x p x p 65 | 66 | return torch.from_numpy(out).type(torch.float32) 67 | 68 | def gaussblur(x, kernel, p=5, chn=3): 69 | x_pad = F.pad(x, pad=[int((p-1)/2),]*4, mode='reflect') 70 | y = F.conv2d(x_pad, kernel, padding=0, stride=1, groups=chn) 71 | 72 | return y 73 | 74 | def var_match(x, y, fake_y, kernel, chn=3): 75 | p = kernel.shape[2] 76 | # estimate the real distribution 77 | err_real = y - x 78 | mu_real = gaussblur(err_real, kernel, p, chn) 79 | err2_real = (err_real-mu_real)**2 80 | var_real = gaussblur(err2_real, kernel, p, chn) 81 | var_real = torch.where(var_real<1e-10, torch.ones_like(fake_y)*1e-10, var_real) 82 | # estimate the fake distribution 83 | err_fake = fake_y - x 84 | mu_fake = gaussblur(err_fake, kernel, p, chn) 85 | err2_fake = (err_fake-mu_fake)**2 86 | var_fake = gaussblur(err2_fake, kernel, p, chn) 87 | var_fake = torch.where(var_fake<1e-10, torch.ones_like(fake_y)*1e-10, var_fake) 88 | 89 | # calculate the loss 90 | loss_err = F.l1_loss(mu_real, mu_fake, reduction='mean') 91 | loss_var = F.l1_loss(var_real, var_fake, reduction='mean') 92 | 93 | return loss_err, loss_var 94 | 95 | def mean_match(x, fake_y,y,fake_x, kernel, chn=3): 96 | p = kernel.shape[2] 97 | # estimate the real distribution 98 | err_real = fake_y - x 99 | mu_real = gaussblur(err_real, kernel, p, chn) 100 | err_fake = y - fake_x 101 | mu_fake = gaussblur(err_fake, kernel, p, chn) 102 | loss = F.l1_loss(mu_real, mu_fake, reduction='mean') 103 | 104 | return loss 105 | 106 | def mean_match_1(y, fake_y, kernel, chn=3): 107 | p = kernel.shape[2] 108 | # estimate the real distribution 109 | # err_real = y - x 110 | mu_real = gaussblur(y, kernel, p, chn) 111 | 112 | mu_fake = gaussblur(fake_y, kernel, p, chn) 113 | loss = F.l1_loss(mu_real, mu_fake, reduction='mean') 114 | 115 | return loss 116 | 117 | class GANLoss(nn.Module): 118 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 119 | super(GANLoss, self).__init__() 120 | self.gan_type = gan_type.lower() 121 | self.real_label_val = real_label_val 122 | self.fake_label_val = fake_label_val 123 | 124 | if self.gan_type == 'vanilla': 125 | self.loss = nn.BCEWithLogitsLoss() 126 | elif self.gan_type == 'lsgan': 127 | self.loss = nn.MSELoss() 128 | elif self.gan_type == 'wgan-gp': 129 | 130 | def wgan_loss(input, target): 131 | # target is boolean 132 | return -1 * input.mean() if target else input.mean() 133 | 134 | self.loss = wgan_loss 135 | else: 136 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 137 | 138 | def get_target_label(self, input, target_is_real): 139 | if self.gan_type == 'wgan-gp': 140 | return target_is_real 141 | if target_is_real: 142 | return torch.empty_like(input).fill_(self.real_label_val) 143 | else: 144 | return torch.empty_like(input).fill_(self.fake_label_val) 145 | 146 | def forward(self, input, target_is_real): 147 | target_label = self.get_target_label(input, target_is_real) 148 | loss = self.loss(input, target_label) 149 | return loss 150 | 151 | 152 | class GANLoss_v2(nn.Module): 153 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 154 | super(GANLoss_v2, self).__init__() 155 | self.register_buffer('real_label', torch.tensor(target_real_label)) 156 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 157 | 158 | if use_lsgan: 159 | self.loss = nn.MSELoss() 160 | else: 161 | self.loss = nn.BCELoss() 162 | 163 | def get_target_tensor(self, input, target_is_real): 164 | if target_is_real: 165 | target_tensor = self.real_label 166 | else: 167 | target_tensor = self.fake_label 168 | return target_tensor.expand_as(input) 169 | 170 | def __call__(self, input, target_is_real): 171 | target_tensor = self.get_target_tensor(input, target_is_real) 172 | return self.loss(input, target_tensor) 173 | class GANLoss_v3(nn.Module): 174 | """Define GAN loss. 175 | 176 | Args: 177 | gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. 178 | real_label_val (float): The value for real label. Default: 1.0. 179 | fake_label_val (float): The value for fake label. Default: 0.0. 180 | loss_weight (float): Loss weight. Default: 1.0. 181 | Note that loss_weight is only for generators; and it is always 1.0 182 | for discriminators. 183 | """ 184 | 185 | def __init__(self, 186 | gan_type, 187 | real_label_val=1.0, 188 | fake_label_val=0.0, 189 | loss_weight=1.0): 190 | super(GANLoss_v3, self).__init__() 191 | self.gan_type = gan_type 192 | self.loss_weight = loss_weight 193 | self.real_label_val = real_label_val 194 | self.fake_label_val = fake_label_val 195 | 196 | if self.gan_type == 'vanilla': 197 | self.loss = nn.BCEWithLogitsLoss() 198 | elif self.gan_type == 'lsgan': 199 | self.loss = nn.MSELoss() 200 | elif self.gan_type == 'wgan': 201 | self.loss = self._wgan_loss 202 | elif self.gan_type == 'wgan_softplus': 203 | self.loss = self._wgan_softplus_loss 204 | elif self.gan_type == 'hinge': 205 | self.loss = nn.ReLU() 206 | else: 207 | raise NotImplementedError( 208 | f'GAN type {self.gan_type} is not implemented.') 209 | 210 | def _wgan_loss(self, input, target): 211 | """wgan loss. 212 | 213 | Args: 214 | input (Tensor): Input tensor. 215 | target (bool): Target label. 216 | 217 | Returns: 218 | Tensor: wgan loss. 219 | """ 220 | return -input.mean() if target else input.mean() 221 | 222 | def _wgan_softplus_loss(self, input, target): 223 | """wgan loss with soft plus. softplus is a smooth approximation to the 224 | ReLU function. 225 | 226 | In StyleGAN2, it is called: 227 | Logistic loss for discriminator; 228 | Non-saturating loss for generator. 229 | 230 | Args: 231 | input (Tensor): Input tensor. 232 | target (bool): Target label. 233 | 234 | Returns: 235 | Tensor: wgan loss. 236 | """ 237 | return F.softplus(-input).mean() if target else F.softplus( 238 | input).mean() 239 | 240 | def get_target_label(self, input, target_is_real): 241 | """Get target label. 242 | 243 | Args: 244 | input (Tensor): Input tensor. 245 | target_is_real (bool): Whether the target is real or fake. 246 | 247 | Returns: 248 | (bool | Tensor): Target tensor. Return bool for wgan, otherwise, 249 | return Tensor. 250 | """ 251 | 252 | if self.gan_type in ['wgan', 'wgan_softplus']: 253 | return target_is_real 254 | target_val = ( 255 | self.real_label_val if target_is_real else self.fake_label_val) 256 | return input.new_ones(input.size()) * target_val 257 | 258 | def forward(self, input, target_is_real, is_disc=False): 259 | """ 260 | Args: 261 | input (Tensor): The input for the loss module, i.e., the network 262 | prediction. 263 | target_is_real (bool): Whether the targe is real or fake. 264 | is_disc (bool): Whether the loss for discriminators or not. 265 | Default: False. 266 | 267 | Returns: 268 | Tensor: GAN loss value. 269 | """ 270 | target_label = self.get_target_label(input, target_is_real) 271 | if self.gan_type == 'hinge': 272 | if is_disc: # for discriminators in hinge-gan 273 | input = -input if target_is_real else input 274 | loss = self.loss(1 + input).mean() 275 | else: # for generators in hinge-gan 276 | loss = -input.mean() 277 | else: # other gan types 278 | loss = self.loss(input, target_label) 279 | 280 | # loss_weight is always 1.0 for discriminators 281 | return loss if is_disc else loss * self.loss_weight 282 | 283 | 284 | class PerceptualLoss(nn.Module): 285 | def __init__(self): 286 | super(PerceptualLoss, self).__init__() 287 | 288 | vgg = vgg19(pretrained=True) 289 | loss_network = nn.Sequential(*list(vgg.features)[:35]).eval() 290 | for param in loss_network.parameters(): 291 | param.requires_grad = False 292 | self.loss_network = loss_network 293 | self.l1_loss = nn.L1Loss() 294 | 295 | def forward(self, high_resolution, fake_high_resolution): 296 | perception_loss = self.l1_loss(self.loss_network(high_resolution), self.loss_network(fake_high_resolution)) 297 | return perception_loss 298 | 299 | 300 | class L1Loss(nn.Module): 301 | """L1 (mean absolute error, MAE) loss. 302 | 303 | Args: 304 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 305 | reduction (str): Specifies the reduction to apply to the output. 306 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 307 | """ 308 | 309 | def __init__(self, loss_weight=1.0, reduction='mean'): 310 | super(L1Loss, self).__init__() 311 | if reduction not in ['none', 'mean', 'sum']: 312 | raise ValueError('Unsupported reduction mode: {reduction}. ' 313 | 'Supported ones are: {_reduction_modes}') 314 | 315 | self.loss_weight = loss_weight 316 | self.reduction = reduction 317 | self.Gas = GaussionSmoothLayer(3, 15, 9) 318 | 319 | def forward(self, pred, target, weight=None, **kwargs): 320 | """ 321 | Args: 322 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 323 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 324 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 325 | weights. Default: None. 326 | """ 327 | # pred = pred - F.avg_pool2d( 328 | # F.pad(pred, (1, 1, 1, 1), mode='reflect'), 3, 1, padding=0) 329 | # target = target - F.avg_pool2d( 330 | # F.pad(target, (1, 1, 1, 1), mode='reflect'), 3, 1, padding=0) 331 | pred = pred - self.Gas(pred) 332 | target = target - self.Gas(target) 333 | return self.loss_weight * l1_loss( 334 | pred, target, weight, reduction=self.reduction) 335 | 336 | 337 | class MSELoss(nn.Module): 338 | """MSE (L2) loss. 339 | 340 | Args: 341 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 342 | reduction (str): Specifies the reduction to apply to the output. 343 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 344 | """ 345 | 346 | def __init__(self, loss_weight=1.0, reduction='mean'): 347 | super(MSELoss, self).__init__() 348 | if reduction not in ['none', 'mean', 'sum']: 349 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 350 | f'Supported ones are: {_reduction_modes}') 351 | 352 | self.loss_weight = loss_weight 353 | self.reduction = reduction 354 | 355 | def forward(self, pred, target, weight=None, **kwargs): 356 | """ 357 | Args: 358 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 359 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 360 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 361 | weights. Default: None. 362 | """ 363 | return self.loss_weight * mse_loss( 364 | pred, target, weight, reduction=self.reduction) 365 | 366 | 367 | def gaussian(window_size, sigma): 368 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 369 | return gauss / gauss.sum() 370 | 371 | 372 | def create_window(window_size, channel): 373 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 374 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 375 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 376 | return window 377 | 378 | 379 | class log_SSIM_loss(nn.Module): 380 | def __init__(self, window_size=11, channel=3, is_cuda=True, size_average=True): 381 | super(log_SSIM_loss, self).__init__() 382 | self.window_size = window_size 383 | self.channel = channel 384 | self.size_average = size_average 385 | self.window = create_window(window_size, channel) 386 | if is_cuda: 387 | self.window = self.window.cuda() 388 | 389 | 390 | def forward(self, img1, img2): 391 | mu1 = F.conv2d(img1, self.window, padding=self.window_size // 2, groups=self.channel) 392 | mu2 = F.conv2d(img2, self.window, padding=self.window_size // 2, groups=self.channel) 393 | 394 | mu1_sq = mu1.pow(2) 395 | mu2_sq = mu2.pow(2) 396 | mu1_mu2 = mu1 * mu2 397 | 398 | sigma1_sq = F.conv2d(img1 * img1, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_sq 399 | sigma2_sq = F.conv2d(img2 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu2_sq 400 | sigma12 = F.conv2d(img1 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2 401 | 402 | C1 = 0.01 ** 2 403 | C2 = 0.03 ** 2 404 | 405 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 406 | 407 | return -torch.log10(ssim_map.mean()) 408 | 409 | 410 | class negative_SSIM_loss(nn.Module): 411 | def __init__(self, window_size=11, channel=3, is_cuda=True, size_average=True): 412 | super(negative_SSIM_loss, self).__init__() 413 | self.window_size = window_size 414 | self.channel = channel 415 | self.size_average = size_average 416 | self.window = create_window(window_size, channel) 417 | if is_cuda: 418 | self.window = self.window.cuda() 419 | 420 | 421 | def forward(self, img1, img2): 422 | mu1 = F.conv2d(img1, self.window, padding=self.window_size // 2, groups=self.channel) 423 | mu2 = F.conv2d(img2, self.window, padding=self.window_size // 2, groups=self.channel) 424 | 425 | mu1_sq = mu1.pow(2) 426 | mu2_sq = mu2.pow(2) 427 | mu1_mu2 = mu1 * mu2 428 | 429 | sigma1_sq = F.conv2d(img1 * img1, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_sq 430 | sigma2_sq = F.conv2d(img2 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu2_sq 431 | sigma12 = F.conv2d(img1 * img2, self.window, padding=self.window_size // 2, groups=self.channel) - mu1_mu2 432 | 433 | C1 = 0.01 ** 2 434 | C2 = 0.03 ** 2 435 | 436 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 437 | 438 | return 1.0-ssim_map.mean() 439 | 440 | 441 | class GRAD_loss(nn.Module): 442 | def __init__(self, channel=3, is_cuda=True): 443 | super(GRAD_loss, self).__init__() 444 | self.edge_conv = nn.Conv2d(channel, channel*2, kernel_size=3, stride=1, padding=1, groups=channel, bias=False) 445 | edge_kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 446 | edge_ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) 447 | edge_k = [] 448 | for i in range(channel): 449 | edge_k.append(edge_kx) 450 | edge_k.append(edge_ky) 451 | 452 | edge_k = np.stack(edge_k) 453 | 454 | edge_k = torch.from_numpy(edge_k).float().view(channel*2, 1, 3, 3) 455 | self.edge_conv.weight = nn.Parameter(edge_k) 456 | for param in self.parameters(): 457 | param.requires_grad = False 458 | 459 | if is_cuda: self.edge_conv.cuda() 460 | 461 | def forward(self, img1, img2): 462 | img1_grad = self.edge_conv(img1) 463 | img2_grad = self.edge_conv(img2) 464 | 465 | return F.l1_loss(img1_grad, img2_grad) 466 | 467 | 468 | class exp_GRAD_loss(nn.Module): 469 | def __init__(self, channel=3, is_cuda=True): 470 | super(exp_GRAD_loss, self).__init__() 471 | self.edge_conv = nn.Conv2d(channel, channel*2, kernel_size=3, stride=1, padding=1, groups=channel, bias=False) 472 | edge_kx = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 473 | edge_ky = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) 474 | edge_k = [] 475 | for i in range(channel): 476 | edge_k.append(edge_kx) 477 | edge_k.append(edge_ky) 478 | 479 | edge_k = np.stack(edge_k) 480 | 481 | edge_k = torch.from_numpy(edge_k).float().view(channel*2, 1, 3, 3) 482 | self.edge_conv.weight = nn.Parameter(edge_k) 483 | for param in self.parameters(): 484 | param.requires_grad = False 485 | 486 | if is_cuda: self.edge_conv.cuda() 487 | 488 | def forward(self, img1, img2): 489 | img1_grad = self.edge_conv(img1) 490 | img2_grad = self.edge_conv(img2) 491 | 492 | return torch.exp(F.l1_loss(img1_grad, img2_grad)) - 1 493 | 494 | 495 | class log_PSNR_loss(torch.nn.Module): 496 | def __init__(self): 497 | super(log_PSNR_loss, self).__init__() 498 | 499 | def forward(self, img1, img2): 500 | diff = img1 - img2 501 | mse = diff*diff.mean() 502 | return -torch.log10(1.0-mse) 503 | 504 | 505 | class MSE_loss(torch.nn.Module): 506 | def __init__(self): 507 | super(MSE_loss, self).__init__() 508 | 509 | def forward(self, img1, img2): 510 | return F.mse_loss(img1, img2) 511 | 512 | 513 | class L1_loss(torch.nn.Module): 514 | def __init__(self): 515 | super(L1_loss, self).__init__() 516 | 517 | def forward(self, img1, img2): 518 | return F.l1_loss(img1, img2) 519 | 520 | 521 | loss_dict = { 522 | 'l1': L1_loss, 523 | 'mse': MSE_loss, 524 | 'grad': GRAD_loss, 525 | 'exp_grad': exp_GRAD_loss, 526 | 'log_ssim': log_SSIM_loss, 527 | 'neg_ssim': negative_SSIM_loss, 528 | 'log_psnr': log_PSNR_loss, 529 | } 530 | 531 | -------------------------------------------------------------------------------- /networks3/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Power by Zongsheng Yue 2019-09-01 20:56:15 4 | 5 | from networks3.Discriminator import _NetD, DiscriminatorLinear, Discriminator1,\ 6 | VGGStyleDiscriminator128,NLayerDiscriminator 7 | from networks3.UNetG import UNetG, sample_generator,sample_generator_1 8 | from networks3.UNetD import UNetD,DnCNN 9 | from networks3.util import pixel_shuffle_down_sampling, pixel_shuffle_up_sampling 10 | from functools import partial 11 | from importlib import import_module 12 | import os 13 | # from GaussianSmoothLayer import GaussionSmoothLayer 14 | import torchvision.transforms as transforms 15 | import torch.nn as nn 16 | import torch 17 | import torch 18 | from torch import nn 19 | import torch.nn.functional as F 20 | # from .SubBlocks import conv3x3, conv_down 21 | #!/usr/bin/env python 22 | # -*- coding:utf-8 -*- 23 | # Power by Zongsheng Yue 2019-03-20 19:48:14 24 | # Adapted from https://github.com/jvanvugt/pytorch-unet 25 | # from .util import pixel_shuffle_down_sampling, pixel_shuffle_up_sampling 26 | import torch 27 | from torch import nn 28 | import torch.nn.functional as F 29 | from networks3.SubBlocks import conv3x3, conv_down 30 | torch_ver = torch.__version__[:3] 31 | import numpy as np 32 | import torch 33 | import math 34 | from torch.nn import Module, Sequential, Conv2d, ReLU,AdaptiveMaxPool2d, AdaptiveAvgPool2d, \ 35 | NLLLoss, BCELoss, CrossEntropyLoss, AvgPool2d, MaxPool2d, Parameter, Linear, Sigmoid, Softmax, Dropout, Embedding 36 | from torch.nn import PixelShuffle, PixelUnshuffle 37 | from torch.nn import functional as F 38 | from torch.autograd import Variable 39 | torch_ver = torch.__version__[:3] 40 | 41 | # __all__ = ['PAM_Module', 'CAM_Module'] 42 | 43 | 44 | 45 | 46 | class PAM_Module(Module): 47 | """ Position attention module""" 48 | #Ref from SAGAN 49 | def __init__(self, in_dim): 50 | super(PAM_Module, self).__init__() 51 | self.chanel_in = in_dim 52 | 53 | self.query_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 54 | self.key_conv = Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 55 | self.value_conv = Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 56 | self.gamma = Parameter(torch.zeros(1)) 57 | 58 | self.softmax = Softmax(dim=-1) 59 | def forward(self, x): 60 | """ 61 | inputs : 62 | x : input feature maps( B X C X H X W) 63 | returns : 64 | out : attention value + input feature 65 | attention: B X (HxW) X (HxW) 66 | """ 67 | m_batchsize, C, height, width = x.size() 68 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 69 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) 70 | energy = torch.bmm(proj_query, proj_key) 71 | attention = self.softmax(energy) 72 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) 73 | 74 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 75 | out = out.view(m_batchsize, C, height, width) 76 | 77 | out = self.gamma*out + x 78 | return out 79 | 80 | 81 | class CAM_Module(Module): 82 | """ Channel attention module""" 83 | def __init__(self, in_dim): 84 | super(CAM_Module, self).__init__() 85 | self.chanel_in = in_dim 86 | 87 | 88 | self.gamma = Parameter(torch.zeros(1)) 89 | self.softmax = Softmax(dim=-1) 90 | def forward(self,x): 91 | """ 92 | inputs : 93 | x : input feature maps( B X C X H X W) 94 | returns : 95 | out : attention value + input feature 96 | attention: B X C X C 97 | """ 98 | m_batchsize, C, height, width = x.size() 99 | proj_query = x.view(m_batchsize, C, -1) 100 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 101 | energy = torch.bmm(proj_query, proj_key) 102 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy 103 | attention = self.softmax(energy_new) 104 | proj_value = x.view(m_batchsize, C, -1) 105 | 106 | out = torch.bmm(attention, proj_value) 107 | out = out.view(m_batchsize, C, height, width) 108 | 109 | out = self.gamma*out + x 110 | return out 111 | 112 | 113 | 114 | 115 | class UNetD(nn.Module): 116 | def __init__(self, in_chn, wf=64, depth=5, relu_slope=0.2): 117 | super(UNetD, self).__init__() 118 | self.depth = depth 119 | self.down_path = nn.ModuleList() 120 | prev_channels = self.get_input_chn(in_chn) 121 | for i in range(depth): 122 | downsample = True if (i+1) < depth else False 123 | self.down_path.append(UNetConvBlock(prev_channels, (2**i)*wf, downsample, relu_slope)) 124 | prev_channels = (2**i) * wf 125 | 126 | self.up_path = nn.ModuleList() 127 | for i in reversed(range(depth - 1)): 128 | self.up_path.append(UNetUpBlock(prev_channels, (2**i)*wf, relu_slope)) 129 | prev_channels = (2**i)*wf 130 | 131 | self.last = conv3x3(prev_channels, in_chn, bias=True) 132 | # self._initialize() 133 | 134 | def forward(self, x1): 135 | 136 | res = x1 137 | blocks = [] 138 | for i, down in enumerate(self.down_path): 139 | if (i+1) < self.depth: 140 | x1, x1_up = down(x1) 141 | blocks.append(x1_up) 142 | else: 143 | x1 = down(x1) 144 | 145 | for i, up in enumerate(self.up_path): 146 | x1 = up(x1, blocks[-i-1]) 147 | 148 | out = self.last(x1) 149 | return out+res 150 | 151 | def get_input_chn(self, in_chn): 152 | return in_chn 153 | 154 | def _initialize(self): 155 | gain = nn.init.calculate_gain('leaky_relu', 0.20) 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | nn.init.orthogonal_(m.weight, gain=gain) 159 | if not m.bias is None: 160 | nn.init.constant_(m.bias, 0) 161 | 162 | class UNetConvBlock(nn.Module): 163 | def __init__(self, in_size, out_size, downsample, relu_slope): 164 | super(UNetConvBlock, self).__init__() 165 | self.downsample = downsample 166 | self.block = nn.Sequential( 167 | nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True), 168 | nn.LeakyReLU(relu_slope, inplace=True), 169 | nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True), 170 | nn.LeakyReLU(relu_slope, inplace=True)) 171 | self.SA = spatial_attn_layer() ## Spatial Attention 172 | self.CA = CALayer(out_size, 8) ## Channel Attention 173 | self.conv1x1 = nn.Conv2d(2 * out_size, out_size, kernel_size=1) 174 | 175 | if downsample: 176 | self.downsample = conv_down(out_size, out_size, bias=False) 177 | 178 | def forward(self, x): 179 | out = self.block(x) 180 | # sa_branch = self.SA(out) 181 | # ca_branch = self.CA(out) 182 | # res = torch.cat([sa_branch, ca_branch], dim=1) 183 | # out = self.conv1x1(res) 184 | if self.downsample: 185 | out_down = self.downsample(out) 186 | return out_down, out 187 | else: 188 | return out 189 | 190 | class UNetUpBlock(nn.Module): 191 | def __init__(self, in_size, out_size, relu_slope): 192 | super(UNetUpBlock, self).__init__() 193 | self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True) 194 | self.conv_block = UNetConvBlock(in_size, out_size, False, relu_slope) 195 | self.SA = spatial_attn_layer() ## Spatial Attention 196 | self.CA = CALayer(out_size, 8) ## Channel Attention 197 | self.conv1x1 = nn.Conv2d(2 * out_size, out_size, kernel_size=1) 198 | 199 | def forward(self, x, bridge): 200 | up = self.up(x) 201 | # sa_branch = self.SA(up) 202 | # ca_branch = self.CA(up) 203 | # res = torch.cat([sa_branch, ca_branch], dim=1) 204 | # out = self.conv1x1(res) 205 | out = torch.cat([up, bridge], 1) 206 | out = self.conv_block(out) 207 | 208 | return out 209 | 210 | class DnCNN(nn.Module): 211 | def __init__(self, channels, num_of_layers=17): 212 | super(DnCNN, self).__init__() 213 | kernel_size = 3 214 | padding = 1 215 | features = 64 216 | layers = [] 217 | layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 218 | layers.append(nn.ReLU(inplace=True)) 219 | for _ in range(num_of_layers-2): 220 | layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False)) 221 | layers.append(nn.BatchNorm2d(features)) 222 | layers.append(nn.ReLU(inplace=True)) 223 | layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False)) 224 | self.dncnn = nn.Sequential(*layers) 225 | def forward(self, x): 226 | out = self.dncnn(x) 227 | return out 228 | 229 | 230 | 231 | class _Conv_Block(nn.Module): 232 | def __init__(self): 233 | super(_Conv_Block, self).__init__() 234 | 235 | self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 236 | self.in1 = nn.BatchNorm2d(64, affine=True) 237 | self.relu = nn.LeakyReLU(0.2, inplace=True) 238 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 239 | self.in2 = nn.BatchNorm2d(64, affine=True) 240 | 241 | def forward(self, x): 242 | identity_data = x 243 | output = self.relu(self.in1(self.conv1(x))) 244 | output = self.in2(self.conv2(output)) 245 | return output 246 | 247 | 248 | class _Residual_Block(nn.Module): 249 | def __init__(self,n_feat = 64, reduction = 8): 250 | super(_Residual_Block, self).__init__() 251 | self.SA = spatial_attn_layer() ## Spatial Attention 252 | self.CA = CALayer(n_feat, reduction) ## Channel Attention 253 | self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1) 254 | # self.in1 = nn.BatchNorm2d(64, affine=True) 255 | self.relu = nn.LeakyReLU(0.2, inplace=True) 256 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, ) 257 | self.conv1x1 = nn.Conv2d(n_feat * 2, n_feat, kernel_size=1) 258 | # self.in2 = nn.BatchNorm2d(64, affine=True) 259 | 260 | def forward(self, x): 261 | identity_data = x 262 | output = self.relu((self.conv1(x))) 263 | 264 | output = ((self.conv2(output))) 265 | # sa_branch = self.SA(output) 266 | # ca_branch = self.CA(output) 267 | # res = torch.cat([sa_branch, ca_branch], dim=1) 268 | # res = self.conv1x1(res) 269 | 270 | output = torch.add(self.relu(output), identity_data) 271 | # output = torch.add(output,identity_data) 272 | 273 | 274 | return output 275 | 276 | class _NetG_DOWN(nn.Module): 277 | def __init__(self, stride=2): 278 | super(_NetG_DOWN, self).__init__() 279 | # self.Gas = GaussionSmoothLayer(3, 15, 9) 280 | self.Gas = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=7, stride=1, padding=3) 281 | # self.tiza = tizao(3) 282 | # self.conv_input = nn.Sequential( 283 | # nn.Conv2d(in_channels=6, out_channels=64, kernel_size=7, stride=1, padding=3, ), 284 | # nn.LeakyReLU(0.2, inplace=True), 285 | # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=stride + 2, stride=stride, padding=1, ), 286 | # nn.LeakyReLU(0.2, inplace=True), 287 | # nn.Conv2d(in_channels=64, out_channels=64, kernel_size=stride + 2, stride=stride, padding=1, ), 288 | # nn.LeakyReLU(0.2, inplace=True), 289 | # ) 290 | self.conv_input = nn.Sequential( 291 | nn.Conv2d(in_channels=6, out_channels=64, kernel_size=7, stride=1, padding=3, ), 292 | ) 293 | # self.relu = nn.LeakyReLU(0.2, inplace=True) 294 | 295 | self.residual = self.make_layer(_Residual_Block, 6) 296 | # self.dab = self.make_layer(DAB, 6) 297 | self.conv_output = nn.Sequential( 298 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 299 | nn.LeakyReLU(0.2, inplace=True), 300 | nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, ), 301 | nn.LeakyReLU(0.2, inplace=True), 302 | nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, stride=1, padding=3, ), 303 | ) 304 | self.scale = nn.Parameter(torch.randn(3,1,1),requires_grad=True) 305 | 306 | 307 | def make_layer(self, block, num_of_layer): 308 | 309 | layers = [] 310 | for _ in range(num_of_layer): 311 | layers.append(block()) 312 | return nn.Sequential(*layers) 313 | 314 | def forward(self, x, y): 315 | 316 | # z = torch.cat([x, self.scale*y], dim=1) 317 | 318 | # z = torch.cat([x, y], dim=1) 319 | # = torch.cat([x, y - self.tiza(y)], dim=1) 320 | z = torch.cat([x, y - self.Gas(y)], dim=1) 321 | 322 | 323 | out = self.conv_input(z) 324 | 325 | 326 | out = self.residual(out) 327 | 328 | out = self.conv_output(out) 329 | 330 | return out + x 331 | 332 | 333 | 334 | class spatial_attn_layer(nn.Module): 335 | def __init__(self, kernel_size=3): 336 | super(spatial_attn_layer, self).__init__() 337 | self.compress = ChannelPool() 338 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False) 339 | def forward(self, x): 340 | # import pdb;pdb.set_trace() 341 | x_compress = self.compress(x) 342 | x_out = self.spatial(x_compress) 343 | scale = torch.sigmoid(x_out) # broadcasting 344 | return x * scale 345 | 346 | class BasicConv(nn.Module): 347 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 348 | bn=False, bias=False): 349 | super(BasicConv, self).__init__() 350 | self.out_channels = out_planes 351 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 352 | dilation=dilation, groups=groups, bias=bias) 353 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 354 | self.relu = nn.ReLU() if relu else None 355 | 356 | def forward(self, x): 357 | x = self.conv(x) 358 | if self.bn is not None: 359 | x = self.bn(x) 360 | if self.relu is not None: 361 | x = self.relu(x) 362 | return x 363 | class ChannelPool(nn.Module): 364 | def forward(self, x): 365 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 366 | 367 | class CALayer(nn.Module): 368 | def __init__(self, channel, reduction=16): 369 | super(CALayer, self).__init__() 370 | # global average pooling: feature --> point 371 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 372 | # feature channel downscale and upscale --> channel weight 373 | self.conv_du = nn.Sequential( 374 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 375 | nn.ReLU(inplace=True), 376 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 377 | nn.Sigmoid() 378 | ) 379 | 380 | def forward(self, x): 381 | y = self.avg_pool(x) 382 | y = self.conv_du(y) 383 | return x * y 384 | model_class_dict = {} 385 | 386 | def regist_model(model_class): 387 | model_name = model_class.__name__.lower() 388 | assert not model_name in model_class_dict, 'there is already registered model: %s in model_class_dict.' % model_name 389 | model_class_dict[model_name] = model_class 390 | 391 | return model_class 392 | 393 | def get_model_class(model_name:str): 394 | model_name = model_name.lower() 395 | return model_class_dict[model_name] 396 | 397 | # import all python files in model folder 398 | 399 | 400 | class APBSN(nn.Module): 401 | ''' 402 | Asymmetric PD Blind-Spot Network (AP-BSN) 403 | ''' 404 | 405 | def __init__(self, pd_a=5, pd_b=2, pd_pad=2, R3=True, R3_T=8, R3_p=0.16, 406 | bsn='DBSNl', in_ch=3, bsn_base_ch=128, bsn_num_module=9): 407 | ''' 408 | Args: 409 | pd_a : 'PD stride factor' during training 410 | pd_b : 'PD stride factor' during inference 411 | pd_pad : pad size between sub-images by PD process 412 | R3 : flag of 'Random Replacing Refinement' 413 | R3_T : number of masks for R3 414 | R3_p : probability of R3 415 | bsn : blind-spot network type 416 | in_ch : number of input image channel 417 | bsn_base_ch : number of bsn base channel 418 | bsn_num_module : number of module 419 | ''' 420 | super().__init__() 421 | 422 | # network hyper-parameters 423 | self.pd_a = pd_a 424 | self.pd_b = pd_b 425 | self.pd_pad = pd_pad 426 | self.R3 = R3 427 | self.R3_T = R3_T 428 | self.R3_p = R3_p 429 | 430 | # define network 431 | if bsn == 'DBSNl': 432 | self.bsn = DBSNl(in_ch, in_ch, bsn_base_ch, bsn_num_module) 433 | else: 434 | raise NotImplementedError('bsn %s is not implemented' % bsn) 435 | 436 | def forward(self, img, pd=None): 437 | ''' 438 | Foward function includes sequence of PD, BSN and inverse PD processes. 439 | Note that denoise() function is used during inference time (for differenct pd factor and R3). 440 | ''' 441 | # default pd factor is training factor (a) 442 | if pd is None: pd = self.pd_a 443 | 444 | # do PD 445 | if pd > 1: 446 | pd_img = pixel_shuffle_down_sampling(img, f=pd, pad=self.pd_pad) 447 | else: 448 | p = self.pd_pad 449 | pd_img = F.pad(img, (p, p, p, p)) 450 | 451 | # forward blind-spot network 452 | pd_img_denoised = self.bsn(pd_img) 453 | 454 | # do inverse PD 455 | if pd > 1: 456 | img_pd_bsn = pixel_shuffle_up_sampling(pd_img_denoised, f=pd, pad=self.pd_pad) 457 | else: 458 | p = self.pd_pad 459 | img_pd_bsn = pd_img_denoised[:, :, p:-p, p:-p] 460 | 461 | return img_pd_bsn 462 | 463 | def denoise(self, x): 464 | ''' 465 | Denoising process for inference. 466 | ''' 467 | b, c, h, w = x.shape 468 | 469 | # pad images for PD process 470 | if h % self.pd_b != 0: 471 | x = F.pad(x, (0, 0, 0, self.pd_b - h % self.pd_b), mode='constant', value=0) 472 | if w % self.pd_b != 0: 473 | x = F.pad(x, (0, self.pd_b - w % self.pd_b, 0, 0), mode='constant', value=0) 474 | 475 | # forward PD-BSN process with inference pd factor 476 | img_pd_bsn = self.forward(img=x, pd=self.pd_b) 477 | 478 | # Random Replacing Refinement 479 | if not self.R3: 480 | ''' Directly return the result (w/o R3) ''' 481 | return img_pd_bsn[:, :, :h, :w] 482 | else: 483 | denoised = torch.empty(*(x.shape), self.R3_T, device=x.device) 484 | for t in range(self.R3_T): 485 | indice = torch.rand_like(x) 486 | mask = indice < self.R3_p 487 | 488 | tmp_input = torch.clone(img_pd_bsn).detach() 489 | tmp_input[mask] = x[mask] 490 | p = self.pd_pad 491 | tmp_input = F.pad(tmp_input, (p, p, p, p), mode='reflect') 492 | if self.pd_pad == 0: 493 | denoised[..., t] = self.bsn(tmp_input) 494 | else: 495 | denoised[..., t] = self.bsn(tmp_input)[:, :, p:-p, p:-p] 496 | 497 | return torch.mean(denoised, dim=-1) 498 | 499 | ''' 500 | elif self.R3 == 'PD-refinement': 501 | s = 2 502 | denoised = torch.empty(*(x.shape), s**2, device=x.device) 503 | for i in range(s): 504 | for j in range(s): 505 | tmp_input = torch.clone(x_mean).detach() 506 | tmp_input[:,:,i::s,j::s] = x[:,:,i::s,j::s] 507 | p = self.pd_pad 508 | tmp_input = F.pad(tmp_input, (p,p,p,p), mode='reflect') 509 | if self.pd_pad == 0: 510 | denoised[..., i*s+j] = self.bsn(tmp_input) 511 | else: 512 | denoised[..., i*s+j] = self.bsn(tmp_input)[:,:,p:-p,p:-p] 513 | return_denoised = torch.mean(denoised, dim=-1) 514 | else: 515 | raise RuntimeError('post-processing type not supported') 516 | ''' 517 | 518 | 519 | class DBSNl(nn.Module): 520 | ''' 521 | Dilated Blind-Spot Network (cutomized light version) 522 | 523 | self-implemented version of the network from "Unpaired Learning of Deep Image Denoising (ECCV 2020)" 524 | and several modificaions are included. 525 | see our supple for more details. 526 | ''' 527 | 528 | def __init__(self, in_ch=3, out_ch=3, base_ch=96, num_module=8): 529 | ''' 530 | Args: 531 | in_ch : number of input channel 532 | out_ch : number of output channel 533 | base_ch : number of base channel 534 | num_module : number of modules in the network 535 | ''' 536 | super().__init__() 537 | 538 | assert base_ch % 2 == 0, "base channel should be divided with 2" 539 | 540 | ly = [] 541 | ly += [nn.Conv2d(in_ch, base_ch, kernel_size=1)] 542 | ly += [nn.ReLU(inplace=True)] 543 | self.head = nn.Sequential(*ly) 544 | 545 | self.branch1 = DC_branchl(2, base_ch, num_module) 546 | self.branch2 = DC_branchl(3, base_ch, num_module) 547 | self.SA = spatial_attn_layer() ## Spatial Attention 548 | self.CA = CALayer(out_ch, 8) ## Channel Attention 549 | self.conv1x1 = nn.Conv2d(2 * out_ch, out_ch, kernel_size=1) 550 | ly = [] 551 | ly += [nn.Conv2d(base_ch * 2, base_ch, kernel_size=1)] 552 | ly += [nn.ReLU(inplace=True)] 553 | ly += [nn.Conv2d(base_ch , out_ch, kernel_size=3, padding=1,bias=True)] 554 | self.tail = nn.Sequential(*ly) 555 | 556 | def forward(self, x): 557 | 558 | x = self.head(x) 559 | 560 | br1 = self.branch1(x) 561 | br2 = self.branch2(x) 562 | 563 | x = torch.cat([br1, br2], dim=1) 564 | 565 | return self.tail(x) 566 | 567 | def _initialize_weights(self): 568 | # Liyong version 569 | for m in self.modules(): 570 | if isinstance(m, nn.Conv2d): 571 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 572 | m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5) 573 | 574 | 575 | class DC_branchl(nn.Module): 576 | def __init__(self, stride, in_ch, num_module): 577 | super().__init__() 578 | 579 | ly = [] 580 | ly += [nn.Conv2d(in_ch, in_ch, kernel_size=1)] 581 | ly += [nn.ReLU(inplace=True)] 582 | ly += [nn.Conv2d(in_ch, in_ch, kernel_size=1)] 583 | ly += [nn.ReLU(inplace=True)] 584 | 585 | 586 | ly += [DCl(stride, in_ch) for _ in range(num_module)] 587 | 588 | ly += [nn.Conv2d(in_ch, in_ch, kernel_size=1)] 589 | ly += [nn.ReLU(inplace=True)] 590 | 591 | self.body = nn.Sequential(*ly) 592 | self.SA = spatial_attn_layer() ## Spatial Attention 593 | self.CA = CALayer(in_ch, 8) ## Channel Attention 594 | self.conv1x1 = nn.Conv2d(2 * in_ch, in_ch, kernel_size=1) 595 | 596 | def forward(self, x): 597 | y = self.body(x) 598 | 599 | return y 600 | 601 | class DCl(nn.Module): 602 | def __init__(self, stride, in_ch): 603 | super().__init__() 604 | 605 | ly = [] 606 | ly += [nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, bias=True)] 607 | ly += [nn.ReLU(inplace=True)] 608 | ly += [nn.Conv2d(in_ch, in_ch, kernel_size=3, stride=1, padding=1, bias=True)] 609 | ly += [nn.ReLU(inplace=True)] 610 | # ly += [nn.Conv2d(in_ch, in_ch, kernel_size=1)] 611 | 612 | self.body = nn.Sequential(*ly) 613 | self.SA = spatial_attn_layer() ## Spatial Attention 614 | self.CA = CALayer(in_ch, 8) ## Channel Attention 615 | self.conv1x1 = nn.Conv2d(2 * in_ch, in_ch, kernel_size=1) 616 | def forward(self, x): 617 | y = self.body(x) 618 | 619 | # z = y1 + x 620 | # y2 = self.body(z) 621 | # y3 = y2 + z 622 | return y+x 623 | # sa_branch = self.SA(y) 624 | # ca_branch = self.CA(y) 625 | # res = torch.cat([sa_branch, ca_branch], dim=1) 626 | # out = self.conv1x1(res) 627 | return y+x 628 | 629 | class CentralMaskedConv2d(nn.Conv2d): 630 | def __init__(self, *args, **kwargs): 631 | super().__init__(*args, **kwargs) 632 | 633 | self.register_buffer('mask', self.weight.data.clone()) 634 | _, _, kH, kW = self.weight.size() 635 | self.mask.fill_(1) 636 | self.mask[:, :, kH // 2, kH // 2] = 0 637 | 638 | def forward(self, x): 639 | self.weight.data *= self.mask 640 | return super().forward(x) 641 | 642 | if __name__ == '__main__': 643 | net=DBSNl() 644 | para=sum(p.numel() for p in net.parameters()) 645 | print(para) 646 | 647 | class ConvLayer1(nn.Module): 648 | 649 | def __init__(self, in_channels, out_channels, kernel_size, stride): 650 | super(ConvLayer1, self).__init__() 651 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size-1)//2, stride= stride) 652 | 653 | nn.init.xavier_normal_(self.conv2d.weight.data) 654 | 655 | def forward(self, x): 656 | # out = self.reflection_pad(x) 657 | # out = self.conv2d(out) 658 | return self.conv2d(x) 659 | 660 | 661 | class ConvLayer(nn.Module): 662 | def __init__(self, in_channels, out_channels, kernel_size, stride): 663 | super(ConvLayer, self).__init__() 664 | padding = (kernel_size-1)//2 665 | self.block = nn.Sequential( 666 | nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, stride=stride), 667 | nn.ReLU() 668 | ) 669 | nn.init.xavier_normal_(self.block[0].weight.data) 670 | 671 | def forward(self, x): 672 | return self.block(x) 673 | 674 | 675 | class line(nn.Module): 676 | def __init__(self): 677 | super(line, self).__init__() 678 | self.delta = nn.Parameter(torch.randn(1, 1)) 679 | 680 | def forward(self, x ,y ): 681 | return torch.mul((1-self.delta), x) + torch.mul(self.delta, y) 682 | 683 | 684 | class Encoding_block(nn.Module): 685 | def __init__(self, base_filter, n_convblock): 686 | super(Encoding_block, self).__init__() 687 | self.n_convblock = n_convblock 688 | modules_body = [] 689 | for i in range(self.n_convblock-1): 690 | modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=1)) 691 | modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=2)) 692 | self.body = nn.Sequential(*modules_body) 693 | 694 | def forward(self, x): 695 | for i in range(self.n_convblock-1): 696 | x = self.body[i](x) 697 | ecode = x 698 | x = self.body[self.n_convblock-1](x) 699 | return ecode, x 700 | 701 | 702 | class UpsampleConvLayer(nn.Module): 703 | 704 | def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): 705 | super(UpsampleConvLayer, self).__init__() 706 | self.upsample = upsample 707 | self.conv2d = ConvLayer(in_channels, out_channels, kernel_size, stride) 708 | 709 | def forward(self, x): 710 | x_in = x 711 | if self.upsample: 712 | x_in = torch.nn.functional.interpolate(x_in, scale_factor=self.upsample) 713 | out = self.conv2d(x_in) 714 | return out 715 | 716 | 717 | class upsample1(nn.Module): 718 | def __init__(self, base_filter): 719 | super(upsample1, self).__init__() 720 | self.conv1 = ConvLayer(base_filter, base_filter, 3, stride=1) 721 | self.ConvTranspose = UpsampleConvLayer(base_filter, base_filter, kernel_size=3, stride=1, upsample=2) 722 | self.cat = ConvLayer1(base_filter*2, base_filter, kernel_size=1, stride=1) 723 | 724 | def forward(self, x, y): 725 | y = self.ConvTranspose(y) 726 | x = self.conv1(x) 727 | return self.cat(torch.cat((x, y), dim=1)) 728 | 729 | 730 | class Decoding_block2(nn.Module): 731 | def __init__(self, base_filter, n_convblock): 732 | super(Decoding_block2, self).__init__() 733 | self.n_convblock = n_convblock 734 | self.upsample = upsample1(base_filter) 735 | modules_body = [] 736 | for i in range(self.n_convblock-1): 737 | modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=1)) 738 | modules_body.append(ConvLayer(base_filter, base_filter, 3, stride=1)) 739 | self.body = nn.Sequential(*modules_body) 740 | 741 | def forward(self, x, y): 742 | x = self.upsample(x, y) 743 | for i in range(self.n_convblock): 744 | x = self.body[i](x) 745 | return x 746 | 747 | #Corresponds to DEAM Module in NLO Sub-network 748 | class Attention_unet(nn.Module): 749 | def __init__(self, channel, reduction=16): 750 | super(Attention_unet, self).__init__() 751 | self.conv_du = nn.Sequential( 752 | ConvLayer1(in_channels=channel, out_channels=channel // reduction, kernel_size=3, stride=1), 753 | nn.ReLU(inplace=True), 754 | ConvLayer1(in_channels=channel // reduction, out_channels=channel, kernel_size=3, stride=1), 755 | nn.Sigmoid() 756 | ) 757 | self.cat = ConvLayer1(in_channels=channel * 2, out_channels=channel, kernel_size=1, stride=1) 758 | self.C = ConvLayer1(in_channels=channel, out_channels=channel, kernel_size=3, stride=1) 759 | self.ConvTranspose = UpsampleConvLayer(channel, channel, kernel_size=3, stride=1, upsample=2)#up-sampling 760 | 761 | def forward(self, x, g): 762 | up_g = self.ConvTranspose(g) 763 | weight = self.conv_du(self.cat(torch.cat([self.C(x), up_g], 1))) 764 | rich_x = torch.mul((1 - weight), up_g) + torch.mul(weight, x) 765 | return rich_x 766 | 767 | #Corresponds to NLO Sub-network 768 | class ziwangluo1(nn.Module): 769 | def __init__(self, base_filter, n_convblock_in, n_convblock_out): 770 | super(ziwangluo1, self).__init__() 771 | self.conv_dila1 = ConvLayer1(64, 64, 3, 1) 772 | self.conv_dila2 = ConvLayer1(64, 64, 5, 1) 773 | self.conv_dila3 = ConvLayer1(64, 64, 7, 1) 774 | 775 | self.cat1 = torch.nn.Conv2d(in_channels=64 * 3, out_channels=64, kernel_size=1, stride=1, padding=0, 776 | dilation=1, bias=True) 777 | nn.init.xavier_normal_(self.cat1.weight.data) 778 | self.e3 = Encoding_block(base_filter, n_convblock_in) 779 | self.e2 = Encoding_block(base_filter, n_convblock_in) 780 | self.e1 = Encoding_block(base_filter, n_convblock_in) 781 | self.e0 = Encoding_block(base_filter, n_convblock_in) 782 | 783 | 784 | self.attention3 = Attention_unet(base_filter) 785 | self.attention2 = Attention_unet(base_filter) 786 | self.attention1 = Attention_unet(base_filter) 787 | self.attention0 = Attention_unet(base_filter) 788 | self.mid = nn.Sequential(ConvLayer(base_filter, base_filter, 3, 1), 789 | ConvLayer(base_filter, base_filter, 3, 1)) 790 | self.de3 = Decoding_block2(base_filter, n_convblock_out) 791 | self.de2 = Decoding_block2(base_filter, n_convblock_out) 792 | self.de1 = Decoding_block2(base_filter, n_convblock_out) 793 | self.de0 = Decoding_block2(base_filter, n_convblock_out) 794 | 795 | self.final = ConvLayer1(base_filter, base_filter, 3, stride=1) 796 | 797 | def forward(self, x): 798 | _input = x 799 | encode0, down0 = self.e0(x) 800 | encode1, down1 = self.e1(down0) 801 | encode2, down2 = self.e2(down1) 802 | encode3, down3 = self.e3(down2) 803 | 804 | # media_end = self.Encoding_block_end(down3) 805 | media_end = self.mid(down3) 806 | 807 | g_conv3 = self.attention3(encode3, media_end) 808 | up3 = self.de3(g_conv3, media_end) 809 | g_conv2 = self.attention2(encode2, up3) 810 | up2 = self.de2(g_conv2, up3) 811 | 812 | g_conv1 = self.attention1(encode1, up2) 813 | up1 = self.de1(g_conv1, up2) 814 | 815 | g_conv0 = self.attention0(encode0, up1) 816 | up0 = self.de0(g_conv0, up1) 817 | 818 | final = self.final(up0) 819 | 820 | return _input+final 821 | 822 | 823 | class line(nn.Module): 824 | def __init__(self): 825 | super(line, self).__init__() 826 | self.delta = nn.Parameter(torch.randn(1, 1)) 827 | 828 | def forward(self, x, y): 829 | return torch.mul((1 - self.delta), x) + torch.mul(self.delta, y) 830 | 831 | 832 | class SCA(nn.Module): 833 | def __init__(self, channel, reduction=16): 834 | super(SCA, self).__init__() 835 | self.conv_du = nn.Sequential( 836 | ConvLayer1(in_channels=channel, out_channels=channel // reduction, kernel_size=3, stride=1), 837 | nn.ReLU(inplace=True), 838 | ConvLayer1(in_channels=channel // reduction, out_channels=channel, kernel_size=3, stride=1), 839 | nn.Sigmoid() 840 | ) 841 | 842 | def forward(self, x): 843 | y = self.conv_du(x) 844 | return y 845 | 846 | 847 | class Weight(nn.Module): 848 | def __init__(self, channel): 849 | super(Weight, self).__init__() 850 | self.cat =ConvLayer1(in_channels=channel*2, out_channels=channel, kernel_size=1, stride=1) 851 | self.C = ConvLayer1(in_channels=channel, out_channels=channel, kernel_size=3, stride=1) 852 | self.weight = SCA(channel) 853 | 854 | def forward(self, x, y): 855 | delta = self.weight(self.cat(torch.cat([self.C(y), x], 1))) 856 | return delta 857 | 858 | 859 | class transform_function(nn.Module): 860 | def __init__(self, in_channel, out_channel): 861 | super(transform_function, self).__init__() 862 | self.ext = ConvLayer1(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1) 863 | self.pre = torch.nn.Sequential( 864 | ConvLayer1(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1), 865 | nn.ReLU(inplace=True), 866 | ConvLayer1(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1), 867 | 868 | ) 869 | 870 | def forward(self, x): 871 | y = self.ext(x) 872 | return y+self.pre(y) 873 | 874 | 875 | class Inverse_transform_function(nn.Module): 876 | def __init__(self, in_channel, out_channel): 877 | super(Inverse_transform_function, self).__init__() 878 | self.ext = ConvLayer1(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1) 879 | self.pre = torch.nn.Sequential( 880 | ConvLayer1(in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1), 881 | nn.ReLU(inplace=True), 882 | ConvLayer1(in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1), 883 | ) 884 | 885 | def forward(self, x): 886 | x = self.pre(x)+x 887 | x = self.ext(x) 888 | return x 889 | 890 | 891 | class Deam(nn.Module): 892 | def __init__(self, Isreal): 893 | super(Deam, self).__init__() 894 | if Isreal: 895 | self.transform_function = transform_function(3, 64) 896 | self.inverse_transform_function = Inverse_transform_function(64, 3) 897 | else: 898 | self.transform_function = transform_function(1, 64) 899 | self.inverse_transform_function = Inverse_transform_function(64, 1) 900 | 901 | self.line11 = Weight(64) 902 | self.line22 = Weight(64) 903 | self.line33 = Weight(64) 904 | self.line44 = Weight(64) 905 | 906 | self.net2 = ziwangluo1(64, 3, 2) 907 | 908 | def forward(self, x): 909 | x = self.transform_function(x) 910 | y = x 911 | 912 | # Corresponds to NLO Sub-network 913 | x1 = self.net2(y) 914 | # Corresponds to DEAM Module 915 | delta_1 = self.line11(x1, y) 916 | x1 = torch.mul((1 - delta_1), x1) + torch.mul(delta_1, y) 917 | 918 | x2 = self.net2(x1) 919 | delta_2 = self.line22(x2, y) 920 | x2 = torch.mul((1 - delta_2), x2) + torch.mul(delta_2, y) 921 | 922 | x3 = self.net2(x2) 923 | delta_3 = self.line33(x3, y) 924 | x3 = torch.mul((1 - delta_3), x3) + torch.mul(delta_3, y) 925 | 926 | x4 = self.net2(x3) 927 | delta_4 = self.line44(x4, y) 928 | x4 = torch.mul((1 - delta_4), x4) + torch.mul(delta_4, y) 929 | x4 = self.inverse_transform_function(x4) 930 | return x4 931 | 932 | 933 | def print_network(net): 934 | num_params = 0 935 | for param in net.parameters(): 936 | num_params += param.numel() 937 | print(net) 938 | print('Total number of parameters: %d' % num_params) 939 | --------------------------------------------------------------------------------