├── .gitignore
├── LICENSE
├── README.md
├── checkpoints
└── README.md
├── model.py
├── synthesize.py
├── test.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .*/
2 | _*/
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Zhang Yi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Rethinking Noise Synthesis and Modeling in Raw Denoising (ICCV2021)
2 | ---
3 | [Yi Zhang](https://zhangyi-3.github.io/)1,
4 | [Hongwei Qin](https://scholar.google.com/citations?user=ZGM7HfgAAAAJ&hl=en)2,
5 | [Xiaogang Wang](https://scholar.google.com/citations?user=-B5JgjsAAAAJ&hl=zh-CN)1,
6 | [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/)1
7 |
8 | 1CUHK-SenseTime Joint Lab, 2SenseTime Research
9 |
10 |
11 |
12 | ### Abstract
13 |
14 | >The lack of large-scale real raw image denoising dataset gives rise to challenges on synthesizing
15 | realistic raw image noise for training denoising models.However, the real raw image noise is
16 | contributed by many noise sources and varies greatly among different sensors.
17 | Existing methods are unable to model all noise sources accurately, and building a noise model
18 | for each sensor is also laborious. In this paper, we introduce a new perspective to synthesize
19 | noise by directly sampling from the sensor's real noise.It inherently generates accurate raw image
20 | noise for different camera sensors. Two efficient and generic techniques: pattern-aligned patch
21 | sampling and high-bit reconstruction help accurate synthesis of spatial-correlated noise and high-bit noise respectively. We conduct systematic experiments on SIDD and ELD datasets.
22 | The results show that (1) our method outperforms existing methods and demonstrates wide
23 | generalization on different sensors and lighting conditions. (2) Recent conclusions derived from
24 | DNN-based noise modeling methods are actually based on inaccurate noise parameters.
25 | The DNN-based methods still cannot outperform physics-based statistical methods.
26 |
27 | ### Testing
28 | The code has been tested with the following environment:
29 | ```
30 | pytorch == 1.5.0
31 | scikit-image == 0.16.2
32 | scipy == 1.3.1
33 | h5py 2.10.0
34 | ```
35 |
36 | - Prepare the [SIDD-Medium Dataset](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php) dataset.
37 | - Download the [pretrained models](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155135732_link_cuhk_edu_hk/Egb3x2YO-qBBgQ41N8WiCIUBRQuxb4gWsV_Ml1yLfDti9w?e=wRwC7e) and put them into the checkpoints folder.
38 | - Modify the default root of the SIDD dataset and run the command with the specific camera name (s6 | ip | gp)
39 | ```
40 | python -u test.py --root SIDD_Medium/Data --camera s6
41 | ```
42 |
43 |
44 | ### Noise parameters
45 | The calibrated noise parameters for the SIDD dataset and the ELD dataset can be found in ``synthesize.py``.
46 |
47 |
48 | ### Citation
49 | ``` bibtex
50 | @InProceedings{zhang2021rethinking,
51 | author = {Zhang, Yi and Qin, Hongwei and Wang, Xiaogang and Li, Hongsheng},
52 | title = {Rethinking Noise Synthesis and Modeling in Raw Denoising},
53 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
54 | month = {October},
55 | year = {2021},
56 | pages = {4593-4601}
57 | }
58 | ```
59 |
60 | ### Contact
61 | Feel free to contact zhangyi@link.cuhk.edu.hk if you have any questions.
62 |
63 | ### Acknowledgments
64 | * [ELD](https://github.com/Vandermode/ELD)
65 | * [CA-NoiseGAN](https://github.com/arcchang1236/CA-NoiseGAN)
66 | * [simple-camera-pipeline](https://github.com/AbdoKamel/simple-camera-pipeline)
--------------------------------------------------------------------------------
/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | Put pretrained models.
2 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class UNetSeeInDark(nn.Module):
6 | def __init__(self, in_channels=4, out_channels=4):
7 | super(UNetSeeInDark, self).__init__()
8 |
9 | # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10 | self.conv1_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1)
11 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
12 | self.pool1 = nn.MaxPool2d(kernel_size=2)
13 |
14 | self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
15 | self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
16 | self.pool2 = nn.MaxPool2d(kernel_size=2)
17 |
18 | self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
19 | self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
20 | self.pool3 = nn.MaxPool2d(kernel_size=2)
21 |
22 | self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
23 | self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
24 | self.pool4 = nn.MaxPool2d(kernel_size=2)
25 |
26 | self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
27 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
28 |
29 | self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
30 | self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
31 | self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
32 |
33 | self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
34 | self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
35 | self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
36 |
37 | self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
38 | self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
39 | self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
40 |
41 | self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
42 | self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
43 | self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
44 |
45 | self.conv10_1 = nn.Conv2d(32, out_channels, kernel_size=1, stride=1)
46 |
47 | def forward(self, x):
48 | conv1 = self.lrelu(self.conv1_1(x))
49 | conv1 = self.lrelu(self.conv1_2(conv1))
50 | pool1 = self.pool1(conv1)
51 |
52 | conv2 = self.lrelu(self.conv2_1(pool1))
53 | conv2 = self.lrelu(self.conv2_2(conv2))
54 | pool2 = self.pool1(conv2)
55 |
56 | conv3 = self.lrelu(self.conv3_1(pool2))
57 | conv3 = self.lrelu(self.conv3_2(conv3))
58 | pool3 = self.pool1(conv3)
59 |
60 | conv4 = self.lrelu(self.conv4_1(pool3))
61 | conv4 = self.lrelu(self.conv4_2(conv4))
62 | pool4 = self.pool1(conv4)
63 |
64 | conv5 = self.lrelu(self.conv5_1(pool4))
65 | conv5 = self.lrelu(self.conv5_2(conv5))
66 |
67 | up6 = self.upv6(conv5)
68 | up6 = torch.cat([up6, conv4], 1)
69 | conv6 = self.lrelu(self.conv6_1(up6))
70 | conv6 = self.lrelu(self.conv6_2(conv6))
71 |
72 | up7 = self.upv7(conv6)
73 | up7 = torch.cat([up7, conv3], 1)
74 | conv7 = self.lrelu(self.conv7_1(up7))
75 | conv7 = self.lrelu(self.conv7_2(conv7))
76 |
77 | up8 = self.upv8(conv7)
78 | up8 = torch.cat([up8, conv2], 1)
79 | conv8 = self.lrelu(self.conv8_1(up8))
80 | conv8 = self.lrelu(self.conv8_2(conv8))
81 |
82 | up9 = self.upv9(conv8)
83 | up9 = torch.cat([up9, conv1], 1)
84 | conv9 = self.lrelu(self.conv9_1(up9))
85 | conv9 = self.lrelu(self.conv9_2(conv9))
86 |
87 | conv10 = self.conv10_1(conv9)
88 | # out = nn.functional.pixel_shuffle(conv10, 2)
89 | out = conv10
90 | return out
91 |
92 | def _initialize_weights(self):
93 | for m in self.modules():
94 | if isinstance(m, nn.Conv2d):
95 | m.weight.data.normal_(0.0, 0.02)
96 | if m.bias is not None:
97 | m.bias.data.normal_(0.0, 0.02)
98 | if isinstance(m, nn.ConvTranspose2d):
99 | m.weight.data.normal_(0.0, 0.02)
100 |
101 | def lrelu(self, x):
102 | outt = torch.max(0.2 * x, x)
103 | return outt
104 |
--------------------------------------------------------------------------------
/synthesize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.distributions.poisson import Poisson
3 |
4 | import numpy as np
5 |
6 |
7 | def generate_poisson_(y, k=1):
8 | y = torch.poisson(y / k) * k
9 | return y
10 |
11 |
12 | def generate_read_noise(shape, noise_type, scale, loc=0):
13 | noise_type = noise_type.lower()
14 | if noise_type == 'norm':
15 | read = torch.FloatTensor(shape).normal_(loc, scale)
16 | else:
17 | raise NotImplementedError('Read noise type error.')
18 | return read
19 |
20 |
21 | def noise_profiles(camera):
22 | camera = camera.lower()
23 | if camera == 'ip': # iPhone
24 | iso_set = [100, 200, 400, 800, 1600, 2000]
25 | cshot = [0.00093595, 0.00104404, 0.00116461, 0.00129911, 0.00144915, 0.00150104]
26 | cread = [4.697713410870357e-07, 6.904488905478659e-07, 6.739473744228789e-07,
27 | 6.776787431555864e-07, 6.781983208034481e-07, 6.783184262356993e-07]
28 | elif camera == 's6': # Sumsung s6 edge
29 | iso_set = [100, 200, 400, 800, 1600, 3200]
30 | cshot = [0.00162521, 0.00256175, 0.00403799, 0.00636492, 0.01003277, 0.01581424]
31 | cread = [1.1792188420255036e-06, 1.607602896683437e-06, 2.9872611575167216e-06,
32 | 5.19157563906707e-06, 1.0011034196248119e-05, 2.0652668477786836e-05]
33 | elif camera == 'gp': # Google Pixel
34 | iso_set = [100, 200, 400, 800, 1600, 3200, 6400]
35 | cshot = [0.00024718, 0.00048489, 0.00095121, 0.001866, 0.00366055, 0.00718092, 0.01408686]
36 | cread = [1.6819349659429324e-06, 2.0556981890860545e-06, 2.703070976302046e-06,
37 | 4.116405515789963e-06, 7.569256436438246e-06, 1.5199001098203388e-05, 5.331422827048082e-05]
38 | elif camera == 'sony': # Sony a7s2
39 | iso_set = [800, 1600, 3200]
40 | cshot = [1.0028880020069384, 1.804521362114003, 3.246920234173119]
41 | cread = [4.053034401667052, 6.692229120425673, 4.283115294604881]
42 | elif camera == 'nikon': # Nikon D850
43 | iso_set = [800, 1600, 3200]
44 | cshot = [3.355988883536526, 6.688199969242411, 13.32901281288985]
45 | cread = [4.4959735547955635, 8.360429952584846, 15.684213053647735]
46 | else:
47 | assert NotImplementedError
48 | return iso_set, cshot, cread
49 |
50 |
51 | def pg_noise_demo(clean_tensor, camera='IP'):
52 | iso_set, k_set, read_scale_set = noise_profiles(camera)
53 |
54 | # sample randomly
55 | i = np.random.choice(len(k_set))
56 | k, read_scale = k_set[i], read_scale_set[i]
57 |
58 | noisy_shot = generate_poisson_(clean_tensor, k)
59 | read_noise = generate_read_noise(clean_tensor.shape, noise_type='norm', scale=read_scale)
60 | noisy = noisy_shot + read_noise
61 | return noisy
62 |
63 |
64 | if __name__ == '__main__':
65 | clean = torch.randn(48, 48).clamp(0, 1)
66 | noisy = pg_noise_demo(clean, camera='ip')
67 | print(noisy.shape, noisy.mean())
68 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 |
5 | import numpy as np
6 | import torch.nn.functional as F
7 | import scipy.io as sio
8 |
9 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity
10 |
11 | import utils
12 | from model import UNetSeeInDark
13 |
14 |
15 | def forward_patches(model, noisy, patch_size=256 * 3, pad=32):
16 | shift = patch_size - pad * 2
17 |
18 | noisy = torch.FloatTensor(noisy).cuda()
19 | noisy = utils.raw2stack(noisy).unsqueeze(0)
20 | noisy = F.pad(noisy, (pad, pad, pad, pad), mode='reflect')
21 | denoised = torch.zeros_like(noisy)
22 |
23 | _, _, H, W = noisy.shape
24 | for i in np.arange(0, H, shift):
25 | for j in np.arange(0, W, shift):
26 | h_end, w_end = min(i + patch_size, H), min(j + patch_size, W)
27 | h_start, w_start = h_end - patch_size, w_end - patch_size
28 |
29 | input_var = noisy[..., h_start: h_end, w_start: w_end]
30 | with torch.no_grad():
31 | out_var = model(input_var)
32 | denoised[..., h_start + pad: h_end - pad, w_start + pad: w_end - pad] = \
33 | out_var[..., pad:-pad, pad:-pad]
34 |
35 | denoised = denoised[..., pad:-pad, pad:-pad]
36 | denoised = utils.stack2raw(denoised[0]).detach().cpu().numpy()
37 |
38 | denoised = denoised.clip(0, 1)
39 | return denoised
40 |
41 |
42 | if __name__ == '__main__':
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument('--root', default='/mnt/lustre/zhangyi3/data/SIDD_Medium/Data/')
45 | parser.add_argument('--camera', choices=['s6', 'gp', 'ip'], required=True, help='camera name')
46 | args = parser.parse_args()
47 |
48 | camera = args.camera
49 | root = args.root
50 |
51 | # save_dir = './results/' + camera
52 | # if not os.path.exists(save_dir):
53 | # os.makedirs(save_dir)
54 | print('test', camera, 'root', root)
55 |
56 | test_data_list = [item for item in os.listdir(root) if int(item.split('_')[1]) in [2, 3, 5] and camera in item.lower()]
57 |
58 | # build model
59 | model = UNetSeeInDark()
60 | model = model.cuda()
61 | model = torch.nn.DataParallel(model)
62 |
63 | model_path = './checkpoints/%s.pth' % camera.lower()
64 | model.load_state_dict(torch.load(model_path, map_location='cpu'))
65 |
66 | psnr_list = []
67 | for idx, item in enumerate(test_data_list):
68 | head = item[:4]
69 | for tail in ['GT_RAW_010', 'GT_RAW_011']:
70 | print('processing', idx, item, tail, end=' ')
71 | mat = utils.open_hdf5(os.path.join(root, item, '%s_%s.MAT' % (head, tail)))
72 | gt = np.array(mat['x'], dtype=np.float32)
73 | mat = utils.open_hdf5(os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'NOISY'))))
74 | noisy = np.array(mat['x'], dtype=np.float32)
75 |
76 | meta = sio.loadmat(os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'METADATA'))))
77 | meta = meta['metadata'][0][0]
78 |
79 | # transform to rggb pattern
80 | py_meta = utils.extract_metainfo(
81 | os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'METADATA'))))
82 | pattern = py_meta['pattern']
83 | noisy = utils.transform_to_rggb(noisy, pattern)
84 | gt = utils.transform_to_rggb(gt, pattern)
85 |
86 | denoised = forward_patches(model, noisy)
87 |
88 | psnr = peak_signal_noise_ratio(gt, denoised, data_range=1)
89 | psnr_list.append(psnr)
90 | print('psnr %.2f' % psnr)
91 |
92 | print('Camera %s, average PSNR %.2f' % (camera, np.mean(psnr_list)))
93 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import h5py
2 | import time
3 | import torch
4 |
5 | import scipy.io as sio
6 |
7 | import numpy as np
8 |
9 |
10 | def open_hdf5(filename):
11 | while True:
12 | try:
13 | hdf5_file = h5py.File(filename, 'r')
14 | return hdf5_file
15 | except OSError:
16 | print(filename, ' waiting')
17 | time.sleep(3) # Wait a bit
18 |
19 |
20 | def extract_metainfo(path='0151_METADATA_RAW_010.MAT'):
21 | meta = sio.loadmat(path)['metadata']
22 | mat_vals = meta[0][0]
23 | mat_keys = mat_vals.dtype.descr
24 |
25 | keys = []
26 | for item in mat_keys:
27 | keys.append(item[0])
28 |
29 | py_dict = {}
30 | for key in keys:
31 | py_dict[key] = mat_vals[key]
32 |
33 | device = py_dict['Model'][0].lower()
34 | bitDepth = py_dict['BitDepth'][0][0]
35 | if 'iphone' in device or bitDepth != 16:
36 | noise = py_dict['UnknownTags'][-2][0][-1][0][:2]
37 | iso = py_dict['DigitalCamera'][0, 0]['ISOSpeedRatings'][0][0]
38 | pattern = py_dict['SubIFDs'][0][0]['UnknownTags'][0][0][1][0][-1][0]
39 | time = py_dict['DigitalCamera'][0, 0]['ExposureTime'][0][0]
40 |
41 | else:
42 | noise = py_dict['UnknownTags'][-1][0][-1][0][:2]
43 | iso = py_dict['ISOSpeedRatings'][0][0]
44 | pattern = py_dict['UnknownTags'][1][0][-1][0]
45 | time = py_dict['ExposureTime'][0][0] # the 0th row and 0th line item
46 |
47 | rgb = ['R', 'G', 'B']
48 | pattern = ''.join([rgb[i] for i in pattern])
49 |
50 | asShotNeutral = py_dict['AsShotNeutral'][0]
51 | b_gain, _, r_gain = asShotNeutral
52 |
53 | # only load ccm1
54 | ccm = py_dict['ColorMatrix1'][0].astype(float).reshape((3, 3))
55 |
56 | return {'device': device,
57 | 'pattern': pattern,
58 | 'iso': iso,
59 | 'noise': noise,
60 | 'time': time,
61 | 'wb': np.array([r_gain, 1, b_gain]),
62 | 'ccm': ccm, }
63 |
64 |
65 | def transform_to_rggb(img, pattern):
66 | assert len(img.shape) == 2 and type(img) == np.ndarray
67 |
68 | if pattern.lower() == 'bggr': # same pattern
69 | img = np.roll(np.roll(img, 1, axis=1), 1, axis=0)
70 | elif pattern.lower() == 'rggb':
71 | pass
72 | elif pattern.lower() == 'grbg':
73 | img = np.roll(img, 1, axis=1)
74 | elif pattern.lower() == 'gbrg':
75 | img = np.roll(img, 1, axis=0)
76 | else:
77 | assert 'no support'
78 |
79 | return img
80 |
81 |
82 | def raw2stack(var):
83 | h, w = var.shape
84 | if var.is_cuda:
85 | res = torch.cuda.FloatTensor(4, h // 2, w // 2).fill_(0)
86 | else:
87 | res = torch.FloatTensor(4, h // 2, w // 2).fill_(0)
88 | res[0] = var[0::2, 0::2]
89 | res[1] = var[0::2, 1::2]
90 | res[2] = var[1::2, 0::2]
91 | res[3] = var[1::2, 1::2]
92 | return res
93 |
94 |
95 | def stack2raw(var):
96 | _, h, w = var.shape
97 | if var.is_cuda:
98 | res = torch.cuda.FloatTensor(h * 2, w * 2)
99 | else:
100 | res = torch.FloatTensor(h * 2, w * 2)
101 | res[0::2, 0::2] = var[0]
102 | res[0::2, 1::2] = var[1]
103 | res[1::2, 0::2] = var[2]
104 | res[1::2, 1::2] = var[3]
105 | return res
106 |
--------------------------------------------------------------------------------