├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── README.md ├── model.py ├── synthesize.py ├── test.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .*/ 2 | _*/ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Zhang Yi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Rethinking Noise Synthesis and Modeling in Raw Denoising (ICCV2021) 2 | --- 3 | [Yi Zhang](https://zhangyi-3.github.io/)1, 4 | [Hongwei Qin](https://scholar.google.com/citations?user=ZGM7HfgAAAAJ&hl=en)2, 5 | [Xiaogang Wang](https://scholar.google.com/citations?user=-B5JgjsAAAAJ&hl=zh-CN)1, 6 | [Hongsheng Li](https://www.ee.cuhk.edu.hk/~hsli/)1
7 | 8 | 1CUHK-SenseTime Joint Lab, 2SenseTime Research 9 | 10 | 11 | 12 | ### Abstract 13 | 14 | >The lack of large-scale real raw image denoising dataset gives rise to challenges on synthesizing 15 | realistic raw image noise for training denoising models.However, the real raw image noise is 16 | contributed by many noise sources and varies greatly among different sensors. 17 | Existing methods are unable to model all noise sources accurately, and building a noise model 18 | for each sensor is also laborious. In this paper, we introduce a new perspective to synthesize 19 | noise by directly sampling from the sensor's real noise.It inherently generates accurate raw image 20 | noise for different camera sensors. Two efficient and generic techniques: pattern-aligned patch 21 | sampling and high-bit reconstruction help accurate synthesis of spatial-correlated noise and high-bit noise respectively. We conduct systematic experiments on SIDD and ELD datasets. 22 | The results show that (1) our method outperforms existing methods and demonstrates wide 23 | generalization on different sensors and lighting conditions. (2) Recent conclusions derived from 24 | DNN-based noise modeling methods are actually based on inaccurate noise parameters. 25 | The DNN-based methods still cannot outperform physics-based statistical methods. 26 | 27 | ### Testing 28 | The code has been tested with the following environment: 29 | ``` 30 | pytorch == 1.5.0 31 | scikit-image == 0.16.2 32 | scipy == 1.3.1 33 | h5py 2.10.0 34 | ``` 35 | 36 | - Prepare the [SIDD-Medium Dataset](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php) dataset. 37 | - Download the [pretrained models](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155135732_link_cuhk_edu_hk/Egb3x2YO-qBBgQ41N8WiCIUBRQuxb4gWsV_Ml1yLfDti9w?e=wRwC7e) and put them into the checkpoints folder. 38 | - Modify the default root of the SIDD dataset and run the command with the specific camera name (s6 | ip | gp) 39 | ``` 40 | python -u test.py --root SIDD_Medium/Data --camera s6 41 | ``` 42 | 43 | 44 | ### Noise parameters 45 | The calibrated noise parameters for the SIDD dataset and the ELD dataset can be found in ``synthesize.py``. 46 | 47 | 48 | ### Citation 49 | ``` bibtex 50 | @InProceedings{zhang2021rethinking, 51 | author = {Zhang, Yi and Qin, Hongwei and Wang, Xiaogang and Li, Hongsheng}, 52 | title = {Rethinking Noise Synthesis and Modeling in Raw Denoising}, 53 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 54 | month = {October}, 55 | year = {2021}, 56 | pages = {4593-4601} 57 | } 58 | ``` 59 | 60 | ### Contact 61 | Feel free to contact zhangyi@link.cuhk.edu.hk if you have any questions. 62 | 63 | ### Acknowledgments 64 | * [ELD](https://github.com/Vandermode/ELD) 65 | * [CA-NoiseGAN](https://github.com/arcchang1236/CA-NoiseGAN) 66 | * [simple-camera-pipeline](https://github.com/AbdoKamel/simple-camera-pipeline) -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | Put pretrained models. 2 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class UNetSeeInDark(nn.Module): 6 | def __init__(self, in_channels=4, out_channels=4): 7 | super(UNetSeeInDark, self).__init__() 8 | 9 | # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 10 | self.conv1_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1) 11 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 12 | self.pool1 = nn.MaxPool2d(kernel_size=2) 13 | 14 | self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) 15 | self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 16 | self.pool2 = nn.MaxPool2d(kernel_size=2) 17 | 18 | self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 19 | self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 20 | self.pool3 = nn.MaxPool2d(kernel_size=2) 21 | 22 | self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 23 | self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 24 | self.pool4 = nn.MaxPool2d(kernel_size=2) 25 | 26 | self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 27 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 28 | 29 | self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2) 30 | self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1) 31 | self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 32 | 33 | self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2) 34 | self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) 35 | self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 36 | 37 | self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2) 38 | self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1) 39 | self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 40 | 41 | self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2) 42 | self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) 43 | self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) 44 | 45 | self.conv10_1 = nn.Conv2d(32, out_channels, kernel_size=1, stride=1) 46 | 47 | def forward(self, x): 48 | conv1 = self.lrelu(self.conv1_1(x)) 49 | conv1 = self.lrelu(self.conv1_2(conv1)) 50 | pool1 = self.pool1(conv1) 51 | 52 | conv2 = self.lrelu(self.conv2_1(pool1)) 53 | conv2 = self.lrelu(self.conv2_2(conv2)) 54 | pool2 = self.pool1(conv2) 55 | 56 | conv3 = self.lrelu(self.conv3_1(pool2)) 57 | conv3 = self.lrelu(self.conv3_2(conv3)) 58 | pool3 = self.pool1(conv3) 59 | 60 | conv4 = self.lrelu(self.conv4_1(pool3)) 61 | conv4 = self.lrelu(self.conv4_2(conv4)) 62 | pool4 = self.pool1(conv4) 63 | 64 | conv5 = self.lrelu(self.conv5_1(pool4)) 65 | conv5 = self.lrelu(self.conv5_2(conv5)) 66 | 67 | up6 = self.upv6(conv5) 68 | up6 = torch.cat([up6, conv4], 1) 69 | conv6 = self.lrelu(self.conv6_1(up6)) 70 | conv6 = self.lrelu(self.conv6_2(conv6)) 71 | 72 | up7 = self.upv7(conv6) 73 | up7 = torch.cat([up7, conv3], 1) 74 | conv7 = self.lrelu(self.conv7_1(up7)) 75 | conv7 = self.lrelu(self.conv7_2(conv7)) 76 | 77 | up8 = self.upv8(conv7) 78 | up8 = torch.cat([up8, conv2], 1) 79 | conv8 = self.lrelu(self.conv8_1(up8)) 80 | conv8 = self.lrelu(self.conv8_2(conv8)) 81 | 82 | up9 = self.upv9(conv8) 83 | up9 = torch.cat([up9, conv1], 1) 84 | conv9 = self.lrelu(self.conv9_1(up9)) 85 | conv9 = self.lrelu(self.conv9_2(conv9)) 86 | 87 | conv10 = self.conv10_1(conv9) 88 | # out = nn.functional.pixel_shuffle(conv10, 2) 89 | out = conv10 90 | return out 91 | 92 | def _initialize_weights(self): 93 | for m in self.modules(): 94 | if isinstance(m, nn.Conv2d): 95 | m.weight.data.normal_(0.0, 0.02) 96 | if m.bias is not None: 97 | m.bias.data.normal_(0.0, 0.02) 98 | if isinstance(m, nn.ConvTranspose2d): 99 | m.weight.data.normal_(0.0, 0.02) 100 | 101 | def lrelu(self, x): 102 | outt = torch.max(0.2 * x, x) 103 | return outt 104 | -------------------------------------------------------------------------------- /synthesize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions.poisson import Poisson 3 | 4 | import numpy as np 5 | 6 | 7 | def generate_poisson_(y, k=1): 8 | y = torch.poisson(y / k) * k 9 | return y 10 | 11 | 12 | def generate_read_noise(shape, noise_type, scale, loc=0): 13 | noise_type = noise_type.lower() 14 | if noise_type == 'norm': 15 | read = torch.FloatTensor(shape).normal_(loc, scale) 16 | else: 17 | raise NotImplementedError('Read noise type error.') 18 | return read 19 | 20 | 21 | def noise_profiles(camera): 22 | camera = camera.lower() 23 | if camera == 'ip': # iPhone 24 | iso_set = [100, 200, 400, 800, 1600, 2000] 25 | cshot = [0.00093595, 0.00104404, 0.00116461, 0.00129911, 0.00144915, 0.00150104] 26 | cread = [4.697713410870357e-07, 6.904488905478659e-07, 6.739473744228789e-07, 27 | 6.776787431555864e-07, 6.781983208034481e-07, 6.783184262356993e-07] 28 | elif camera == 's6': # Sumsung s6 edge 29 | iso_set = [100, 200, 400, 800, 1600, 3200] 30 | cshot = [0.00162521, 0.00256175, 0.00403799, 0.00636492, 0.01003277, 0.01581424] 31 | cread = [1.1792188420255036e-06, 1.607602896683437e-06, 2.9872611575167216e-06, 32 | 5.19157563906707e-06, 1.0011034196248119e-05, 2.0652668477786836e-05] 33 | elif camera == 'gp': # Google Pixel 34 | iso_set = [100, 200, 400, 800, 1600, 3200, 6400] 35 | cshot = [0.00024718, 0.00048489, 0.00095121, 0.001866, 0.00366055, 0.00718092, 0.01408686] 36 | cread = [1.6819349659429324e-06, 2.0556981890860545e-06, 2.703070976302046e-06, 37 | 4.116405515789963e-06, 7.569256436438246e-06, 1.5199001098203388e-05, 5.331422827048082e-05] 38 | elif camera == 'sony': # Sony a7s2 39 | iso_set = [800, 1600, 3200] 40 | cshot = [1.0028880020069384, 1.804521362114003, 3.246920234173119] 41 | cread = [4.053034401667052, 6.692229120425673, 4.283115294604881] 42 | elif camera == 'nikon': # Nikon D850 43 | iso_set = [800, 1600, 3200] 44 | cshot = [3.355988883536526, 6.688199969242411, 13.32901281288985] 45 | cread = [4.4959735547955635, 8.360429952584846, 15.684213053647735] 46 | else: 47 | assert NotImplementedError 48 | return iso_set, cshot, cread 49 | 50 | 51 | def pg_noise_demo(clean_tensor, camera='IP'): 52 | iso_set, k_set, read_scale_set = noise_profiles(camera) 53 | 54 | # sample randomly 55 | i = np.random.choice(len(k_set)) 56 | k, read_scale = k_set[i], read_scale_set[i] 57 | 58 | noisy_shot = generate_poisson_(clean_tensor, k) 59 | read_noise = generate_read_noise(clean_tensor.shape, noise_type='norm', scale=read_scale) 60 | noisy = noisy_shot + read_noise 61 | return noisy 62 | 63 | 64 | if __name__ == '__main__': 65 | clean = torch.randn(48, 48).clamp(0, 1) 66 | noisy = pg_noise_demo(clean, camera='ip') 67 | print(noisy.shape, noisy.mean()) 68 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import scipy.io as sio 8 | 9 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 10 | 11 | import utils 12 | from model import UNetSeeInDark 13 | 14 | 15 | def forward_patches(model, noisy, patch_size=256 * 3, pad=32): 16 | shift = patch_size - pad * 2 17 | 18 | noisy = torch.FloatTensor(noisy).cuda() 19 | noisy = utils.raw2stack(noisy).unsqueeze(0) 20 | noisy = F.pad(noisy, (pad, pad, pad, pad), mode='reflect') 21 | denoised = torch.zeros_like(noisy) 22 | 23 | _, _, H, W = noisy.shape 24 | for i in np.arange(0, H, shift): 25 | for j in np.arange(0, W, shift): 26 | h_end, w_end = min(i + patch_size, H), min(j + patch_size, W) 27 | h_start, w_start = h_end - patch_size, w_end - patch_size 28 | 29 | input_var = noisy[..., h_start: h_end, w_start: w_end] 30 | with torch.no_grad(): 31 | out_var = model(input_var) 32 | denoised[..., h_start + pad: h_end - pad, w_start + pad: w_end - pad] = \ 33 | out_var[..., pad:-pad, pad:-pad] 34 | 35 | denoised = denoised[..., pad:-pad, pad:-pad] 36 | denoised = utils.stack2raw(denoised[0]).detach().cpu().numpy() 37 | 38 | denoised = denoised.clip(0, 1) 39 | return denoised 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument('--root', default='/mnt/lustre/zhangyi3/data/SIDD_Medium/Data/') 45 | parser.add_argument('--camera', choices=['s6', 'gp', 'ip'], required=True, help='camera name') 46 | args = parser.parse_args() 47 | 48 | camera = args.camera 49 | root = args.root 50 | 51 | # save_dir = './results/' + camera 52 | # if not os.path.exists(save_dir): 53 | # os.makedirs(save_dir) 54 | print('test', camera, 'root', root) 55 | 56 | test_data_list = [item for item in os.listdir(root) if int(item.split('_')[1]) in [2, 3, 5] and camera in item.lower()] 57 | 58 | # build model 59 | model = UNetSeeInDark() 60 | model = model.cuda() 61 | model = torch.nn.DataParallel(model) 62 | 63 | model_path = './checkpoints/%s.pth' % camera.lower() 64 | model.load_state_dict(torch.load(model_path, map_location='cpu')) 65 | 66 | psnr_list = [] 67 | for idx, item in enumerate(test_data_list): 68 | head = item[:4] 69 | for tail in ['GT_RAW_010', 'GT_RAW_011']: 70 | print('processing', idx, item, tail, end=' ') 71 | mat = utils.open_hdf5(os.path.join(root, item, '%s_%s.MAT' % (head, tail))) 72 | gt = np.array(mat['x'], dtype=np.float32) 73 | mat = utils.open_hdf5(os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'NOISY')))) 74 | noisy = np.array(mat['x'], dtype=np.float32) 75 | 76 | meta = sio.loadmat(os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'METADATA')))) 77 | meta = meta['metadata'][0][0] 78 | 79 | # transform to rggb pattern 80 | py_meta = utils.extract_metainfo( 81 | os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'METADATA')))) 82 | pattern = py_meta['pattern'] 83 | noisy = utils.transform_to_rggb(noisy, pattern) 84 | gt = utils.transform_to_rggb(gt, pattern) 85 | 86 | denoised = forward_patches(model, noisy) 87 | 88 | psnr = peak_signal_noise_ratio(gt, denoised, data_range=1) 89 | psnr_list.append(psnr) 90 | print('psnr %.2f' % psnr) 91 | 92 | print('Camera %s, average PSNR %.2f' % (camera, np.mean(psnr_list))) 93 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import time 3 | import torch 4 | 5 | import scipy.io as sio 6 | 7 | import numpy as np 8 | 9 | 10 | def open_hdf5(filename): 11 | while True: 12 | try: 13 | hdf5_file = h5py.File(filename, 'r') 14 | return hdf5_file 15 | except OSError: 16 | print(filename, ' waiting') 17 | time.sleep(3) # Wait a bit 18 | 19 | 20 | def extract_metainfo(path='0151_METADATA_RAW_010.MAT'): 21 | meta = sio.loadmat(path)['metadata'] 22 | mat_vals = meta[0][0] 23 | mat_keys = mat_vals.dtype.descr 24 | 25 | keys = [] 26 | for item in mat_keys: 27 | keys.append(item[0]) 28 | 29 | py_dict = {} 30 | for key in keys: 31 | py_dict[key] = mat_vals[key] 32 | 33 | device = py_dict['Model'][0].lower() 34 | bitDepth = py_dict['BitDepth'][0][0] 35 | if 'iphone' in device or bitDepth != 16: 36 | noise = py_dict['UnknownTags'][-2][0][-1][0][:2] 37 | iso = py_dict['DigitalCamera'][0, 0]['ISOSpeedRatings'][0][0] 38 | pattern = py_dict['SubIFDs'][0][0]['UnknownTags'][0][0][1][0][-1][0] 39 | time = py_dict['DigitalCamera'][0, 0]['ExposureTime'][0][0] 40 | 41 | else: 42 | noise = py_dict['UnknownTags'][-1][0][-1][0][:2] 43 | iso = py_dict['ISOSpeedRatings'][0][0] 44 | pattern = py_dict['UnknownTags'][1][0][-1][0] 45 | time = py_dict['ExposureTime'][0][0] # the 0th row and 0th line item 46 | 47 | rgb = ['R', 'G', 'B'] 48 | pattern = ''.join([rgb[i] for i in pattern]) 49 | 50 | asShotNeutral = py_dict['AsShotNeutral'][0] 51 | b_gain, _, r_gain = asShotNeutral 52 | 53 | # only load ccm1 54 | ccm = py_dict['ColorMatrix1'][0].astype(float).reshape((3, 3)) 55 | 56 | return {'device': device, 57 | 'pattern': pattern, 58 | 'iso': iso, 59 | 'noise': noise, 60 | 'time': time, 61 | 'wb': np.array([r_gain, 1, b_gain]), 62 | 'ccm': ccm, } 63 | 64 | 65 | def transform_to_rggb(img, pattern): 66 | assert len(img.shape) == 2 and type(img) == np.ndarray 67 | 68 | if pattern.lower() == 'bggr': # same pattern 69 | img = np.roll(np.roll(img, 1, axis=1), 1, axis=0) 70 | elif pattern.lower() == 'rggb': 71 | pass 72 | elif pattern.lower() == 'grbg': 73 | img = np.roll(img, 1, axis=1) 74 | elif pattern.lower() == 'gbrg': 75 | img = np.roll(img, 1, axis=0) 76 | else: 77 | assert 'no support' 78 | 79 | return img 80 | 81 | 82 | def raw2stack(var): 83 | h, w = var.shape 84 | if var.is_cuda: 85 | res = torch.cuda.FloatTensor(4, h // 2, w // 2).fill_(0) 86 | else: 87 | res = torch.FloatTensor(4, h // 2, w // 2).fill_(0) 88 | res[0] = var[0::2, 0::2] 89 | res[1] = var[0::2, 1::2] 90 | res[2] = var[1::2, 0::2] 91 | res[3] = var[1::2, 1::2] 92 | return res 93 | 94 | 95 | def stack2raw(var): 96 | _, h, w = var.shape 97 | if var.is_cuda: 98 | res = torch.cuda.FloatTensor(h * 2, w * 2) 99 | else: 100 | res = torch.FloatTensor(h * 2, w * 2) 101 | res[0::2, 0::2] = var[0] 102 | res[0::2, 1::2] = var[1] 103 | res[1::2, 0::2] = var[2] 104 | res[1::2, 1::2] = var[3] 105 | return res 106 | --------------------------------------------------------------------------------