├── .gitignore ├── .pre-commit-config.yaml ├── C2N ├── config │ ├── C2N_DIDN.yml │ └── C2N_DnCNN.yml ├── model │ ├── C2N.py │ ├── DIDN.py │ ├── DnCNN.py │ └── __init__.py └── util │ └── config.py ├── LICENSE ├── README.md ├── ckpt └── .gitkeep ├── data ├── .gitkeep ├── SIDD_clean_examples │ ├── SIDD1.png │ ├── SIDD2.png │ └── SIDD3.png ├── clean_ex1.png ├── clean_ex2.png ├── noisy_ex1_SIDD.png └── noisy_ex2_DND.png ├── imgs └── architecture.png ├── test_denoise.py └── test_generate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # checkpoint results 7 | ckpt/ 8 | !ckpt/.gitkeep 9 | 10 | # data 11 | data/ 12 | !data/.gitkeep 13 | !data/clean_ex1.png 14 | !data/clean_ex2.png 15 | !data/SIDD_clean_examples 16 | !data/noisy_ex1_SIDD.png 17 | !data/noisy_ex2_DND.png 18 | results/ 19 | 20 | # for macOS 21 | **/.DS_Store 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-yaml 7 | - id: end-of-file-fixer 8 | - id: trailing-whitespace 9 | args: [--markdown-linebreak-ext=md] 10 | - id: double-quote-string-fixer 11 | # - id: no-commit-to-branch 12 | # args: ["--branch", "main"] 13 | - repo: https://gitlab.com/pycqa/flake8 14 | rev: 3.8.4 15 | hooks: 16 | - id: flake8 17 | args: 18 | # better length for research-purpose code, 19 | # also making isort line length compatible with black 20 | - "--max-line-length=88" 21 | # E501: really long comments 22 | - "--ignore=E501" 23 | # E241: multiple spaces after ',' 24 | - "--ignore=E241" 25 | 26 | - repo: https://github.com/pycqa/isort 27 | rev: 5.8.0 28 | hooks: 29 | - id: isort 30 | name: isort (python) 31 | -------------------------------------------------------------------------------- /C2N/config/C2N_DIDN.yml: -------------------------------------------------------------------------------- 1 | # region - model 2 | model: 3 | Generator: C2N_G 4 | Discriminator: C2N_D 5 | Denoiser: DIDN_8 6 | 7 | color: BGR 8 | 9 | GAN_n_r: 32 10 | # endregion - model 11 | 12 | # region - common train 13 | train: 14 | aug: "rotflip" 15 | # endregion - common train 16 | 17 | # region - train_GAN 18 | train_GAN: 19 | loss: 20 | GAN_adv: 1.*wGAN-GP 21 | GAN_con: 1e-2*batch_zero_mean 22 | 23 | dset_N: prep_SIDD_Medium_sRGB 24 | dset_CL: prep_SIDD_Medium_sRGB 25 | noise_additive: None 26 | salt_pepper: None 27 | crop_patch: True 28 | 29 | preload: False 30 | # endregion - train_GAN 31 | 32 | # region - train_DNer 33 | train_DNer: 34 | loss: 35 | DN: 1.*L1 36 | 37 | dset: prep_SIDD_Medium_sRGB 38 | noise_additive: None 39 | salt_pepper: None 40 | crop_patch: True 41 | 42 | dset_gen_CL: None 43 | n_patch_gen: 18000 44 | loadname_CLtoN: None 45 | # endregion - train_DNer 46 | -------------------------------------------------------------------------------- /C2N/config/C2N_DnCNN.yml: -------------------------------------------------------------------------------- 1 | # region - model 2 | model: 3 | Generator: C2N_G 4 | Discriminator: C2N_D 5 | Denoiser: CDnCNN_B 6 | 7 | color: BGR 8 | 9 | GAN_n_r: 32 10 | # endregion - model 11 | 12 | # region - common train 13 | train: 14 | aug: "rotflip" 15 | # endregion - common train 16 | 17 | # region - train_GAN 18 | train_GAN: 19 | loss: 20 | GAN_adv: 1.*wGAN-GP 21 | GAN_con: 1e-2*batch_zero_mean 22 | 23 | dset_N: prep_SIDD_Medium_sRGB 24 | dset_CL: prep_SIDD_Medium_sRGB 25 | noise_additive: None 26 | salt_pepper: None 27 | crop_patch: True 28 | 29 | preload: False 30 | # endregion - train_GAN 31 | 32 | # region - train_DNer 33 | train_DNer: 34 | loss: 35 | DN: 1.*L1 36 | 37 | dset: prep_SIDD_Medium_sRGB 38 | noise_additive: None 39 | salt_pepper: None 40 | crop_patch: True 41 | 42 | dset_gen_CL: None 43 | n_patch_gen: 18000 44 | loadname_CLtoN: None 45 | # endregion - train_DNer 46 | -------------------------------------------------------------------------------- /C2N/model/C2N.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as torch_distb 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | d = 1e-6 7 | 8 | 9 | class ResBlock(nn.Module): 10 | def __init__(self, n_ch_in, ksize=3, bias=True): 11 | self.n_ch = n_ch_in 12 | self.bias = bias 13 | 14 | super().__init__() 15 | 16 | layer = [] 17 | layer.append(nn.Conv2d(self.n_ch, self.n_ch, ksize, 18 | padding=(ksize // 2), bias=self.bias, 19 | padding_mode='reflect')) 20 | layer.append(nn.PReLU()) 21 | layer.append(nn.Conv2d(self.n_ch, self.n_ch, ksize, 22 | padding=(ksize // 2), bias=self.bias, 23 | padding_mode='reflect')) 24 | 25 | self.body = nn.Sequential(*layer) 26 | 27 | def forward(self, x): 28 | return x + self.body(x) 29 | 30 | 31 | class C2N_D(nn.Module): 32 | def __init__(self, n_ch_in): 33 | self.n_ch_unit = 64 34 | 35 | super().__init__() 36 | 37 | self.n_block = 6 38 | 39 | self.n_ch_in = n_ch_in 40 | self.head = nn.Sequential( 41 | nn.Conv2d(self.n_ch_in, self.n_ch_unit, 3, 42 | padding=1, bias=True, 43 | padding_mode='reflect'), 44 | nn.PReLU() 45 | ) 46 | 47 | layers = [ResBlock(self.n_ch_unit) for _ in range(self.n_block)] 48 | self.body = nn.Sequential(*layers) 49 | 50 | self.tail = nn.Conv2d(self.n_ch_unit, 1, 3, 51 | padding=1, bias=True, 52 | padding_mode='reflect') 53 | 54 | def forward(self, b_img_Gout): 55 | (N, C, H, W) = b_img_Gout.size() 56 | 57 | y = self.head(b_img_Gout) 58 | y = self.body(y) 59 | y = self.tail(y) 60 | 61 | return y 62 | 63 | 64 | class C2N_G(nn.Module): 65 | def __init__(self, n_ch_in=3, n_ch_out=3, n_r=32): 66 | self.n_ch_unit = 64 # number of base channel 67 | self.n_ext = 5 # number of residual blocks in feature extractor 68 | self.n_block_indep = 3 # number of residual blocks in independent module 69 | self.n_block_dep = 2 # number of residual blocks in dependent module 70 | 71 | self.n_ch_in = n_ch_in # number of input channels 72 | self.n_ch_out = n_ch_out # number of output channels 73 | self.n_r = n_r # length of r vector 74 | 75 | super().__init__() 76 | 77 | # feature extractor 78 | self.ext_head = nn.Sequential( 79 | nn.Conv2d(n_ch_in, self.n_ch_unit, 3, 80 | padding=1, bias=True, 81 | padding_mode='reflect'), 82 | nn.PReLU(), 83 | nn.Conv2d(self.n_ch_unit, self.n_ch_unit * 2, 3, 84 | padding=1, bias=True, 85 | padding_mode='reflect') 86 | ) 87 | self.ext_merge = nn.Sequential( 88 | nn.Conv2d((self.n_ch_unit * 2) + self.n_r, 2 * self.n_ch_unit, 3, 89 | padding=1, bias=True, 90 | padding_mode='reflect'), 91 | nn.PReLU() 92 | ) 93 | self.ext = nn.Sequential( 94 | *[ResBlock(2 * self.n_ch_unit) for _ in range(self.n_ext)] 95 | ) 96 | 97 | # pipe-indep 98 | self.indep_merge = nn.Conv2d(self.n_ch_unit, self.n_ch_unit, 1, 99 | padding=0, bias=True, 100 | padding_mode='reflect') 101 | self.pipe_indep_1 = nn.Sequential( 102 | *[ResBlock(self.n_ch_unit, ksize=1, bias=False) 103 | for _ in range(self.n_block_indep)] 104 | ) 105 | self.pipe_indep_3 = nn.Sequential( 106 | *[ResBlock(self.n_ch_unit, ksize=3, bias=False) 107 | for _ in range(self.n_block_indep)] 108 | ) 109 | 110 | # pipe-dep 111 | self.dep_merge = nn.Conv2d(self.n_ch_unit, self.n_ch_unit, 1, 112 | padding=0, bias=True, 113 | padding_mode='reflect') 114 | self.pipe_dep_1 = nn.Sequential( 115 | *[ResBlock(self.n_ch_unit, ksize=1, bias=False) 116 | for _ in range(self.n_block_dep)] 117 | ) 118 | self.pipe_dep_3 = nn.Sequential( 119 | *[ResBlock(self.n_ch_unit, ksize=3, bias=False) 120 | for _ in range(self.n_block_dep)] 121 | ) 122 | 123 | # T tail 124 | self.T_tail = nn.Conv2d(self.n_ch_unit, self.n_ch_out, 1, 125 | padding=0, bias=True, 126 | padding_mode='reflect') 127 | 128 | def forward(self, x, r_vector=None): 129 | (N, C, H, W) = x.size() 130 | 131 | # r map 132 | if r_vector is None: 133 | r_vector = torch.randn(N, self.n_r) 134 | r_map = r_vector.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, H, W) 135 | r_map = r_map.float().detach() 136 | r_map = r_map.to(x.device) 137 | 138 | # feat extractor 139 | feat_CL = self.ext_head(x) 140 | list_cat = [feat_CL, r_map] 141 | feat_CL = self.ext_merge(torch.cat(list_cat, 1)) 142 | feat_CL = self.ext(feat_CL) 143 | 144 | # make initial dep noise feature 145 | normal_scale = F.relu(feat_CL[:, self.n_ch_unit:, :, :]) + d 146 | get_feat_dep = torch_distb.Normal(loc=feat_CL[:, :self.n_ch_unit, :, :], 147 | scale=normal_scale) 148 | feat_noise_dep = get_feat_dep.rsample().to(x.device) 149 | 150 | # make initial indep noise feature 151 | feat_noise_indep = torch.rand_like(feat_noise_dep, requires_grad=True) 152 | feat_noise_indep = feat_noise_indep.to(x.device) 153 | 154 | # ===== 155 | 156 | # pipe-indep 157 | list_cat = [feat_noise_indep] 158 | feat_noise_indep = self.indep_merge(torch.cat(list_cat, 1)) 159 | feat_noise_indep = self.pipe_indep_1(feat_noise_indep) + \ 160 | self.pipe_indep_3(feat_noise_indep) 161 | 162 | # pipe-dep 163 | list_cat = [feat_noise_dep] 164 | feat_noise_dep = self.dep_merge(torch.cat(list_cat, 1)) 165 | feat_noise_dep = self.pipe_dep_1(feat_noise_dep) + \ 166 | self.pipe_dep_3(feat_noise_dep) 167 | 168 | feat_noise = feat_noise_indep + feat_noise_dep 169 | noise = self.T_tail(feat_noise) 170 | 171 | return x + noise 172 | 173 | 174 | if __name__ == '__main__': 175 | x = torch.randn(2, 3, 64, 64) 176 | 177 | c2nd = C2N_D(3) 178 | print(c2nd(x).shape) 179 | c2ng = C2N_G(3, 3, 32) 180 | print(c2ng(x).shape) 181 | -------------------------------------------------------------------------------- /C2N/model/DIDN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class _Residual_Block(nn.Module): 6 | def __init__(self): 7 | super(_Residual_Block, self).__init__() 8 | 9 | # res1 10 | self.conv1 = nn.Conv2d(in_channels=256, out_channels=256, 11 | kernel_size=3, stride=1, padding=1, bias=False) 12 | self.relu2 = nn.PReLU() 13 | self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, 14 | kernel_size=3, stride=1, padding=1, bias=False) 15 | self.relu4 = nn.PReLU() 16 | # res1 17 | # concat1 18 | 19 | self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, 20 | kernel_size=3, stride=2, padding=1, bias=False) 21 | self.relu6 = nn.PReLU() 22 | 23 | # res2 24 | self.conv7 = nn.Conv2d(in_channels=512, out_channels=512, 25 | kernel_size=3, stride=1, padding=1, bias=False) 26 | self.relu8 = nn.PReLU() 27 | # res2 28 | # concat2 29 | 30 | self.conv9 = nn.Conv2d(in_channels=512, out_channels=1024, 31 | kernel_size=3, stride=2, padding=1, bias=False) 32 | self.relu10 = nn.PReLU() 33 | 34 | # res3 35 | self.conv11 = nn.Conv2d(in_channels=1024, out_channels=1024, 36 | kernel_size=3, stride=1, padding=1, bias=False) 37 | self.relu12 = nn.PReLU() 38 | # res3 39 | 40 | self.conv13 = nn.Conv2d(in_channels=1024, out_channels=2048, 41 | kernel_size=1, stride=1, padding=0, bias=False) 42 | self.up14 = nn.PixelShuffle(2) 43 | 44 | # concat2 45 | self.conv15 = nn.Conv2d(in_channels=1024, out_channels=512, 46 | kernel_size=1, stride=1, padding=0, bias=False) 47 | # res4 48 | self.conv16 = nn.Conv2d(in_channels=512, out_channels=512, 49 | kernel_size=3, stride=1, padding=1, bias=False) 50 | self.relu17 = nn.PReLU() 51 | # res4 52 | 53 | self.conv18 = nn.Conv2d(in_channels=512, out_channels=1024, 54 | kernel_size=1, stride=1, padding=0, bias=False) 55 | self.up19 = nn.PixelShuffle(2) 56 | 57 | # concat1 58 | self.conv20 = nn.Conv2d(in_channels=512, out_channels=256, 59 | kernel_size=1, stride=1, padding=0, bias=False) 60 | # res5 61 | self.conv21 = nn.Conv2d(in_channels=256, out_channels=256, 62 | kernel_size=3, stride=1, padding=1, bias=False) 63 | self.relu22 = nn.PReLU() 64 | self.conv23 = nn.Conv2d(in_channels=256, out_channels=256, 65 | kernel_size=3, stride=1, padding=1, bias=False) 66 | self.relu24 = nn.PReLU() 67 | # res5 68 | 69 | self.conv25 = nn.Conv2d(in_channels=256, out_channels=256, 70 | kernel_size=3, stride=1, padding=1, bias=False) 71 | 72 | def forward(self, x): 73 | res1 = x 74 | out = self.relu4(self.conv3(self.relu2(self.conv1(x)))) 75 | out = torch.add(res1, out) 76 | cat1 = out 77 | 78 | out = self.relu6(self.conv5(out)) 79 | res2 = out 80 | out = self.relu8(self.conv7(out)) 81 | out = torch.add(res2, out) 82 | cat2 = out 83 | 84 | out = self.relu10(self.conv9(out)) 85 | res3 = out 86 | 87 | out = self.relu12(self.conv11(out)) 88 | out = torch.add(res3, out) 89 | 90 | out = self.up14(self.conv13(out)) 91 | 92 | out = torch.cat([out, cat2], 1) 93 | out = self.conv15(out) 94 | res4 = out 95 | out = self.relu17(self.conv16(out)) 96 | out = torch.add(res4, out) 97 | 98 | out = self.up19(self.conv18(out)) 99 | 100 | out = torch.cat([out, cat1], 1) 101 | out = self.conv20(out) 102 | res5 = out 103 | out = self.relu24(self.conv23(self.relu22(self.conv21(out)))) 104 | out = torch.add(res5, out) 105 | 106 | out = self.conv25(out) 107 | out = torch.add(out, res1) 108 | 109 | return out 110 | 111 | 112 | class Recon_Block(nn.Module): 113 | def __init__(self): 114 | super(Recon_Block, self).__init__() 115 | 116 | self.conv1 = nn.Conv2d(in_channels=256, out_channels=256, 117 | kernel_size=3, stride=1, padding=1, bias=False) 118 | self.relu2 = nn.PReLU() 119 | self.conv3 = nn.Conv2d(in_channels=256, out_channels=256, 120 | kernel_size=3, stride=1, padding=1, bias=False) 121 | self.relu4 = nn.PReLU() 122 | 123 | self.conv5 = nn.Conv2d(in_channels=256, out_channels=256, 124 | kernel_size=3, stride=1, padding=1, bias=False) 125 | self.relu6 = nn.PReLU() 126 | self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, 127 | kernel_size=3, stride=1, padding=1, bias=False) 128 | self.relu8 = nn.PReLU() 129 | 130 | self.conv9 = nn.Conv2d(in_channels=256, out_channels=256, 131 | kernel_size=3, stride=1, padding=1, bias=False) 132 | self.relu10 = nn.PReLU() 133 | self.conv11 = nn.Conv2d(in_channels=256, out_channels=256, 134 | kernel_size=3, stride=1, padding=1, bias=False) 135 | self.relu12 = nn.PReLU() 136 | 137 | self.conv13 = nn.Conv2d(in_channels=256, out_channels=256, 138 | kernel_size=3, stride=1, padding=1, bias=False) 139 | self.relu14 = nn.PReLU() 140 | self.conv15 = nn.Conv2d(in_channels=256, out_channels=256, 141 | kernel_size=3, stride=1, padding=1, bias=False) 142 | self.relu16 = nn.PReLU() 143 | 144 | self.conv17 = nn.Conv2d(in_channels=256, out_channels=256, 145 | kernel_size=3, stride=1, padding=1, bias=False) 146 | 147 | def forward(self, x): 148 | res1 = x 149 | output = self.relu4(self.conv3(self.relu2(self.conv1(x)))) 150 | output = torch.add(output, res1) 151 | 152 | res2 = output 153 | output = self.relu8(self.conv7(self.relu6(self.conv5(output)))) 154 | output = torch.add(output, res2) 155 | 156 | res3 = output 157 | output = self.relu12(self.conv11(self.relu10(self.conv9(output)))) 158 | output = torch.add(output, res3) 159 | 160 | res4 = output 161 | output = self.relu16(self.conv15(self.relu14(self.conv13(output)))) 162 | output = torch.add(output, res4) 163 | 164 | output = self.conv17(output) 165 | output = torch.add(output, res1) 166 | 167 | return output 168 | 169 | 170 | class DIDN(nn.Module): 171 | def __init__(self, n_ch_in, n_ch_out, n_DUB): 172 | self.n_DUB = n_DUB 173 | 174 | super(DIDN, self).__init__() 175 | 176 | self.conv_input = nn.Conv2d(in_channels=n_ch_in, out_channels=256, 177 | kernel_size=3, stride=1, padding=1, bias=False) 178 | self.relu1 = nn.PReLU() 179 | self.conv_down = nn.Conv2d(in_channels=256, out_channels=256, 180 | kernel_size=3, stride=2, padding=1, bias=False) 181 | self.relu2 = nn.PReLU() 182 | 183 | self.list_DUB = nn.ModuleList([_Residual_Block() for cnt in range(n_DUB)]) 184 | 185 | self.recon = Recon_Block() 186 | # concat 187 | 188 | self.conv_mid = nn.Conv2d(in_channels=256 * len(self.list_DUB), 189 | out_channels=256, 190 | kernel_size=1, stride=1, padding=0, bias=False) 191 | self.relu3 = nn.PReLU() 192 | self.conv_mid2 = nn.Conv2d(in_channels=256, out_channels=256, 193 | kernel_size=3, stride=1, padding=1, bias=False) 194 | self.relu4 = nn.PReLU() 195 | 196 | self.subpixel = nn.PixelShuffle(2) 197 | self.conv_output = nn.Conv2d(in_channels=64, out_channels=n_ch_out, 198 | kernel_size=3, stride=1, padding=1, bias=False) 199 | 200 | def forward(self, x): 201 | residual = x 202 | out = self.relu1(self.conv_input(x)) 203 | out = self.relu2(self.conv_down(out)) 204 | 205 | list_out = [] 206 | for itr in range(self.n_DUB): 207 | out = self.list_DUB[itr](out) 208 | list_out.append(out) 209 | 210 | list_recon = [] 211 | for itr in range(self.n_DUB): 212 | recon = self.recon(list_out[itr]) 213 | list_recon.append(recon) 214 | 215 | out = torch.cat(list_recon, 1) 216 | 217 | out = self.relu3(self.conv_mid(out)) 218 | residual2 = out 219 | out = self.relu4(self.conv_mid2(out)) 220 | out = torch.add(out, residual2) 221 | 222 | out = self.subpixel(out) 223 | out = self.conv_output(out) 224 | out = torch.add(out, residual) 225 | 226 | return out 227 | 228 | 229 | class DIDN_6(DIDN): 230 | def __init__(self, n_ch_in=3, n_ch_out=3): 231 | super(DIDN_6, self).__init__(n_ch_in, n_ch_out, n_DUB=6) 232 | 233 | 234 | class DIDN_8(DIDN): 235 | def __init__(self, n_ch_in=3, n_ch_out=3): 236 | super(DIDN_8, self).__init__(n_ch_in, n_ch_out, n_DUB=8) 237 | 238 | 239 | if __name__ == '__main__': 240 | x = torch.randn(2, 3, 64, 64) 241 | model = DIDN_8(n_ch_in=3, n_ch_out=3) 242 | print(model(x).shape) 243 | -------------------------------------------------------------------------------- /C2N/model/DnCNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class DnCNN(nn.Module): 6 | def __init__(self, n_ch_in, n_ch_out, n_block): 7 | self.n_block = n_block # = 17 for non-blind, 20 for blind 8 | self.n_ch_unit = 64 9 | self.conv_ksize = 3 10 | self.batch_norm = True 11 | 12 | super().__init__() 13 | 14 | layers = [] 15 | layers.append(nn.Conv2d(n_ch_in, self.n_ch_unit, self.conv_ksize, 16 | padding=(self.conv_ksize // 2), groups=1, bias=True)) 17 | layers.append(nn.ReLU(inplace=True)) 18 | 19 | for _ in range(self.n_block - 2): 20 | layers.append(nn.Conv2d(self.n_ch_unit, self.n_ch_unit, self.conv_ksize, 21 | padding=(self.conv_ksize // 2), 22 | groups=1, bias=False)) 23 | if self.batch_norm: 24 | layers.append(nn.BatchNorm2d(self.n_ch_unit)) 25 | layers.append(nn.ReLU(inplace=True)) 26 | 27 | layers.append(nn.Conv2d(self.n_ch_unit, n_ch_out, self.conv_ksize, 28 | padding=(self.conv_ksize // 2), groups=1, bias=True)) 29 | 30 | self.body = nn.Sequential(*layers) 31 | 32 | self._initialize_weights() 33 | 34 | def forward(self, x): 35 | return x - self.body(x) 36 | 37 | def _initialize_weights(self): 38 | # Liyong version 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv2d): 41 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5) 43 | if isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5) 45 | clip_b = 0.025 46 | w = m.weight.data.shape[0] 47 | for j in range(w): 48 | if m.weight.data[j] >= 0 and m.weight.data[j] < clip_b: 49 | m.weight.data[j] = clip_b 50 | elif m.weight.data[j] > -clip_b and m.weight.data[j] < 0: 51 | m.weight.data[j] = -clip_b 52 | m.running_var.fill_(0.01) 53 | 54 | 55 | class DnCNN_S(DnCNN): 56 | def __init__(self, n_ch_in=1, n_ch_out=1): 57 | super(DnCNN_S, self).__init__(n_ch_in, n_ch_out, n_block=17) 58 | 59 | 60 | class DnCNN_B(DnCNN): 61 | def __init__(self, n_ch_in=1, n_ch_out=1): 62 | super(DnCNN_B, self).__init__(n_ch_in, n_ch_out, n_block=20) 63 | 64 | 65 | class CDnCNN_S(DnCNN): 66 | def __init__(self, n_ch_in=3, n_ch_out=3): 67 | super(CDnCNN_S, self).__init__(n_ch_in, n_ch_out, n_block=17) 68 | 69 | 70 | class CDnCNN_B(DnCNN): 71 | def __init__(self, n_ch_in=3, n_ch_out=3): 72 | super(CDnCNN_B, self).__init__(n_ch_in, n_ch_out, n_block=20) 73 | 74 | 75 | if __name__ == '__main__': 76 | x = torch.randn(2, 3, 64, 64) 77 | 78 | cdncnnb = CDnCNN_B(3, 3) 79 | print(cdncnnb(x).shape) 80 | -------------------------------------------------------------------------------- /C2N/model/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | dict_modulename = { 4 | 'DnCNN_S': 'DnCNN', 5 | 'DnCNN_B': 'DnCNN', 6 | 'CDnCNN_S': 'DnCNN', 7 | 'CDnCNN_B': 'DnCNN', 8 | # DIDN 9 | 'DIDN_6': 'DIDN', 10 | 'DIDN_8': 'DIDN', 11 | # C2N 12 | 'C2N_D': 'C2N', 13 | 'C2N_G': 'C2N', 14 | } 15 | 16 | 17 | def get_model(name_model): 18 | if name_model is None: 19 | return None 20 | else: 21 | module_model = import_module('C2N.model.{}'.format(dict_modulename[name_model])) 22 | model_class = getattr(module_model, name_model) 23 | return model_class() 24 | -------------------------------------------------------------------------------- /C2N/util/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | 5 | 6 | class ConfigParser: 7 | def __init__(self, args): 8 | # load model configuration 9 | path_config = os.path.join('C2N', 'config') 10 | fname_config = f'{os.path.splitext(os.path.basename(args.config))[0]}.yml' 11 | with open(os.path.join(path_config, fname_config)) as f: 12 | self.config = yaml.load(f, Loader=yaml.FullLoader) 13 | 14 | # load argument 15 | for arg in args.__dict__: 16 | self.config[arg] = args.__dict__[arg] 17 | 18 | # string None handing 19 | self.convert_None(self.config) 20 | 21 | def __getitem__(self, name): 22 | return self.config[name] 23 | 24 | def convert_None(self, d): 25 | for key in d: 26 | if d[key] == 'None': 27 | d[key] = None 28 | if isinstance(d[key], dict): 29 | self.convert_None(d[key]) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 onwn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C2N: Practical Generative Noise Modeling for Real-World Denoising - Official PyTorch release 2 | 3 | This is an official PyTorch release of the paper 4 | [**"C2N: Practical Generative Noise Modeling for Real-World Denoising"**](https://openaccess.thecvf.com/content/ICCV2021/papers/Jang_C2N_Practical_Generative_Noise_Modeling_for_Real-World_Denoising_ICCV_2021_paper.pdf) 5 | from **ICCV 2021**. 6 | 7 | ![architecture](./imgs/architecture.png) 8 | 9 | If you find C2N useful in your research, please cite our work as follows: 10 | 11 | ``` 12 | @InProceedings{Jang_2021_ICCV, 13 | author = {Jang, Geonwoon and Lee, Wooseok and Son, Sanghyun and Lee, Kyoung Mu}, 14 | title = {C2N: Practical Generative Noise Modeling for Real-World Denoising}, 15 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 16 | month = {October}, 17 | year = {2021}, 18 | pages = {2350-2359} 19 | } 20 | ``` 21 | 22 | [[PDF](https://openaccess.thecvf.com/content/ICCV2021/papers/Jang_C2N_Practical_Generative_Noise_Modeling_for_Real-World_Denoising_ICCV_2021_paper.pdf)] 23 | [[Supp](https://openaccess.thecvf.com/content/ICCV2021/supplemental/Jang_C2N_Practical_Generative_ICCV_2021_supplemental.pdf)] 24 | [[arXiv](https://arxiv.org/abs/2202.09533)] 25 | 26 | --- 27 | 28 | ## Setup 29 | 30 | ### Dependencies 31 | 32 | - Python 3.9.6 33 | - numpy >= 1.16.4 34 | - cudatoolkit >= 10. (if using GPU) 35 | - PyTorch 1.2.0 36 | - opencv-python 37 | - scikit-image >= 0.15.0 38 | - tqdm 39 | - pillow 40 | - pyyamml 41 | - imutils 42 | 43 | 50 | 51 | ### Data 52 | 53 | You can place any custom images in `./data` and image datasets in subdirectory `./data/[name_of_dataset]` 54 | 55 | For the SIDD and DND benchmark images, you can find them at [SIDD Benchmark](https://www.eecs.yorku.ca/~kamel/sidd/benchmark.php) and [DND Benchmark](). 56 | Convert them into .png images and place them in each subdirectory. 57 | 58 | ### Pre-trained Models 59 | 60 | Download following pre-trained models: 61 | 62 | | Generator | Clean | Noisy | config | Pre-trained | 63 | | :-------: | :---: | :---: | :-------: | :---------: | 64 | | C2N | SIDD | SIDD | C2N_DnCNN | [model](https://drive.google.com/file/d/1Cn0KptLHd8p6v4_72PMvjssZbzTmgN4Z/view?usp=sharing) | 65 | | C2N | SIDD | DND | C2N_DnCNN | [model](https://drive.google.com/file/d/1Ce2Z9Gz7YssiIFIgGmj86xwDDqjgl2-S/view?usp=sharing) | 66 | 67 | | Denoiser | Generator | Clean | Noisy | Clean (denoiser train) | config | Pre-trained | 68 | | :------: | :-------: | :---: | :---: | :--------------: | :-------: | :---------: | 69 | | DnCNN | C2N | SIDD | SIDD | SIDD | C2N_DnCNN | [model](https://drive.google.com/file/d/1wxuhXwhHYVLiAuUvwqIBiOX8NcFzdWmN/view?usp=sharing) | 70 | | DIDN | C2N | SIDD | SIDD | SIDD | C2N_DIDN | [model](https://drive.google.com/file/d/12Q5zZp3l_sH4pofXJraZbiEtAQ-kCtWD/view?usp=sharing) | 71 | | DIDN | C2N | SIDD | DND | SIDD | C2N_DIDN | [model](https://drive.google.com/file/d/1gZQ3mfhLlnN0FZD3lxZGiApY8nVJBPT-/view?usp=sharing) | 72 | 73 | --- 74 | 75 | ## Demo (Quick start) 76 | 77 | ### test_generate.py: 78 | 79 | - `config`: Name of the configuration. 80 | - `ckpt`: Name of the checkpoint to load. Choose between 'C2N-SIDD_to_SIDD' and 'C2N-DND_to_SIDD' depending on the noisy images it is trained on. 81 | - `mode`: 'single' or 'dataset'. 82 | - `data`: Filename of clean image if `mode` is 'single', dataset of clean images if `mode` is 'dataset'. 83 | - `gpu`: GPU id. Currently this demo only supports single-GPU or CPU device. 84 | 85 | Examples: 86 | 87 | ```bash 88 | # Generate on single clean image 89 | python test_generate.py --ckpt C2N-SIDD_to_SIDD.ckpt --mode single --data clean_ex1.png --gpu 0 90 | python test_generate.py --ckpt C2N-DND_to_SIDD.ckpt --mode single --data clean_ex2.png --gpu 0 91 | 92 | # Generate on clean images in a dataset 93 | python test_generate.py --ckpt C2N-SIDD_to_SIDD.ckpt --mode dataset --data SIDD_clean_examples --gpu 0 94 | python test_generate.py --ckpt C2N-DND_to_SIDD.ckpt --mode dataset --data SIDD_clean_examples --gpu 0 95 | ``` 96 | 97 | ### test_denoise.py: 98 | 99 | - `config`: Name of the configuration. Choose between 'C2N_DnCNN' and 'C2N_DIDN' depending on the denoiser to be used. 100 | - `ckpt`: Name of the checkpoint to load. 101 | - Name format: '[denoiser model]-[C2N train noisy set]_to_[C2N train clean set]-on\_[denoiser train set]' 102 | - denoiser model: 'DnCNN' or 'DIDN'. 103 | - C2N train noisy set: Dataset that the noisy images for C2N training are sampled from. 104 | - C2N train clean set: Dataset that the clean images for C2N training are sampled from. 105 | - denoiser train set: Dataset that the clean->noisy images are generated from, to train the denoiser. 106 | - `mode`: 'single' or 'dataset'. 107 | - `data`: Filename of noisy/generated image if `mode` is 'single', dataset of noisy/generated images if `mode` is 'dataset'. 108 | - `gpu`: GPU id. Currently this demo only supports single-GPU or CPU device. 109 | 110 | Examples: 111 | 112 | ```bash 113 | # Denoise single noisy image 114 | python test_denoise.py --config C2N_DnCNN --ckpt DnCNN-SIDD_to_SIDD-on_SIDD --mode single --data noisy_ex1_SIDD.png --gpu 0 115 | python test_denoise.py --config C2N_DIDN --ckpt DIDN-SIDD_to_SIDD-on_SIDD --mode single --data noisy_ex1_SIDD.png --gpu 0 116 | python test_denoise.py --config C2N_DIDN --ckpt DIDN-SIDD_to_DND-on_SIDD --mode single --data noisy_ex2_DND.png --gpu 0 117 | 118 | # Denoise noisy images in a dataset 119 | python test_denoise.py --config C2N_DnCNN --ckpt DnCNN-SIDD_to_SIDD-on_SIDD --mode dataset --data SIDD_benchmark --gpu 0 120 | python test_denoise.py --config C2N_DIDN --ckpt DIDN-SIDD_to_SIDD-on_SIDD --mode dataset --data SIDD_benchmark --gpu 0 121 | python test_denoise.py --config C2N_DIDN --ckpt DIDN-SIDD_to_DND-on_SIDD --mode dataset --data DND_benchmark --gpu 0 122 | 123 | # Denoise the generated images from C2N 124 | # You may copy the generated images in `results/[input_clean_data_path*]` to `data/[input_clean_data_path*]_generated.png`, for example. 125 | python test_denoise.py --config C2N_DIDN --ckpt DIDN-SIDD_to_SIDD-on_SIDD --mode single --data clean_ex1_generated.png --gpu 0 126 | python test_denoise.py --config C2N_DIDN --ckpt DIDN-SIDD_to_DND-on_SIDD --mode single --data clean_ex2_generated.png --gpu 0 127 | ``` 128 | -------------------------------------------------------------------------------- /ckpt/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/ckpt/.gitkeep -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/data/.gitkeep -------------------------------------------------------------------------------- /data/SIDD_clean_examples/SIDD1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/data/SIDD_clean_examples/SIDD1.png -------------------------------------------------------------------------------- /data/SIDD_clean_examples/SIDD2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/data/SIDD_clean_examples/SIDD2.png -------------------------------------------------------------------------------- /data/SIDD_clean_examples/SIDD3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/data/SIDD_clean_examples/SIDD3.png -------------------------------------------------------------------------------- /data/clean_ex1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/data/clean_ex1.png -------------------------------------------------------------------------------- /data/clean_ex2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/data/clean_ex2.png -------------------------------------------------------------------------------- /data/noisy_ex1_SIDD.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/data/noisy_ex1_SIDD.png -------------------------------------------------------------------------------- /data/noisy_ex2_DND.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/data/noisy_ex2_DND.png -------------------------------------------------------------------------------- /imgs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/onwn/C2N/b3d13405f3eebf1a3ea8b769ec7bdf1c917d5636/imgs/architecture.png -------------------------------------------------------------------------------- /test_denoise.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from C2N.model import get_model 9 | from C2N.util.config import ConfigParser 10 | 11 | 12 | def main(): 13 | # parsing configuration 14 | args = argparse.ArgumentParser() 15 | args.add_argument('--config', default=None, type=str) 16 | args.add_argument('--ckpt', default=None, type=str) 17 | args.add_argument('--mode', default='single', type=str) 18 | args.add_argument('--data', default=None, type=str) 19 | args.add_argument('--gpu', default=None, type=str) 20 | 21 | args = args.parse_args() 22 | 23 | assert args.config is not None, 'config file path is needed.' 24 | assert args.ckpt is not None, 'checkpoint epoch is needed.' 25 | assert args.data is not None, 'data path or filename is needed.' 26 | assert args.mode in ['single', 'dataset'], 'mode must be single or dataset.' 27 | 28 | # device setting 29 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) if args.gpu is not None else '' 30 | 31 | configs = ConfigParser(args) 32 | 33 | denoise(configs) 34 | 35 | 36 | @ torch.no_grad() 37 | def denoise(configs): 38 | # model load 39 | denoiser = get_model(configs['model']['Denoiser']) 40 | if configs['gpu'] is not None: 41 | denoiser = denoiser.cuda() 42 | fpath_ckpt = os.path.join( 43 | 'ckpt', f'{os.path.splitext(os.path.basename(configs["ckpt"]))[0]}.ckpt' 44 | ) 45 | ckpt = torch.load( 46 | fpath_ckpt, 47 | map_location=torch.device('cpu') 48 | if os.environ['CUDA_VISIBLE_DEVICES'] == '' else None 49 | ) 50 | denoiser.load_state_dict(ckpt) 51 | denoiser.eval() 52 | print('model loaded!') 53 | 54 | # make results folder 55 | os.makedirs('./results', exist_ok=True) 56 | 57 | # denoise 58 | if configs['mode'] == 'single': 59 | denoised = denoise_single_img( 60 | configs, denoiser, os.path.join('data', configs['data']) 61 | ) 62 | fname_data = os.path.basename(configs['data']) 63 | tag_data = os.path.splitext(fname_data)[0] 64 | fpath_output = f'./results/{tag_data}_denoised.png' 65 | cv2.imwrite(fpath_output, denoised) 66 | print('denoised to %s' % (fpath_output)) 67 | elif configs['mode'] == 'dataset': 68 | for (dirpath, _, filenames) in os.walk(os.path.join('data', configs['data'])): 69 | folder_name = os.path.basename(os.path.normpath(dirpath)) 70 | os.makedirs('./results/%s' % folder_name, exist_ok=True) 71 | 72 | for filename in filenames: 73 | if os.path.splitext(filename)[1] not in ['.png', '.jpg', '.jpeg']: 74 | continue 75 | denoised = denoise_single_img(configs, denoiser, 76 | os.path.join(dirpath, filename)) 77 | tag_data = os.path.splitext(filename)[0] 78 | fpath_output = f'./results/{folder_name}/{tag_data}_denoised.png' 79 | os.makedirs(f'./results/{folder_name}', exist_ok=True) 80 | cv2.imwrite(fpath_output, denoised) 81 | print('denoised to %s' % (fpath_output)) 82 | 83 | 84 | def denoise_single_img(configs, denoiser, img_path): 85 | img = cv2.imread(img_path, -1).astype(float) 86 | img = img / 255.0 87 | img = img.transpose(2, 0, 1) 88 | img = torch.from_numpy(img).float().unsqueeze(0) 89 | if configs['gpu'] is not None: 90 | img = img.cuda() 91 | 92 | denoised = denoiser(img) 93 | denoised = denoised.cpu().detach().squeeze(0).numpy() 94 | denoised = denoised.transpose(1, 2, 0) 95 | denoised = denoised * 255.0 96 | denoised += 0.5 97 | denoised = np.clip(denoised, 0., 255.) 98 | denoised = denoised.astype(np.uint8) 99 | 100 | return denoised 101 | 102 | 103 | if __name__ == '__main__': 104 | main() 105 | -------------------------------------------------------------------------------- /test_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | 8 | from C2N.model import get_model 9 | from C2N.util.config import ConfigParser 10 | 11 | 12 | def main(): 13 | # parsing configuration 14 | args = argparse.ArgumentParser() 15 | args.add_argument('--config', default='C2N_DnCNN', type=str) 16 | args.add_argument('--ckpt', default=None, type=str) 17 | args.add_argument('--mode', default='single', type=str) 18 | args.add_argument('--data', default=None, type=str) 19 | args.add_argument('--gpu', default=None, type=int) 20 | 21 | args = args.parse_args() 22 | 23 | assert args.config is not None, 'config file path is needed' 24 | assert args.ckpt is not None, 'checkpoint epoch is needed' 25 | assert args.data is not None, 'data path or filename is needed' 26 | assert args.mode in ['single', 'dataset'], 'mode must be single or dataset' 27 | 28 | # device setting 29 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) if args.gpu is not None else '' 30 | 31 | configs = ConfigParser(args) 32 | 33 | generate(configs) 34 | 35 | 36 | @ torch.no_grad() 37 | def generate(configs): 38 | # model load 39 | generator = get_model(configs['model']['Generator']) 40 | if configs['gpu'] is not None: 41 | generator = generator.cuda() 42 | fpath_ckpt = os.path.join( 43 | 'ckpt', f'{os.path.splitext(os.path.basename(configs["ckpt"]))[0]}.ckpt' 44 | ) 45 | ckpt = torch.load( 46 | fpath_ckpt, 47 | map_location=torch.device('cpu') 48 | if os.environ['CUDA_VISIBLE_DEVICES'] == '' else None 49 | ) 50 | generator.load_state_dict(ckpt) 51 | generator.eval() 52 | print('model loaded!') 53 | 54 | # make results folder 55 | os.makedirs('./results', exist_ok=True) 56 | 57 | # denoise 58 | if configs['mode'] == 'single': 59 | generated = generate_single_img( 60 | configs, generator, os.path.join('data', configs['data']) 61 | ) 62 | fname_data = os.path.basename(configs['data']) 63 | tag_data = os.path.splitext(fname_data)[0] 64 | fpath_output = f'./results/{tag_data}_generated.png' 65 | cv2.imwrite(fpath_output, generated) 66 | print('generated to %s' % (fpath_output)) 67 | elif configs['mode'] == 'dataset': 68 | for (dirpath, _, filenames) in os.walk(os.path.join('data', configs['data'])): 69 | folder_name = dirpath.split('/')[-1] 70 | os.makedirs(f'./results/{folder_name}', exist_ok=True) 71 | 72 | for filename in filenames: 73 | if os.path.splitext(filename)[1] not in ['.png', '.jpg', '.jpeg']: 74 | continue 75 | generated = generate_single_img(configs, generator, 76 | os.path.join(dirpath, filename)) 77 | tag_data = os.path.splitext(filename)[0] 78 | fpath_output = f'./results/{folder_name}/{tag_data}_generated.png' 79 | cv2.imwrite(fpath_output, generated) 80 | print('generated to %s' % (fpath_output)) 81 | 82 | 83 | def generate_single_img(configs, generator, img_path): 84 | img = cv2.imread(img_path, -1).astype(float) 85 | img = img / 255.0 86 | img = img.transpose(2, 0, 1) 87 | img = torch.from_numpy(img).float().unsqueeze(0) 88 | if configs['gpu'] is not None: 89 | img = img.cuda() 90 | 91 | generated = generator(img) 92 | generated = generated.cpu().detach().squeeze(0).numpy() 93 | generated = generated.transpose(1, 2, 0) 94 | generated = generated * 255.0 95 | generated += 0.5 96 | generated = np.clip(generated, 0., 255.) 97 | generated = generated.astype(np.uint8) 98 | 99 | return generated 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | --------------------------------------------------------------------------------