├── .gitignore ├── LICENSE ├── README.md ├── checkpoint ├── all │ └── .gitkeep ├── real │ └── .gitkeep └── synthetic │ └── .gitkeep ├── imgs ├── CBDNet_v13.png ├── folder.png └── results.png ├── matdata ├── 201_CRF_data.mat ├── 201_CRF_iCRF_function.mat ├── dorfCurvesInv.mat └── intermosaic.mat ├── model.py ├── test.py ├── train_all.py ├── train_real.py ├── train_syn.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Add by user 2 | .vscode/ 3 | result/ 4 | dataset/ 5 | checkpoint/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # celery beat schedule file 90 | celerybeat-schedule 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | .dmypy.json 117 | dmypy.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 IDKiro 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CBDNet-tensorflow 2 | 3 | An unofficial implementation of CBDNet by Tensorflow. 4 | 5 | [CBDNet in MATLAB](https://github.com/GuoShi28/CBDNet) 6 | 7 | [CBDNet in PyTorch](https://github.com/IDKiro/CBDNet-pytorch) 8 | 9 | ## Quick Start 10 | 11 | ### Data 12 | 13 | Download the dataset and pre-trained model: 14 | [[OneDrive](https://zjueducn-my.sharepoint.com/:f:/g/personal/3140103306_zju_edu_cn/EorD2T0_OHNEu_5rH6IpdzYB0l3SM9IfmyxWhHjyfVfFJA?e=YL4V99)] 15 | [[Baidu Pan](https://pan.baidu.com/s/1ObvekJcPhtK9RUOC86vmNA) (8ko0)] 16 | [[Mega](https://mega.nz/#F!uOZEVAYR!fbf-RCtnbUR7mlHZsgiL5g)] 17 | 18 | Extract the files to `dataset` folder and `checkpoint` folder as follow: 19 | 20 | ![](imgs/folder.png) 21 | 22 | ### Train 23 | 24 | Train the model on synthetic noisy images: 25 | 26 | ``` 27 | python train_syn.py 28 | ``` 29 | 30 | Train the model on real noisy images: 31 | 32 | ``` 33 | python train_real.py 34 | ``` 35 | 36 | Train the model on synthetic noisy images and real noisy images: 37 | 38 | ``` 39 | python train_all.py 40 | ``` 41 | 42 | **In order to reduce the time to read the images, it will save all the images in memory which requires large memory.** 43 | 44 | ### Test 45 | 46 | Test the trained model on DND dataset: 47 | 48 | ``` 49 | python test.py 50 | ``` 51 | 52 | Optional: 53 | 54 | ``` 55 | --ckpt {all,real,synthetic} checkpoint type 56 | --cpu [CPU] Use CPU 57 | ``` 58 | 59 | Example: 60 | 61 | ``` 62 | python test.py --ckpt synthetic --cpu 63 | ``` 64 | 65 | ## Network Structure 66 | 67 | ![Image of Network](imgs/CBDNet_v13.png) 68 | 69 | ## Realistic Noise Model 70 | Given a clean image `x`, the realistic noise model can be represented as: 71 | 72 | ![](http://latex.codecogs.com/gif.latex?\\textbf{y}=f(\\textbf{DM}(\\textbf{L}+n(\\textbf{L})))) 73 | 74 | ![](http://latex.codecogs.com/gif.latex?n(\\textbf{L})=n_s(\\textbf{L})+n_c) 75 | 76 | Where `y` is the noisy image, `f(.)` is the CRF function and the irradiance ![](http://latex.codecogs.com/gif.latex?\\textbf{L}=\\textbf{M}f^{-1}(\\textbf{x})) , `M(.)` represents the function that convert sRGB image to Bayer image and `DM(.)` represents the demosaicing function. 77 | 78 | If considering denosing on compressed images, 79 | 80 | ![](http://latex.codecogs.com/gif.latex?\\textbf{y}=JPEG(f(\\textbf{DM}(\\textbf{L}+n(\\textbf{L}))))) 81 | 82 | ## Result 83 | 84 | ![](imgs/results.png) 85 | -------------------------------------------------------------------------------- /checkpoint/all/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/checkpoint/all/.gitkeep -------------------------------------------------------------------------------- /checkpoint/real/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/checkpoint/real/.gitkeep -------------------------------------------------------------------------------- /checkpoint/synthetic/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/checkpoint/synthetic/.gitkeep -------------------------------------------------------------------------------- /imgs/CBDNet_v13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/imgs/CBDNet_v13.png -------------------------------------------------------------------------------- /imgs/folder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/imgs/folder.png -------------------------------------------------------------------------------- /imgs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/imgs/results.png -------------------------------------------------------------------------------- /matdata/201_CRF_data.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/matdata/201_CRF_data.mat -------------------------------------------------------------------------------- /matdata/201_CRF_iCRF_function.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/matdata/201_CRF_iCRF_function.mat -------------------------------------------------------------------------------- /matdata/dorfCurvesInv.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/matdata/dorfCurvesInv.mat -------------------------------------------------------------------------------- /matdata/intermosaic.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDKiro/CBDNet-tensorflow/591431f62bff77c9c1dc3cb950b124663bc1e97c/matdata/intermosaic.mat -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib.slim as slim 3 | 4 | def upsample_and_sum(x1, x2, output_channels, in_channels, scope=None): 5 | pool_size = 2 6 | deconv_filter = tf.Variable(tf.truncated_normal([pool_size, pool_size, output_channels, in_channels], stddev=0.02), name=scope) 7 | deconv = tf.nn.conv2d_transpose(x1, deconv_filter, tf.shape(x2), strides=[1, pool_size, pool_size, 1]) 8 | 9 | deconv_output = deconv + x2 10 | deconv_output.set_shape([None, None, None, output_channels]) 11 | 12 | return deconv_output 13 | 14 | def FCN(input): 15 | with tf.variable_scope('fcn'): 16 | conv1 = slim.conv2d(input, 32, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv1') 17 | conv2 = slim.conv2d(conv1, 32, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv2') 18 | conv3 = slim.conv2d(conv2, 32, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv3') 19 | conv4 = slim.conv2d(conv3, 32, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv4') 20 | conv5 = slim.conv2d(conv4, 3, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv5') 21 | return conv5 22 | 23 | 24 | def UNet(input): 25 | with tf.variable_scope('unet'): 26 | conv1 = slim.conv2d(input, 64, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv1_1') 27 | conv1 = slim.conv2d(conv1, 64, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv1_2') 28 | 29 | pool1 = slim.avg_pool2d(conv1, [2, 2], padding='SAME') 30 | conv2 = slim.conv2d(pool1, 128, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv2_1') 31 | conv2 = slim.conv2d(conv2, 128, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv2_2') 32 | conv2 = slim.conv2d(conv2, 128, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv2_3') 33 | 34 | pool2 = slim.avg_pool2d(conv2, [2, 2], padding='SAME') 35 | conv3 = slim.conv2d(pool2, 256, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv3_1') 36 | conv3 = slim.conv2d(conv3, 256, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv3_2') 37 | conv3 = slim.conv2d(conv3, 256, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv3_3') 38 | conv3 = slim.conv2d(conv3, 256, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv3_4') 39 | conv3 = slim.conv2d(conv3, 256, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv3_5') 40 | conv3 = slim.conv2d(conv3, 256, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv3_6') 41 | 42 | up4 = upsample_and_sum(conv3, conv2, 128, 256, scope='deconv4') 43 | conv4 = slim.conv2d(up4, 128, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv4_1') 44 | conv4 = slim.conv2d(conv4, 128, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv4_2') 45 | conv4 = slim.conv2d(conv4, 128, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv4_3') 46 | 47 | up5 = upsample_and_sum(conv4, conv1, 64, 128, scope='deconv5') 48 | conv5 = slim.conv2d(up5, 64, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv5_1') 49 | conv5 = slim.conv2d(conv5, 64, [3, 3], rate=1, activation_fn=tf.nn.relu, scope='conv5_2') 50 | 51 | out = slim.conv2d(conv5, 3, [1, 1], rate=1, activation_fn=None, scope='conv6') 52 | 53 | return out 54 | 55 | def CBDNet(input): 56 | noise_level = FCN(input) 57 | 58 | concat_img = tf.concat([input, noise_level], 3) 59 | 60 | out = UNet(concat_img) + input 61 | 62 | return noise_level, out -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | import os, time, scipy.io 4 | import argparse 5 | import tensorflow as tf 6 | from tensorflow.contrib.layers.python.layers import initializers 7 | import numpy as np 8 | import glob 9 | import re 10 | import cv2 11 | 12 | from model import * 13 | 14 | parser = argparse.ArgumentParser(description='Testing on DND dataset') 15 | parser.add_argument('--ckpt', type=str, default='all', 16 | choices=['all', 'real', 'synthetic'], help='checkpoint type') 17 | parser.add_argument('--cpu', nargs='?', const=1, help = 'Use CPU') 18 | args = parser.parse_args() 19 | 20 | if not args.cpu: 21 | print('Using GPU!') 22 | else: 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Inference on CPU 24 | print('Using CPU!') 25 | 26 | input_dir = './dataset/test/' 27 | checkpoint_dir = './checkpoint/' + args.ckpt 28 | result_dir = './result/' 29 | 30 | test_fns = glob.glob(input_dir + '*.bmp') 31 | 32 | # model setting 33 | in_image = tf.placeholder(tf.float32, [None, None, None, 3]) 34 | _, out_image = CBDNet(in_image) 35 | 36 | # load model 37 | sess = tf.Session() 38 | sess.run(tf.global_variables_initializer()) 39 | 40 | saver = tf.train.Saver() 41 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 42 | if ckpt: 43 | print('loaded', checkpoint_dir) 44 | saver.restore(sess, ckpt.model_checkpoint_path) 45 | 46 | if not os.path.isdir(result_dir + 'test/'): 47 | os.makedirs(result_dir + 'test/') 48 | 49 | for ind, test_fn in enumerate(test_fns): 50 | print(test_fn) 51 | noisy_img = cv2.imread(test_fn) 52 | noisy_img = noisy_img[:,:,::-1] / 255.0 53 | noisy_img = np.array(noisy_img).astype('float32') 54 | temp_noisy_img = np.expand_dims(noisy_img, axis=0) 55 | 56 | output = sess.run(out_image, feed_dict={in_image:temp_noisy_img}) 57 | output = np.clip(output, 0, 1) 58 | 59 | temp = np.concatenate((temp_noisy_img[0, :, :, :], output[0, :, :, :]), axis=1) 60 | scipy.misc.toimage(temp*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + 'test/test_%d.jpg'%(ind)) 61 | 62 | -------------------------------------------------------------------------------- /train_all.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | import os, time, scipy.io 4 | import tensorflow as tf 5 | from tensorflow.contrib.layers.python.layers import initializers 6 | import numpy as np 7 | import glob 8 | import re 9 | 10 | from utils import * 11 | from model import * 12 | 13 | 14 | def load_CRF(): 15 | CRF = scipy.io.loadmat('matdata/201_CRF_data.mat') 16 | iCRF = scipy.io.loadmat('matdata/dorfCurvesInv.mat') 17 | B_gl = CRF['B'] 18 | I_gl = CRF['I'] 19 | B_inv_gl = iCRF['invB'] 20 | I_inv_gl = iCRF['invI'] 21 | 22 | if os.path.exists('matdata/201_CRF_iCRF_function.mat')==0: 23 | CRF_para = np.array(CRF_function_transfer(I_gl, B_gl)) 24 | iCRF_para = 1. / CRF_para 25 | scipy.io.savemat('matdata/201_CRF_iCRF_function.mat', {'CRF':CRF_para, 'iCRF':iCRF_para}) 26 | else: 27 | Bundle = scipy.io.loadmat('matdata/201_CRF_iCRF_function.mat') 28 | CRF_para = Bundle['CRF'] 29 | iCRF_para = Bundle['iCRF'] 30 | 31 | return CRF_para, iCRF_para, I_gl, B_gl, I_inv_gl, B_inv_gl 32 | 33 | 34 | if __name__ == '__main__': 35 | input_syn_dir = './dataset/synthetic/' 36 | input_real_dir = './dataset/real/' 37 | checkpoint_dir = './checkpoint/all/' 38 | result_dir = './result/all/' 39 | 40 | PS = 512 # patch size, if your GPU memory is not enough, modify it 41 | REAPET = 10 42 | save_freq = 100 43 | 44 | CRF_para, iCRF_para, I_gl, B_gl, I_inv_gl, B_inv_gl = load_CRF() 45 | 46 | train_syn_fns = glob.glob(input_syn_dir + '*.bmp') 47 | train_real_fns = glob.glob(input_real_dir + 'Batch_*') 48 | 49 | origin_syn_imgs = [None] * len(train_syn_fns) 50 | noise_syn_imgs = [None] * len(train_syn_fns) 51 | noise_syn_levels = [None] * len(train_syn_fns) 52 | 53 | origin_real_imgs = [None] * len(train_real_fns) 54 | noise_real_imgs = [None] * len(train_real_fns) 55 | 56 | for i in range(len(train_syn_fns)): 57 | origin_syn_imgs[i] = [] 58 | noise_syn_imgs[i] = [] 59 | noise_syn_levels[i] = [] 60 | 61 | for i in range(len(train_real_fns)): 62 | origin_real_imgs[i] = [] 63 | noise_real_imgs[i] = [] 64 | 65 | # model setting 66 | in_image = tf.placeholder(tf.float32, [None, None, None, 3]) 67 | gt_image = tf.placeholder(tf.float32, [None, None, None, 3]) 68 | gt_noise = tf.placeholder(tf.float32, [None, None, None, 3]) 69 | 70 | est_noise, out_image = CBDNet(in_image) 71 | 72 | if_asym = tf.placeholder(tf.float32) 73 | G_loss = tf.losses.mean_squared_error(gt_image, out_image) + \ 74 | if_asym * 0.5 * tf.reduce_mean(tf.multiply(tf.abs(0.3 - tf.nn.relu(gt_noise - est_noise)), tf.square(est_noise - gt_noise))) + \ 75 | 0.05 * tf.reduce_mean(tf.square(tf.image.image_gradients(est_noise))) 76 | 77 | lr = tf.placeholder(tf.float32) 78 | 79 | G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss) 80 | 81 | # load model 82 | sess = tf.Session() 83 | sess.run(tf.global_variables_initializer()) 84 | 85 | save_vars = [v for v in tf.global_variables() if (v.name.split('/')[0] == 'fcn' or v.name.split('/')[0] == 'unet')] 86 | saver = tf.train.Saver(var_list=save_vars) 87 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 88 | if ckpt: 89 | print('loaded', checkpoint_dir) 90 | saver.restore(sess, ckpt.model_checkpoint_path) 91 | 92 | allpoint = glob.glob(checkpoint_dir+'epoch-*') 93 | lastepoch = 0 94 | for point in allpoint: 95 | cur_epoch = re.findall(r'epoch-(\d+)', point) 96 | lastepoch = np.maximum(lastepoch, int(cur_epoch[0])) 97 | 98 | learning_rate = 1e-4 99 | for epoch in range(lastepoch, 201): 100 | losses = AverageMeter() 101 | 102 | if os.path.isdir(result_dir+"%04d"%epoch): 103 | continue 104 | cnt=0 105 | 106 | if epoch > 100: 107 | learning_rate = 1e-5 108 | 109 | print('Training on synthetic noisy images...') 110 | for ind in np.random.permutation(len(train_syn_fns)): 111 | train_syn_fn = train_syn_fns[ind] 112 | 113 | if not len(origin_syn_imgs[ind]): 114 | origin_syn_img = ReadImg(train_syn_fn) 115 | origin_syn_imgs[ind] = np.expand_dims(origin_syn_img, axis=0) 116 | 117 | # re-add noise 118 | if epoch % save_freq == 0: 119 | noise_syn_imgs[ind] = [] 120 | noise_syn_levels[ind] = [] 121 | 122 | if len(noise_syn_imgs[ind]) < 1: 123 | noise_syn_img, noise_syn_level = AddRealNoise(origin_syn_imgs[ind][0, :, :, :], CRF_para, iCRF_para, I_gl, B_gl, I_inv_gl, B_inv_gl) 124 | noise_syn_imgs[ind].append(np.expand_dims(noise_syn_img, axis=0)) 125 | noise_syn_levels[ind].append(np.expand_dims(noise_syn_level, axis=0)) 126 | 127 | st = time.time() 128 | for nind in np.random.permutation(len(noise_syn_imgs[ind])): 129 | temp_origin_img = origin_syn_imgs[ind] 130 | temp_noise_img = noise_syn_imgs[ind][nind] 131 | temp_noise_level = noise_syn_levels[ind][nind] 132 | 133 | # data augmentation 134 | if np.random.randint(2, size=1)[0] == 1: 135 | temp_origin_img = np.flip(temp_origin_img, axis=1) 136 | temp_noise_img = np.flip(temp_noise_img, axis=1) 137 | temp_noise_level = np.flip(temp_noise_level, axis=1) 138 | if np.random.randint(2, size=1)[0] == 1: 139 | temp_origin_img = np.flip(temp_origin_img, axis=0) 140 | temp_noise_img = np.flip(temp_noise_img, axis=0) 141 | temp_noise_level = np.flip(temp_noise_level, axis=0) 142 | if np.random.randint(2, size=1)[0] == 1: 143 | temp_origin_img = np.transpose(temp_origin_img, (0, 2, 1, 3)) 144 | temp_noise_img = np.transpose(temp_noise_img, (0, 2, 1, 3)) 145 | temp_noise_level = np.transpose(temp_noise_level, (0, 2, 1, 3)) 146 | 147 | cnt += 1 148 | st = time.time() 149 | 150 | _, G_current, output = sess.run( 151 | [G_opt, G_loss, out_image], 152 | feed_dict={in_image:temp_noise_img, gt_image:temp_origin_img, gt_noise:temp_noise_level, lr:learning_rate, if_asym:1} 153 | ) 154 | output = np.clip(output, 0, 1) 155 | losses.update(G_current) 156 | 157 | print("%d %d Loss=%.4f Time=%.3f"%(epoch, cnt, losses.avg, time.time()-st)) 158 | 159 | if epoch % save_freq == 0: 160 | if not os.path.isdir(result_dir + '%04d'%epoch): 161 | os.makedirs(result_dir + '%04d'%epoch) 162 | 163 | temp = np.concatenate((temp_origin_img[0, :, :, :], temp_noise_img[0, :, :, :], output[0, :, :, :]), axis=1) 164 | scipy.misc.toimage(temp*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%04d/train_%d_%d.jpg'%(epoch, ind, nind)) 165 | 166 | print('Training on real noisy images...') 167 | for r in range(REAPET): 168 | for ind in np.random.permutation(len(train_real_fns)): 169 | train_real_fn = train_real_fns[ind] 170 | 171 | if not len(origin_real_imgs[ind]): 172 | train_real_origin_fns = glob.glob(train_real_fn + '/*Reference.bmp') 173 | train_real_noise_fns = glob.glob(train_real_fn + '/*Noisy.bmp') 174 | 175 | origin_real_img = ReadImg(train_real_origin_fns[0]) 176 | origin_real_imgs[ind] = np.expand_dims(origin_real_img, axis=0) 177 | 178 | for train_real_noise_fn in train_real_noise_fns: 179 | noise_real_img = ReadImg(train_real_noise_fn) 180 | noise_real_imgs[ind].append(np.expand_dims(noise_real_img, axis=0)) 181 | 182 | st = time.time() 183 | for nind in np.random.permutation(len(noise_real_imgs[ind])): 184 | H = origin_real_imgs[ind].shape[1] 185 | W = origin_real_imgs[ind].shape[2] 186 | 187 | ps_temp = min(H, W, PS) - 1 188 | 189 | xx = np.random.randint(0, W-ps_temp) 190 | yy = np.random.randint(0, H-ps_temp) 191 | 192 | temp_origin_img = origin_real_imgs[ind][:, yy:yy+ps_temp, xx:xx+ps_temp, :] 193 | temp_noise_img = noise_real_imgs[ind][nind][:, yy:yy+ps_temp, xx:xx+ps_temp, :] 194 | 195 | if np.random.randint(2, size=1)[0] == 1: 196 | temp_origin_img = np.flip(temp_origin_img, axis=1) 197 | temp_noise_img = np.flip(temp_noise_img, axis=1) 198 | if np.random.randint(2, size=1)[0] == 1: 199 | temp_origin_img = np.flip(temp_origin_img, axis=0) 200 | temp_noise_img = np.flip(temp_noise_img, axis=0) 201 | if np.random.randint(2, size=1)[0] == 1: 202 | temp_origin_img = np.transpose(temp_origin_img, (0, 2, 1, 3)) 203 | temp_noise_img = np.transpose(temp_noise_img, (0, 2, 1, 3)) 204 | 205 | cnt += 1 206 | st = time.time() 207 | 208 | _, G_current, output = sess.run( 209 | [G_opt, G_loss, out_image], 210 | feed_dict={in_image:temp_noise_img, gt_image:temp_origin_img, gt_noise:np.zeros_like(temp_origin_img), lr:learning_rate, if_asym:0} 211 | ) 212 | output = np.clip(output, 0, 1) 213 | losses.update(G_current) 214 | 215 | print("%d %d Loss=%.4f Time=%.3f"%(epoch, cnt, losses.avg, time.time()-st)) 216 | 217 | if epoch % save_freq == 0: 218 | if not os.path.isdir(result_dir + '%04d'%epoch): 219 | os.makedirs(result_dir + '%04d'%epoch) 220 | 221 | temp = np.concatenate((temp_origin_img[0, :, :, :], temp_noise_img[0, :, :, :], output[0, :, :, :]), axis=1) 222 | scipy.misc.toimage(temp*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%04d/train_%d_%d.jpg'%(epoch, ind + len(train_syn_fns) + r * len(train_real_fns), nind)) 223 | 224 | saver.save(sess, checkpoint_dir + 'model.ckpt') 225 | 226 | if not os.path.isdir(checkpoint_dir + 'epoch-' + str(epoch)): 227 | os.mkdir(checkpoint_dir + 'epoch-' + str(epoch)) 228 | 229 | if os.path.isdir(checkpoint_dir + 'epoch-' + str(epoch - 1)): 230 | os.rmdir(checkpoint_dir + 'epoch-' + str(epoch - 1)) 231 | -------------------------------------------------------------------------------- /train_real.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | import os, time, scipy.io 4 | import tensorflow as tf 5 | from tensorflow.contrib.layers.python.layers import initializers 6 | import numpy as np 7 | import glob 8 | import re 9 | 10 | from utils import * 11 | from model import * 12 | 13 | 14 | if __name__ == '__main__': 15 | input_dir = './dataset/real/' 16 | checkpoint_dir = './checkpoint/real/' 17 | result_dir = './result/real/' 18 | 19 | PS = 512 # patch size, if your GPU memory is not enough, modify it 20 | save_freq = 100 21 | 22 | train_fns = glob.glob(input_dir + 'Batch_*') 23 | 24 | origin_imgs = [None] * len(train_fns) 25 | noise_imgs = [None] * len(train_fns) 26 | 27 | for i in range(len(train_fns)): 28 | origin_imgs[i] = [] 29 | noise_imgs[i] = [] 30 | 31 | # model setting 32 | in_image = tf.placeholder(tf.float32, [None, None, None, 3]) 33 | gt_image = tf.placeholder(tf.float32, [None, None, None, 3]) 34 | 35 | est_noise, out_image = CBDNet(in_image) 36 | 37 | G_loss = tf.losses.mean_squared_error(gt_image, out_image) + \ 38 | 0.05 * tf.reduce_mean(tf.square(tf.image.image_gradients(est_noise))) 39 | 40 | lr = tf.placeholder(tf.float32) 41 | G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss) 42 | 43 | # load model 44 | sess = tf.Session() 45 | sess.run(tf.global_variables_initializer()) 46 | 47 | saver = tf.train.Saver() 48 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 49 | if ckpt: 50 | print('loaded', checkpoint_dir) 51 | saver.restore(sess, ckpt.model_checkpoint_path) 52 | 53 | allpoint = glob.glob(checkpoint_dir + 'epoch-*') 54 | lastepoch = 0 55 | for point in allpoint: 56 | cur_epoch = re.findall(r'epoch-(\d+)', point) 57 | lastepoch = np.maximum(lastepoch, int(cur_epoch[0])) 58 | 59 | learning_rate = 1e-4 60 | for epoch in range(lastepoch, 2001): 61 | losses = AverageMeter() 62 | cnt=0 63 | 64 | if epoch > 1000: 65 | learning_rate = 1e-5 66 | 67 | for ind in np.random.permutation(len(train_fns)): 68 | train_fn = train_fns[ind] 69 | 70 | if not len(origin_imgs[ind]): 71 | train_origin_fns = glob.glob(train_fn + '/*Reference.bmp') 72 | train_noise_fns = glob.glob(train_fn + '/*Noisy.bmp') 73 | 74 | origin_img = ReadImg(train_origin_fns[0]) 75 | origin_imgs[ind] = np.expand_dims(origin_img, axis=0) 76 | 77 | for train_noise_fn in train_noise_fns: 78 | noise_img = ReadImg(train_noise_fn) 79 | noise_imgs[ind].append(np.expand_dims(noise_img, axis=0)) 80 | 81 | st = time.time() 82 | for nind in np.random.permutation(len(noise_imgs[ind])): 83 | H = origin_imgs[ind].shape[1] 84 | W = origin_imgs[ind].shape[2] 85 | 86 | ps_temp = min(H, W, PS) - 1 87 | 88 | xx = np.random.randint(0, W-ps_temp) 89 | yy = np.random.randint(0, H-ps_temp) 90 | 91 | temp_origin_img = origin_imgs[ind][:, yy:yy+ps_temp, xx:xx+ps_temp, :] 92 | temp_noise_img = noise_imgs[ind][nind][:, yy:yy+ps_temp, xx:xx+ps_temp, :] 93 | 94 | if np.random.randint(2, size=1)[0] == 1: 95 | temp_origin_img = np.flip(temp_origin_img, axis=1) 96 | temp_noise_img = np.flip(temp_noise_img, axis=1) 97 | if np.random.randint(2, size=1)[0] == 1: 98 | temp_origin_img = np.flip(temp_origin_img, axis=0) 99 | temp_noise_img = np.flip(temp_noise_img, axis=0) 100 | if np.random.randint(2, size=1)[0] == 1: 101 | temp_origin_img = np.transpose(temp_origin_img, (0, 2, 1, 3)) 102 | temp_noise_img = np.transpose(temp_noise_img, (0, 2, 1, 3)) 103 | 104 | cnt += 1 105 | st = time.time() 106 | 107 | _, G_current, output = sess.run( 108 | [G_opt, G_loss, out_image], 109 | feed_dict={in_image:temp_noise_img, gt_image:temp_origin_img, lr:learning_rate} 110 | ) 111 | output = np.clip(output, 0, 1) 112 | losses.update(G_current) 113 | 114 | print("%d %d Loss=%.4f Time=%.3f"%(epoch, cnt, losses.avg, time.time()-st)) 115 | 116 | if epoch % save_freq == 0: 117 | if not os.path.isdir(result_dir + '%04d'%epoch): 118 | os.makedirs(result_dir + '%04d'%epoch) 119 | 120 | temp = np.concatenate((temp_origin_img[0, :, :, :], temp_noise_img[0, :, :, :], output[0, :, :, :]), axis=1) 121 | scipy.misc.toimage(temp*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%04d/train_%d_%d.jpg'%(epoch, ind, nind)) 122 | 123 | saver.save(sess, checkpoint_dir + 'model.ckpt') 124 | 125 | if not os.path.isdir(checkpoint_dir + 'epoch-' + str(epoch)): 126 | os.mkdir(checkpoint_dir + 'epoch-' + str(epoch)) 127 | 128 | if os.path.isdir(checkpoint_dir + 'epoch-' + str(epoch - 1)): 129 | os.rmdir(checkpoint_dir + 'epoch-' + str(epoch - 1)) 130 | -------------------------------------------------------------------------------- /train_syn.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | import os, time, scipy.io 4 | import tensorflow as tf 5 | from tensorflow.contrib.layers.python.layers import initializers 6 | import numpy as np 7 | import glob 8 | import re 9 | 10 | from utils import * 11 | from model import * 12 | 13 | 14 | def load_CRF(): 15 | CRF = scipy.io.loadmat('matdata/201_CRF_data.mat') 16 | iCRF = scipy.io.loadmat('matdata/dorfCurvesInv.mat') 17 | B_gl = CRF['B'] 18 | I_gl = CRF['I'] 19 | B_inv_gl = iCRF['invB'] 20 | I_inv_gl = iCRF['invI'] 21 | 22 | if os.path.exists('matdata/201_CRF_iCRF_function.mat')==0: 23 | CRF_para = np.array(CRF_function_transfer(I_gl, B_gl)) 24 | iCRF_para = 1. / CRF_para 25 | scipy.io.savemat('matdata/201_CRF_iCRF_function.mat', {'CRF':CRF_para, 'iCRF':iCRF_para}) 26 | else: 27 | Bundle = scipy.io.loadmat('matdata/201_CRF_iCRF_function.mat') 28 | CRF_para = Bundle['CRF'] 29 | iCRF_para = Bundle['iCRF'] 30 | 31 | return CRF_para, iCRF_para, I_gl, B_gl, I_inv_gl, B_inv_gl 32 | 33 | 34 | if __name__ == '__main__': 35 | input_dir = './dataset/synthetic/' 36 | checkpoint_dir = './checkpoint/synthetic/' 37 | result_dir = './result/synthetic/' 38 | 39 | save_freq = 100 40 | 41 | CRF_para, iCRF_para, I_gl, B_gl, I_inv_gl, B_inv_gl = load_CRF() 42 | 43 | train_fns = glob.glob(input_dir + '*.bmp') 44 | 45 | origin_imgs = [None] * len(train_fns) 46 | noise_imgs = [None] * len(train_fns) 47 | noise_levels = [None] * len(train_fns) 48 | 49 | for i in range(len(train_fns)): 50 | origin_imgs[i] = [] 51 | noise_imgs[i] = [] 52 | noise_levels[i] = [] 53 | 54 | # model setting 55 | in_image = tf.placeholder(tf.float32, [None, None, None, 3]) 56 | gt_image = tf.placeholder(tf.float32, [None, None, None, 3]) 57 | gt_noise = tf.placeholder(tf.float32, [None, None, None, 3]) 58 | 59 | est_noise, out_image = CBDNet(in_image) 60 | 61 | G_loss = tf.losses.mean_squared_error(gt_image, out_image) + \ 62 | 0.5 * tf.reduce_mean(tf.multiply(tf.abs(0.3 - tf.nn.relu(gt_noise - est_noise)), tf.square(est_noise - gt_noise))) + \ 63 | 0.05 * tf.reduce_mean(tf.square(tf.image.image_gradients(est_noise))) 64 | 65 | lr = tf.placeholder(tf.float32) 66 | G_opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss) 67 | 68 | # load model 69 | sess = tf.Session() 70 | sess.run(tf.global_variables_initializer()) 71 | 72 | save_vars = [v for v in tf.global_variables() if (v.name.split('/')[0] == 'fcn' or v.name.split('/')[0] == 'unet')] 73 | saver = tf.train.Saver() 74 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 75 | if ckpt: 76 | print('loaded', checkpoint_dir) 77 | saver.restore(sess, ckpt.model_checkpoint_path) 78 | 79 | allpoint = glob.glob(checkpoint_dir + 'epoch-*') 80 | lastepoch = 0 81 | for point in allpoint: 82 | cur_epoch = re.findall(r'epoch-(\d+)', point) 83 | lastepoch = np.maximum(lastepoch, int(cur_epoch[0])) 84 | 85 | learning_rate = 1e-4 86 | for epoch in range(lastepoch, 201): 87 | losses = AverageMeter() 88 | 89 | if os.path.isdir(result_dir+"%04d"%epoch): 90 | continue 91 | cnt=0 92 | 93 | if epoch > 100: 94 | learning_rate = 1e-5 95 | 96 | for ind in np.random.permutation(len(train_fns)): 97 | train_fn = train_fns[ind] 98 | 99 | if not len(origin_imgs[ind]): 100 | origin_img = ReadImg(train_fn) 101 | origin_imgs[ind] = np.expand_dims(origin_img, axis=0) 102 | 103 | # re-add noise 104 | if epoch % save_freq == 0: 105 | noise_imgs[ind] = [] 106 | noise_levels[ind] = [] 107 | 108 | if len(noise_imgs[ind]) < 1: 109 | noise_img, noise_level = AddRealNoise(origin_imgs[ind][0, :, :, :], CRF_para, iCRF_para, I_gl, B_gl, I_inv_gl, B_inv_gl) 110 | noise_imgs[ind].append(np.expand_dims(noise_img, axis=0)) 111 | noise_levels[ind].append(np.expand_dims(noise_level, axis=0)) 112 | 113 | st = time.time() 114 | for nind in np.random.permutation(len(noise_imgs[ind])): 115 | temp_origin_img = origin_imgs[ind] 116 | temp_noise_img = noise_imgs[ind][nind] 117 | temp_noise_level = noise_levels[ind][nind] 118 | 119 | # data augmentation 120 | if np.random.randint(2, size=1)[0] == 1: 121 | temp_origin_img = np.flip(temp_origin_img, axis=1) 122 | temp_noise_img = np.flip(temp_noise_img, axis=1) 123 | temp_noise_level = np.flip(temp_noise_level, axis=1) 124 | if np.random.randint(2, size=1)[0] == 1: 125 | temp_origin_img = np.flip(temp_origin_img, axis=0) 126 | temp_noise_img = np.flip(temp_noise_img, axis=0) 127 | temp_noise_level = np.flip(temp_noise_level, axis=0) 128 | if np.random.randint(2, size=1)[0] == 1: 129 | temp_origin_img = np.transpose(temp_origin_img, (0, 2, 1, 3)) 130 | temp_noise_img = np.transpose(temp_noise_img, (0, 2, 1, 3)) 131 | temp_noise_level = np.transpose(temp_noise_level, (0, 2, 1, 3)) 132 | 133 | cnt += 1 134 | st = time.time() 135 | 136 | _, G_current, output = sess.run( 137 | [G_opt, G_loss, out_image], 138 | feed_dict={in_image:temp_noise_img, gt_image:temp_origin_img, gt_noise:temp_noise_level, lr:learning_rate} 139 | ) 140 | output = np.clip(output, 0, 1) 141 | losses.update(G_current) 142 | 143 | print("%d %d Loss=%.4f Time=%.3f"%(epoch, cnt, losses.avg, time.time()-st)) 144 | 145 | if epoch % save_freq == 0: 146 | if not os.path.isdir(result_dir + '%04d'%epoch): 147 | os.makedirs(result_dir + '%04d'%epoch) 148 | 149 | temp = np.concatenate((temp_origin_img[0, :, :, :], temp_noise_img[0, :, :, :], output[0, :, :, :]), axis=1) 150 | scipy.misc.toimage(temp*255, high=255, low=0, cmin=0, cmax=255).save(result_dir + '%04d/train_%d_%d.jpg'%(epoch, ind, nind)) 151 | 152 | saver.save(sess, checkpoint_dir + 'model.ckpt') 153 | 154 | if not os.path.isdir(checkpoint_dir + 'epoch-' + str(epoch)): 155 | os.mkdir(checkpoint_dir + 'epoch-' + str(epoch)) 156 | 157 | if os.path.isdir(checkpoint_dir + 'epoch-' + str(epoch - 1)): 158 | os.rmdir(checkpoint_dir + 'epoch-' + str(epoch - 1)) 159 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import math 5 | import scipy.io as sio 6 | import matplotlib.pyplot as plt 7 | from scipy.optimize import curve_fit 8 | 9 | class AverageMeter(object): 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def ReadImg(filename): 26 | img = cv2.imread(filename) 27 | img = img[:,:,::-1] / 255.0 28 | img = np.array(img).astype('float32') 29 | 30 | return img 31 | 32 | #################################################### 33 | #################### noise model ################### 34 | #################################################### 35 | 36 | def func(x, a): 37 | return np.power(x, a) 38 | 39 | def CRF_curve_fit(I, B): 40 | popt, pcov = curve_fit(func, I, B) 41 | return popt 42 | 43 | def CRF_function_transfer(x, y): 44 | para = [] 45 | for crf in range(201): 46 | temp_x = np.array(x[crf, :]) 47 | temp_y = np.array(y[crf, :]) 48 | para.append(CRF_curve_fit(temp_x, temp_y)) 49 | return para 50 | 51 | def mosaic_bayer(rgb, pattern, noiselevel): 52 | 53 | w, h, c = rgb.shape 54 | if pattern == 1: 55 | num = [1, 2, 0, 1] 56 | elif pattern == 2: 57 | num = [1, 0, 2, 1] 58 | elif pattern == 3: 59 | num = [2, 1, 1, 0] 60 | elif pattern == 4: 61 | num = [0, 1, 1, 2] 62 | elif pattern == 5: 63 | return rgb 64 | 65 | mosaic = np.zeros((w, h, 3)) 66 | mask = np.zeros((w, h, 3)) 67 | B = np.zeros((w, h)) 68 | 69 | B[0:w:2, 0:h:2] = rgb[0:w:2, 0:h:2, num[0]] 70 | B[0:w:2, 1:h:2] = rgb[0:w:2, 1:h:2, num[1]] 71 | B[1:w:2, 0:h:2] = rgb[1:w:2, 0:h:2, num[2]] 72 | B[1:w:2, 1:h:2] = rgb[1:w:2, 1:h:2, num[3]] 73 | 74 | gauss = np.random.normal(0, noiselevel/255.,(w, h)) 75 | gauss = gauss.reshape(w, h) 76 | B = B + gauss 77 | 78 | return (B, mask, mosaic) 79 | 80 | def ICRF_Map(Img, I, B): 81 | w, h, c = Img.shape 82 | output_Img = Img.copy() 83 | prebin = I.shape[0] 84 | tiny_bin = 9.7656e-04 85 | min_tiny_bin = 0.0039 86 | for i in range(w): 87 | for j in range(h): 88 | for k in range(c): 89 | temp = output_Img[i, j, k] 90 | start_bin = 0 91 | if temp > min_tiny_bin: 92 | start_bin = math.floor(temp/tiny_bin - 1) - 1 93 | for b in range(start_bin, prebin): 94 | tempB = B[b] 95 | if tempB >= temp: 96 | index = b 97 | if index > 0: 98 | comp1 = tempB - temp 99 | comp2 = temp - B[index-1] 100 | if comp2 < comp1: 101 | index = index - 1 102 | output_Img[i, j, k] = I[index] 103 | break 104 | 105 | return output_Img 106 | 107 | def CRF_Map(Img, I, B): 108 | w, h, c = Img.shape 109 | output_Img = Img.copy() 110 | prebin = I.shape[0] 111 | tiny_bin = 9.7656e-04 112 | min_tiny_bin = 0.0039 113 | for i in range(w): 114 | for j in range(h): 115 | for k in range(c): 116 | temp = output_Img[i, j, k] 117 | 118 | if temp < 0: 119 | temp = 0 120 | Img[i, j, k] = 0 121 | elif temp > 1: 122 | temp = 1 123 | Img[i, j, k] = 1 124 | start_bin = 0 125 | if temp > min_tiny_bin: 126 | start_bin = math.floor(temp/tiny_bin - 1) - 1 127 | 128 | for b in range(start_bin, prebin): 129 | tempB = I[b] 130 | if tempB >= temp: 131 | index = b 132 | if index > 0: 133 | comp1 = tempB - temp 134 | comp2 = temp - B[index-1] 135 | if comp2 < comp1: 136 | index = index - 1 137 | output_Img[i, j, k] = B[index] 138 | break 139 | return output_Img 140 | 141 | def CRF_Map_opt(Img, popt): 142 | w, h, c = Img.shape 143 | output_Img = Img.copy() 144 | 145 | output_Img = func(output_Img, *popt) 146 | return output_Img 147 | 148 | def Demosaic(B_b, pattern): 149 | 150 | B_b = B_b * 255 151 | B_b = B_b.astype(np.uint16) 152 | 153 | if pattern == 1: 154 | lin_rgb = cv2.demosaicing(B_b, cv2.COLOR_BayerGB2BGR) 155 | elif pattern == 2: 156 | lin_rgb = cv2.demosaicing(B_b, cv2.COLOR_BayerGR2BGR) 157 | elif pattern == 3: 158 | lin_rgb = cv2.demosaicing(B_b, cv2.COLOR_BayerBG2BGR) 159 | elif pattern == 4: 160 | lin_rgb = cv2.demosaicing(B_b, cv2.COLOR_BayerRG2BGR) 161 | elif pattern == 5: 162 | lin_rgb = B_b 163 | 164 | lin_rgb = lin_rgb[:,:,::-1] / 255. 165 | return lin_rgb 166 | 167 | def AddNoiseMosai(x, CRF_para, iCRF_para, I, B, Iinv, Binv, sigma_s, sigma_c, crf_index, pattern, opt = 1): 168 | w, h, c = x.shape 169 | temp_x = CRF_Map_opt(x, iCRF_para[crf_index] ) 170 | 171 | sigma_s = np.reshape(sigma_s, (1, 1, c)) 172 | noise_s_map = np.multiply(sigma_s, temp_x) 173 | noise_s = np.random.randn(w, h, c) * noise_s_map 174 | temp_x_n = temp_x + noise_s 175 | 176 | noise_c = np.zeros((w, h, c)) 177 | for chn in range(3): 178 | noise_c [:, :, chn] = np.random.normal(0, sigma_c[chn], (w, h)) 179 | 180 | temp_x_n = temp_x_n + noise_c 181 | temp_x_n = np.clip(temp_x_n, 0.0, 1.0) 182 | temp_x_n = CRF_Map_opt(temp_x_n, CRF_para[crf_index]) 183 | 184 | if opt == 1: 185 | temp_x = CRF_Map_opt(temp_x, CRF_para[crf_index]) 186 | 187 | B_b_n = mosaic_bayer(temp_x_n[:,:,::-1], pattern, 0)[0] 188 | lin_rgb_n = Demosaic(B_b_n, pattern) 189 | result = lin_rgb_n 190 | if opt == 1: 191 | B_b = mosaic_bayer(temp_x[:,:,::-1], pattern, 0)[0] 192 | lin_rgb = Demosaic(B_b, pattern) 193 | diff = lin_rgb_n - lin_rgb 194 | result = x + diff 195 | 196 | return result 197 | 198 | def AddRealNoise(image, CRF_para, iCRF_para, I_gl, B_gl, I_inv_gl, B_inv_gl): 199 | sigma_s = np.random.uniform(0.0, 0.16, (3,)) 200 | sigma_c = np.random.uniform(0.0, 0.06, (3,)) 201 | CRF_index = np.random.choice(201) 202 | pattern = np.random.choice(4) + 1 203 | noise_img = AddNoiseMosai(image, CRF_para, iCRF_para, I_gl, B_gl, I_inv_gl, B_inv_gl, sigma_s, sigma_c, CRF_index, pattern, 0) 204 | noise_level = sigma_s * np.power(image, 0.5) + sigma_c 205 | 206 | return noise_img, noise_level 207 | 208 | --------------------------------------------------------------------------------