├── README.md ├── dataset.py ├── dnd_denoise.py ├── estimator.py ├── network.py ├── process.py ├── requirements.txt ├── train.py └── unprocess.py /README.md: -------------------------------------------------------------------------------- 1 | # Unprocessing Images for Learned Raw Denoising 2 | 3 | Reference code for the paper [Unprocessing Images for Learned Raw Denoising](http://timothybrooks.com/tech/unprocessing). 4 | Tim Brooks, Ben Mildenhall, Tianfan Xue, Jiawen Chen, Dillon Sharlet, Jonathan T. Barron 5 | CVPR 2019 6 | 7 | If you use this code, please cite our paper: 8 | 9 | ``` 10 | @inproceedings{brooks2019unprocessing, 11 | title={Unprocessing Images for Learned Raw Denoising}, 12 | author={Brooks, Tim and Mildenhall, Ben and Xue, Tianfan and Chen, Jiawen and Sharlet, Dillon and Barron, Jonathan T}, 13 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 14 | year={2019}, 15 | } 16 | ``` 17 | 18 | The code is implemented in Tensorflow and the required packages are listed in `requirements.txt`. 19 | 20 | ## Evaluation on Darmstadt Noise Dataset 21 | 22 | In our paper, we evaluate on the [Darmstadt Noise Dataset](https://noise.visinf.tu-darmstadt.de/). Here are our [Darmstadt results](http://timothybrooks.com/tech/unprocessing/darmstadt-supp/). We highly recommend this dataset for measuring denoise performance on real photographs, as the dataset contains real noisy images, which after denoising and upon submission to the Darmstadt website will be compared against real clean ground truth. Here are instructions to [download this dataset](https://noise.visinf.tu-darmstadt.de/downloads). You'll also need to download [our trained models](https://drive.google.com/file/d/1MTFr-uaIKv5aWe7nXlhTaHBestLUiDLZ/view?usp=sharing) and unzip them into ./models/. Once downloaded, replace the provided `dnd_denoise.py` file with the version in this repository and follow the instructions below to run an unprocessing denoiser on this data. 23 | 24 | ``` 25 | /google-research$ python -m unprocessing.dnd_denoise \ 26 | --model_ckpt=/path/to/models/unprocessing_srgb_loss/model.ckpt-3516383 \ 27 | --data_dir=/path/to/darmstadt/data \ 28 | --output_dir=/path/to/darmstadt/ouputs 29 | ``` 30 | 31 | Then follow instructions in the Darmstadt README file, including running `bundle_submissions.py` to prepare for submission to the Darmstadt Noise Dataset online benchmark. 32 | 33 | ## Training on MIRFlickr Images 34 | 35 | In our paper, we train on source images from [MIRFlickr](https://press.liacs.nl/mirflickr/). We used the full MIRFLICKR-1M dataset, which includes 1 million source images, although the smaller MIRFLICKR-25000, which contains 25 thousand source images, can be used as well and is easier to download and store. Both versions are freely avaiable—here are instructions to [download this dataset](http://press.liacs.nl/mirflickr/mirdownload.html). Once downloaded, break images into `mirflickr/train` and `mirflickr/test` subdirectories. Then run the command below to train on MIRFlickr. Note that you may wish to downsample the source images prior to training to help remove JPEG compression artifacts, which we found slightly helpful in our paper. 36 | 37 | ``` 38 | /google-research$ python -m unprocessing.train \ 39 | --model_dir=/path/to/models/unprocessing_mirflickr \ 40 | --train_pattern=/path/to/mirflickr/train \ 41 | --test_pattern=/path/to/mirflickr/test 42 | ``` 43 | 44 | ## Training for Different Cameras 45 | 46 | Our models are trained to work best on the Darmstadt Noise Dataset, which contains 4 cameras of various types, and low to moderate amounts of noise. They generalize well to other images, such as those from the HDR+ Burst Photography Dataset. However, if you would like to train a denoiser to work best on images from different cameras, you may wish to modify `random_ccm()`, `random_gains()` and `random_noise_levels()` in `py/unprocess.py` to best match the distribution of image metadata from your cameras. See our paper for details of how we modeled such metadata for Darmstadt. If your cameras have a special Bayer pattern outside of those supported, you will also need to modify `mosaic()` and `demosaic()` to match. 47 | 48 | ## Evaluation on Different Real Data 49 | 50 | To run evaluation on real raw photographs outside of Darmstadt and HDR+ datasets will require loading and running the trained model similar to `dnd_denoise.py`, packing input raw images into a 4-channel image of Bayer planes, and estimating the variance from shot and read noise levels. Note that these models are only designed to work with *raw* images (not processed sRGB images). 51 | 52 | Shot and read noise levels are sometimes included in image metadata, and may go by different names, so we recommend referring to the specification for your camera's metadata. The shot noise level is a measurement of how much variance is proportional to the input signal, and the read noise level is a measurement of how much variance is independent of the image. We calculate an approximation of variance using the input noisy image as `variance = shot_noise * noisy_img + read_noise`, and pass it as an additional input both during training and evaluation. 53 | 54 | If shot and read noise levels are not provided in your camera's metadata, it is possible to empirically measure these noise levels by calibrating your sensor or inferring from a single image. Here is a great overview of [shot and read noise](http://people.csail.mit.edu/hasinoff/pubs/hasinoff-photon-2012-preprint.pdf), and here is an example [noise calibration script](https://android.googlesource.com/platform/cts/+/806b430/apps/CameraITS/tests/dng_noise_model/dng_noise_model.py). 55 | 56 | ## Unprocessing Images in Other Training Pipelines 57 | 58 | If you are only looking to generate realistic raw data, you are welcome to use the unprocessing portion of our code separately from the rest of this project. Feel free to copy our `unprocess.py` file for all of your unprocessing needs! 59 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Creates a Dataset of unprocessed images for denoising. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | 28 | from unprocessing import unprocess 29 | 30 | 31 | def read_jpg(filename): 32 | """Reads an 8-bit JPG file from disk and normalizes to [0, 1].""" 33 | image_file = tf.read_file(filename) 34 | image = tf.image.decode_jpeg(image_file, channels=3) 35 | white_level = 255.0 36 | return tf.cast(image, tf.float32) / white_level 37 | 38 | 39 | def is_large_enough(image, height, width): 40 | """Checks if `image` is at least as large as `height` by `width`.""" 41 | image.shape.assert_has_rank(3) 42 | shape = tf.shape(image) 43 | image_height = shape[0] 44 | image_width = shape[1] 45 | return tf.logical_and( 46 | tf.greater_equal(image_height, height), 47 | tf.greater_equal(image_width, width)) 48 | 49 | 50 | def augment(image, height, width): 51 | """Randomly flips and crops `images` to `height` by `width`.""" 52 | size = [height, width, tf.shape(image)[-1]] 53 | image = tf.random_crop(image, size) 54 | image = tf.image.random_flip_left_right(image) 55 | image = tf.image.random_flip_up_down(image) 56 | return image 57 | 58 | 59 | def create_example(image): 60 | """Creates training example of inputs and labels from `image`.""" 61 | image.shape.assert_is_compatible_with([None, None, 3]) 62 | image, metadata = unprocess.unprocess(image) 63 | shot_noise, read_noise = unprocess.random_noise_levels() 64 | noisy_img = unprocess.add_noise(image, shot_noise, read_noise) 65 | # Approximation of variance is calculated using noisy image (rather than clean 66 | # image), since that is what will be avaiable during evaluation. 67 | variance = shot_noise * noisy_img + read_noise 68 | 69 | inputs = { 70 | 'noisy_img': noisy_img, 71 | 'variance': variance, 72 | } 73 | inputs.update(metadata) 74 | labels = image 75 | return inputs, labels 76 | 77 | 78 | def create_dataset_fn(dir_pattern, height, width, batch_size): 79 | """Wrapper for creating a dataset function for unprocessing. 80 | 81 | Args: 82 | dir_pattern: A string representing source data directory glob. 83 | height: Height to crop images. 84 | width: Width to crop images. 85 | batch_size: Number of training examples per batch. 86 | 87 | Returns: 88 | Nullary function that returns a Dataset. 89 | """ 90 | if height % 16 != 0 or width % 16 != 0: 91 | raise ValueError('`height` and `width` must be multiples of 16.') 92 | 93 | def dataset_fn_(): 94 | """Creates a Dataset for unprocessing training.""" 95 | autotune = tf.data.experimental.AUTOTUNE 96 | 97 | filenames = tf.data.Dataset.list_files(dir_pattern, shuffle=True).repeat() 98 | images = filenames.map(read_jpg, num_parallel_calls=autotune) 99 | images = images.filter(lambda x: is_large_enough(x, height, width)) 100 | images = images.map( 101 | lambda x: augment(x, height, width), num_parallel_calls=autotune) 102 | examples = images.map(create_example, num_parallel_calls=autotune) 103 | return examples.batch(batch_size).prefetch(autotune) 104 | 105 | return dataset_fn_ 106 | -------------------------------------------------------------------------------- /dnd_denoise.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unprocessing evaluation on the Darmstadt Noise Dataset. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | 21 | This file denoises images from the Darmstadt Noise Dataset using the 22 | unprocessing neural networks. The full Darmstadt code and data should be 23 | downloaded from https://noise.visinf.tu-darmstadt.de/downloads and this file 24 | should replace the dnd_denoise.py file that is provided. 25 | 26 | This file is modified from the original version by Tobias Plotz, TU Darmstadt 27 | (tobias.ploetz@visinf.tu-darmstadt.de), and is part of the implementation as 28 | described in the CVPR 2017 paper: Benchmarking Denoising Algorithms with Real 29 | Photographs, Tobias Plotz and Stefan Roth. Modified by Tim Brooks of Google in 30 | 2019. The original license is below. 31 | 32 | Copyright (c) 2017, Technische Universitat Darmstadt 33 | All rights reserved. 34 | 35 | Redistribution and use in source and binary forms, with or without modification, 36 | are permitted provided that the following conditions are met: 37 | 38 | 1. Redistributions of source code must retain the above copyright notice, this 39 | list of conditions and the following disclaimer. 40 | 41 | 2. Redistributions in binary form must reproduce the above copyright notice, 42 | this list of conditions and the following disclaimer in the documentation and/or 43 | other materials provided with the distribution. 44 | 45 | 3. Any redistribution, use, or modification is done solely for non-commercial 46 | purposes. Examples of non-commercial uses are teaching, academic research, 47 | public demonstrations and personal experimentation. 48 | 49 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 50 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 51 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 52 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 53 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 54 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 55 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 56 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 57 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 58 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 59 | """ 60 | 61 | from __future__ import absolute_import 62 | from __future__ import division 63 | from __future__ import print_function 64 | 65 | import os 66 | from absl import app 67 | from absl import flags 68 | import h5py 69 | import numpy as np 70 | import scipy.io as sio 71 | import tensorflow as tf 72 | 73 | FLAGS = flags.FLAGS 74 | 75 | flags.DEFINE_string( 76 | 'model_ckpt', 77 | None, 78 | 'Path to checkpoint of a trained unprocessing model. For example: ' 79 | '/path/to/models/unprocessing_srgb_loss/model.ckpt-3516383') 80 | 81 | flags.DEFINE_string( 82 | 'data_dir', 83 | None, 84 | 'Location from which to load input noisy images. This should correspond ' 85 | 'with the \'data\' directory downloaded as part of the Darmstadt Noise ' 86 | 'Dataset.') 87 | 88 | flags.DEFINE_string( 89 | 'output_dir', 90 | None, 91 | 'Location at which to save output denoised images.') 92 | 93 | 94 | def denoise_raw(denoiser, data_dir, output_dir): 95 | """Denoises all bounding boxes in all raw images from the DND dataset. 96 | 97 | The resulting denoised images are saved to disk. 98 | 99 | Args: 100 | denoiser: Function handle called as: 101 | denoised_img = denoiser(noisy_img, shot_noise, read_noise). 102 | data_dir: Folder where the DND dataset resides 103 | output_dir: Folder where denoised output should be written to 104 | 105 | Returns: 106 | None 107 | """ 108 | # Loads image information and bounding boxes. 109 | info = h5py.File(os.path.join(data_dir, 'info.mat'), 'r')['info'] 110 | bb = info['boundingboxes'] 111 | 112 | # Denoise each image. 113 | for i in range(50): 114 | # Loads the noisy image. 115 | filename = os.path.join(data_dir, 'images_raw', '%04d.mat' % (i + 1)) 116 | img = h5py.File(filename, 'r') 117 | noisy = np.float32(np.array(img['Inoisy']).T) 118 | 119 | # Loads raw Bayer color pattern. 120 | bayer_pattern = np.asarray(info[info['camera'][0][i]]['pattern']).tolist() 121 | 122 | # Denoises each bounding box in this image. 123 | boxes = np.array(info[bb[0][i]]).T 124 | for k in range(20): 125 | # Crops the image to this bounding box. 126 | idx = [ 127 | int(boxes[k, 0] - 1), 128 | int(boxes[k, 2]), 129 | int(boxes[k, 1] - 1), 130 | int(boxes[k, 3]) 131 | ] 132 | noisy_crop = noisy[idx[0]:idx[1], idx[2]:idx[3]].copy() 133 | 134 | # Flips the raw image to ensure RGGB Bayer color pattern. 135 | if (bayer_pattern == [[1, 2], [2, 3]]): 136 | pass 137 | elif (bayer_pattern == [[2, 1], [3, 2]]): 138 | noisy_crop = np.fliplr(noisy_crop) 139 | elif (bayer_pattern == [[2, 3], [1, 2]]): 140 | noisy_crop = np.flipud(noisy_crop) 141 | else: 142 | print('Warning: assuming unknown Bayer pattern is RGGB.') 143 | 144 | # Loads shot and read noise factors. 145 | nlf_h5 = info[info['nlf'][0][i]] 146 | shot_noise = nlf_h5['a'][0][0] 147 | read_noise = nlf_h5['b'][0][0] 148 | 149 | # Extracts each Bayer image plane. 150 | denoised_crop = noisy_crop.copy() 151 | height, width = noisy_crop.shape 152 | channels = [] 153 | for yy in range(2): 154 | for xx in range(2): 155 | noisy_crop_c = noisy_crop[yy:height:2, xx:width:2].copy() 156 | channels.append(noisy_crop_c) 157 | channels = np.stack(channels, axis=-1) 158 | 159 | # Denoises this crop of the image. 160 | output = denoiser(channels, shot_noise, read_noise) 161 | 162 | # Copies denoised results to output denoised array. 163 | for yy in range(2): 164 | for xx in range(2): 165 | denoised_crop[yy:height:2, xx:width:2] = output[:, :, 2 * yy + xx] 166 | 167 | # Flips denoised image back to original Bayer color pattern. 168 | if (bayer_pattern == [[1, 2], [2, 3]]): 169 | pass 170 | elif (bayer_pattern == [[2, 1], [3, 2]]): 171 | denoised_crop = np.fliplr(denoised_crop) 172 | elif (bayer_pattern == [[2, 3], [1, 2]]): 173 | denoised_crop = np.flipud(denoised_crop) 174 | 175 | # Saves denoised image crop. 176 | denoised_crop = np.clip(np.float32(denoised_crop), 0.0, 1.0) 177 | save_file = os.path.join(output_dir, '%04d_%02d.mat' % (i + 1, k + 1)) 178 | sio.savemat(save_file, {'denoised_crop': denoised_crop}) 179 | 180 | 181 | def main(_): 182 | with tf.Graph().as_default() as graph: 183 | with tf.Session(graph=graph) as sess: 184 | saver = tf.train.import_meta_graph(FLAGS.model_ckpt + '.meta') 185 | saver.restore(sess, FLAGS.model_ckpt) 186 | 187 | def denoiser(noisy_img, shot_noise, read_noise): 188 | """Unprocessing denoiser.""" 189 | denoised_img_tensor = graph.get_tensor_by_name('denoised_img:0') 190 | noisy_img_tensor = graph.get_tensor_by_name('noisy_img:0') 191 | shot_noise_tensor = graph.get_tensor_by_name('stddev/shot_noise:0') 192 | read_noise_tensor = graph.get_tensor_by_name('stddev/read_noise:0') 193 | feed_dict = { 194 | noisy_img_tensor: noisy_img[np.newaxis, :, :, :], 195 | shot_noise_tensor: np.asarray([shot_noise]), 196 | read_noise_tensor: np.asarray([read_noise]) 197 | } 198 | return sess.run(denoised_img_tensor, feed_dict=feed_dict)[0, :, :, :] 199 | 200 | denoise_raw(denoiser, FLAGS.data_dir, FLAGS.output_dir) 201 | 202 | 203 | if __name__ == '__main__': 204 | app.run(main) 205 | -------------------------------------------------------------------------------- /estimator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unprocessing model function and train and eval specs for Estimator. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | 28 | from unprocessing import process 29 | 30 | 31 | def psnr(labels, predictions): 32 | """Computes average peak signal-to-noise ratio of `predictions`. 33 | 34 | Here PSNR is defined with respect to the maximum value of 1. All image tensors 35 | must be within the range [0, 1]. 36 | 37 | Args: 38 | labels: Tensor of shape [B, H, W, N]. 39 | predictions: Tensor of shape [B, H, W, N]. 40 | 41 | Returns: 42 | Tuple of (psnr, update_op) as returned by tf.metrics. 43 | """ 44 | predictions.shape.assert_is_compatible_with(labels.shape) 45 | with tf.control_dependencies([tf.assert_greater_equal(labels, 0.0), 46 | tf.assert_less_equal(labels, 1.0)]): 47 | psnrs = tf.image.psnr(labels, predictions, max_val=1.0) 48 | psnrs = tf.boolean_mask(psnrs, tf.logical_not(tf.is_inf(psnrs))) 49 | return tf.metrics.mean(psnrs, name='psnr') 50 | 51 | 52 | def create_model_fn(inference_fn, hparams): 53 | """Creates a model function for Estimator. 54 | 55 | Args: 56 | inference_fn: Model inference function with specification: 57 | Args - 58 | noisy_img - Tensor of shape [B, H, W, 4]. 59 | variance - Tensor of shape [B, H, W, 4]. 60 | Returns - 61 | Tensor of shape [B, H, W, 4]. 62 | hparams: Hyperparameters for model as a tf.contrib.training.HParams object. 63 | 64 | Returns: 65 | `_model_fn`. 66 | """ 67 | def _model_fn(features, labels, mode, params): 68 | """Constructs the model function. 69 | 70 | Args: 71 | features: Dictionary of input features. 72 | labels: Tensor of labels if mode is `TRAIN` or `EVAL`, otherwise `None`. 73 | mode: ModeKey object (`TRAIN` or `EVAL`). 74 | params: Parameter dictionary passed from the Estimator object. 75 | 76 | Returns: 77 | An EstimatorSpec object that encapsulates the model and its serving 78 | configurations. 79 | """ 80 | del params # Unused. 81 | 82 | def process_images(images): 83 | """Closure for processing images with fixed metadata.""" 84 | return process.process(images, features['red_gain'], 85 | features['blue_gain'], features['cam2rgb']) 86 | 87 | denoised_img = inference_fn(features['noisy_img'], features['variance']) 88 | 89 | noisy_img = process_images(features['noisy_img']) 90 | denoised_img = process_images(denoised_img) 91 | truth_img = process_images(labels) 92 | 93 | if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: 94 | loss = tf.losses.absolute_difference(truth_img, denoised_img) 95 | else: 96 | loss = None 97 | 98 | if mode == tf.estimator.ModeKeys.TRAIN: 99 | optimizer = tf.train.AdamOptimizer(learning_rate=hparams.learning_rate) 100 | train_op = tf.contrib.layers.optimize_loss( 101 | loss=loss, 102 | global_step=tf.train.get_global_step(), 103 | learning_rate=None, 104 | optimizer=optimizer, 105 | name='') # Prevents scope prefix. 106 | else: 107 | train_op = None 108 | 109 | if mode == tf.estimator.ModeKeys.EVAL: 110 | eval_metric_ops = {'PSNR': psnr(truth_img, denoised_img)} 111 | 112 | def summary(images, name): 113 | """As a hack, saves image summaries by adding to `eval_metric_ops`.""" 114 | images = tf.saturate_cast(images * 255 + 0.5, tf.uint8) 115 | eval_metric_ops[name] = (tf.summary.image(name, images, max_outputs=2), 116 | tf.no_op()) 117 | 118 | summary(noisy_img, 'Noisy') 119 | summary(denoised_img, 'Denoised') 120 | summary(truth_img, 'Truth') 121 | 122 | diffs = (denoised_img - truth_img + 1.0) / 2.0 123 | summary(diffs, 'Diffs') 124 | 125 | else: 126 | eval_metric_ops = None 127 | 128 | return tf.estimator.EstimatorSpec( 129 | mode=mode, 130 | loss=loss, 131 | train_op=train_op, 132 | eval_metric_ops=eval_metric_ops) 133 | 134 | return _model_fn 135 | 136 | 137 | def create_train_and_eval_specs(train_dataset_fn, 138 | eval_dataset_fn, 139 | eval_steps=250): 140 | """Creates a TrainSpec and EvalSpec. 141 | 142 | Args: 143 | train_dataset_fn: Function returning a Dataset of training data. 144 | eval_dataset_fn: Function returning a Dataset of evaluation data. 145 | eval_steps: Number of steps for evaluating model. 146 | 147 | Returns: 148 | Tuple of (TrainSpec, EvalSpec). 149 | """ 150 | train_spec = tf.estimator.TrainSpec(input_fn=train_dataset_fn, max_steps=None) 151 | 152 | eval_spec = tf.estimator.EvalSpec( 153 | input_fn=eval_dataset_fn, steps=eval_steps, name='') 154 | 155 | return train_spec, eval_spec 156 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unprocessing neural network architecture. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | 28 | 29 | def conv(features, num_channels, activation=tf.nn.leaky_relu): 30 | """Applies a 3x3 conv layer.""" 31 | return tf.layers.conv2d(features, num_channels, 3, padding='same', 32 | activation=activation) 33 | 34 | 35 | def conv_block(features, num_channels): 36 | """Applies 3x conv layers.""" 37 | with tf.name_scope(None, 'conv_block'): 38 | features = conv(features, num_channels) 39 | features = conv(features, num_channels) 40 | features = conv(features, num_channels) 41 | return features 42 | 43 | 44 | def downsample_2x(features): 45 | """Applies a 2x spatial downsample via max pooling.""" 46 | with tf.name_scope(None, 'downsample_2x'): 47 | return tf.layers.max_pooling2d(features, 2, 2, padding='same') 48 | 49 | 50 | def upsample_2x(features): 51 | """Applies a 2x spatial upsample via bilinear interpolation.""" 52 | with tf.name_scope(None, 'upsample_2x'): 53 | shape = tf.shape(features) 54 | shape = [shape[1] * 2, shape[2] * 2] 55 | features = tf.image.resize_bilinear(features, shape) 56 | return features 57 | 58 | 59 | def inference(noisy_img, variance): 60 | """Residual U-Net with skip connections. 61 | 62 | Expects four input channels for the Bayer color filter planes (e.g. RGGB). 63 | This is the format of real raw images before they are processed, and an 64 | effective time to denoise images in an image processing pipelines. 65 | 66 | Args: 67 | noisy_img: Tensor of shape [B, H, W, 4]. 68 | variance: Tensor of shape [B, H, W, 4]. 69 | 70 | Returns: 71 | Denoised image in Tensor of shape [B, H, W, 4]. 72 | """ 73 | 74 | noisy_img = tf.identity(noisy_img, 'noisy_img') 75 | noisy_img.set_shape([None, None, None, 4]) 76 | variance = tf.identity(variance, 'variance') 77 | variance.shape.assert_is_compatible_with(noisy_img.shape) 78 | variance.set_shape([None, None, None, 4]) 79 | 80 | features = tf.concat([noisy_img, variance], axis=-1) 81 | skip_connections = [] 82 | 83 | with tf.name_scope(None, 'encoder'): 84 | for num_channels in (32, 64, 128, 256): 85 | features = conv_block(features, num_channels) 86 | skip_connections.append(features) 87 | features = downsample_2x(features) 88 | features = conv_block(features, 512) 89 | 90 | with tf.name_scope(None, 'decoder'): 91 | for num_channels in (256, 128, 64, 32): 92 | features = upsample_2x(features) 93 | with tf.name_scope(None, 'skip_connection'): 94 | features = tf.concat([features, skip_connections.pop()], axis=-1) 95 | features = conv_block(features, num_channels) 96 | 97 | residual = conv(features, 4, None) 98 | return tf.identity(noisy_img + residual, 'denoised_img') 99 | -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Forward processing of raw data to sRGB images. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | 28 | 29 | def apply_gains(bayer_images, red_gains, blue_gains): 30 | """Applies white balance gains to a batch of Bayer images.""" 31 | bayer_images.shape.assert_is_compatible_with((None, None, None, 4)) 32 | green_gains = tf.ones_like(red_gains) 33 | gains = tf.stack([red_gains, green_gains, green_gains, blue_gains], axis=-1) 34 | gains = gains[:, tf.newaxis, tf.newaxis, :] 35 | return bayer_images * gains 36 | 37 | 38 | def demosaic(bayer_images): 39 | """Bilinearly demosaics a batch of RGGB Bayer images.""" 40 | bayer_images.shape.assert_is_compatible_with((None, None, None, 4)) 41 | 42 | # This implementation exploits how edges are aligned when upsampling with 43 | # tf.image.resize_bilinear(). 44 | 45 | with tf.name_scope(None, 'demosaic'): 46 | shape = tf.shape(bayer_images) 47 | shape = [shape[1] * 2, shape[2] * 2] 48 | 49 | red = bayer_images[Ellipsis, 0:1] 50 | red = tf.image.resize_bilinear(red, shape) 51 | 52 | green_red = bayer_images[Ellipsis, 1:2] 53 | green_red = tf.image.flip_left_right(green_red) 54 | green_red = tf.image.resize_bilinear(green_red, shape) 55 | green_red = tf.image.flip_left_right(green_red) 56 | green_red = tf.space_to_depth(green_red, 2) 57 | 58 | green_blue = bayer_images[Ellipsis, 2:3] 59 | green_blue = tf.image.flip_up_down(green_blue) 60 | green_blue = tf.image.resize_bilinear(green_blue, shape) 61 | green_blue = tf.image.flip_up_down(green_blue) 62 | green_blue = tf.space_to_depth(green_blue, 2) 63 | 64 | green_at_red = (green_red[Ellipsis, 0] + green_blue[Ellipsis, 0]) / 2 65 | green_at_green_red = green_red[Ellipsis, 1] 66 | green_at_green_blue = green_blue[Ellipsis, 2] 67 | green_at_blue = (green_red[Ellipsis, 3] + green_blue[Ellipsis, 3]) / 2 68 | 69 | green_planes = [ 70 | green_at_red, green_at_green_red, green_at_green_blue, green_at_blue 71 | ] 72 | green = tf.depth_to_space(tf.stack(green_planes, axis=-1), 2) 73 | 74 | blue = bayer_images[Ellipsis, 3:4] 75 | blue = tf.image.flip_up_down(tf.image.flip_left_right(blue)) 76 | blue = tf.image.resize_bilinear(blue, shape) 77 | blue = tf.image.flip_up_down(tf.image.flip_left_right(blue)) 78 | 79 | rgb_images = tf.concat([red, green, blue], axis=-1) 80 | return rgb_images 81 | 82 | 83 | def apply_ccms(images, ccms): 84 | """Applies color correction matrices.""" 85 | images.shape.assert_has_rank(4) 86 | images = images[:, :, :, tf.newaxis, :] 87 | ccms = ccms[:, tf.newaxis, tf.newaxis, :, :] 88 | return tf.reduce_sum(images * ccms, axis=-1) 89 | 90 | 91 | def gamma_compression(images, gamma=2.2): 92 | """Converts from linear to gamma space.""" 93 | # Clamps to prevent numerical instability of gradients near zero. 94 | return tf.maximum(images, 1e-8) ** (1.0 / gamma) 95 | 96 | 97 | def process(bayer_images, red_gains, blue_gains, cam2rgbs): 98 | """Processes a batch of Bayer RGGB images into sRGB images.""" 99 | bayer_images.shape.assert_is_compatible_with((None, None, None, 4)) 100 | with tf.name_scope(None, 'process'): 101 | # White balance. 102 | bayer_images = apply_gains(bayer_images, red_gains, blue_gains) 103 | # Demosaic. 104 | bayer_images = tf.clip_by_value(bayer_images, 0.0, 1.0) 105 | images = demosaic(bayer_images) 106 | # Color correction. 107 | images = apply_ccms(images, cam2rgbs) 108 | # Gamma compression. 109 | images = tf.clip_by_value(images, 0.0, 1.0) 110 | images = gamma_compression(images) 111 | return images 112 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # tensorflow>=1.13.0 2 | absl-py>=0.1.9 3 | h5py 4 | numpy 5 | scipy 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trains and evaluates unprocessing neural network. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | from absl import flags 27 | import tensorflow as tf 28 | 29 | from unprocessing import dataset 30 | from unprocessing import estimator 31 | from unprocessing import network 32 | 33 | FLAGS = flags.FLAGS 34 | 35 | flags.DEFINE_string( 36 | 'model_dir', 37 | None, 38 | 'Location at which to save model logs and checkpoints.') 39 | 40 | flags.DEFINE_string( 41 | 'train_pattern', 42 | None, 43 | 'Pattern for directory containing source JPG images for training.') 44 | 45 | flags.DEFINE_string( 46 | 'test_pattern', 47 | None, 48 | 'Pattern for directory containing source JPG images for testing.') 49 | 50 | flags.DEFINE_integer( 51 | 'image_size', 52 | 256, 53 | 'Width and height to crop training and testing frames. ' 54 | 'Must be a multiple of 16', 55 | lower_bound=16) 56 | 57 | flags.DEFINE_integer( 58 | 'batch_size', 59 | 16, 60 | 'Training batch size.', 61 | lower_bound=1) 62 | 63 | flags.DEFINE_float( 64 | 'learning_rate', 65 | 2e-5, 66 | 'Learning rate for Adam optimization.', 67 | lower_bound=0.0) 68 | 69 | flags.register_validator( 70 | 'image_size', 71 | lambda image_size: image_size % 16 == 0, 72 | message='\'image_size\' must multiple of 16.') 73 | 74 | flags.mark_flag_as_required('model_dir') 75 | flags.mark_flag_as_required('train_pattern') 76 | flags.mark_flag_as_required('test_pattern') 77 | 78 | 79 | def main(_): 80 | inference_fn = network.inference 81 | hparams = tf.contrib.training.HParams(learning_rate=FLAGS.learning_rate) 82 | model_fn = estimator.create_model_fn(inference_fn, hparams) 83 | config = tf.estimator.RunConfig(FLAGS.model_dir) 84 | tf_estimator = tf.estimator.Estimator(model_fn=model_fn, config=config) 85 | 86 | train_dataset_fn = dataset.create_dataset_fn( 87 | FLAGS.train_pattern, 88 | height=FLAGS.image_size, 89 | width=FLAGS.image_size, 90 | batch_size=FLAGS.batch_size) 91 | 92 | eval_dataset_fn = dataset.create_dataset_fn( 93 | FLAGS.test_pattern, 94 | height=FLAGS.image_size, 95 | width=FLAGS.image_size, 96 | batch_size=FLAGS.batch_size) 97 | 98 | train_spec, eval_spec = estimator.create_train_and_eval_specs( 99 | train_dataset_fn, eval_dataset_fn) 100 | 101 | tf.logging.set_verbosity(tf.logging.INFO) 102 | tf.estimator.train_and_evaluate(tf_estimator, train_spec, eval_spec) 103 | 104 | 105 | if __name__ == '__main__': 106 | tf.app.run(main) 107 | -------------------------------------------------------------------------------- /unprocess.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unprocesses sRGB images into realistic raw data. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import tensorflow as tf 27 | 28 | 29 | def random_ccm(): 30 | """Generates random RGB -> Camera color correction matrices.""" 31 | # Takes a random convex combination of XYZ -> Camera CCMs. 32 | xyz2cams = [[[1.0234, -0.2969, -0.2266], 33 | [-0.5625, 1.6328, -0.0469], 34 | [-0.0703, 0.2188, 0.6406]], 35 | [[0.4913, -0.0541, -0.0202], 36 | [-0.613, 1.3513, 0.2906], 37 | [-0.1564, 0.2151, 0.7183]], 38 | [[0.838, -0.263, -0.0639], 39 | [-0.2887, 1.0725, 0.2496], 40 | [-0.0627, 0.1427, 0.5438]], 41 | [[0.6596, -0.2079, -0.0562], 42 | [-0.4782, 1.3016, 0.1933], 43 | [-0.097, 0.1581, 0.5181]]] 44 | num_ccms = len(xyz2cams) 45 | xyz2cams = tf.constant(xyz2cams) 46 | weights = tf.random_uniform((num_ccms, 1, 1), 1e-8, 1e8) 47 | weights_sum = tf.reduce_sum(weights, axis=0) 48 | xyz2cam = tf.reduce_sum(xyz2cams * weights, axis=0) / weights_sum 49 | 50 | # Multiplies with RGB -> XYZ to get RGB -> Camera CCM. 51 | rgb2xyz = tf.to_float([[0.4124564, 0.3575761, 0.1804375], 52 | [0.2126729, 0.7151522, 0.0721750], 53 | [0.0193339, 0.1191920, 0.9503041]]) 54 | rgb2cam = tf.matmul(xyz2cam, rgb2xyz) 55 | 56 | # Normalizes each row. 57 | rgb2cam = rgb2cam / tf.reduce_sum(rgb2cam, axis=-1, keepdims=True) 58 | return rgb2cam 59 | 60 | 61 | def random_gains(): 62 | """Generates random gains for brightening and white balance.""" 63 | # RGB gain represents brightening. 64 | rgb_gain = 1.0 / tf.random_normal((), mean=0.8, stddev=0.1) 65 | 66 | # Red and blue gains represent white balance. 67 | red_gain = tf.random_uniform((), 1.9, 2.4) 68 | blue_gain = tf.random_uniform((), 1.5, 1.9) 69 | return rgb_gain, red_gain, blue_gain 70 | 71 | 72 | def inverse_smoothstep(image): 73 | """Approximately inverts a global tone mapping curve.""" 74 | image = tf.clip_by_value(image, 0.0, 1.0) 75 | return 0.5 - tf.sin(tf.asin(1.0 - 2.0 * image) / 3.0) 76 | 77 | 78 | def gamma_expansion(image): 79 | """Converts from gamma to linear space.""" 80 | # Clamps to prevent numerical instability of gradients near zero. 81 | return tf.maximum(image, 1e-8) ** 2.2 82 | 83 | 84 | def apply_ccm(image, ccm): 85 | """Applies a color correction matrix.""" 86 | shape = tf.shape(image) 87 | image = tf.reshape(image, [-1, 3]) 88 | image = tf.tensordot(image, ccm, axes=[[-1], [-1]]) 89 | return tf.reshape(image, shape) 90 | 91 | 92 | def safe_invert_gains(image, rgb_gain, red_gain, blue_gain): 93 | """Inverts gains while safely handling saturated pixels.""" 94 | gains = tf.stack([1.0 / red_gain, 1.0, 1.0 / blue_gain]) / rgb_gain 95 | gains = gains[tf.newaxis, tf.newaxis, :] 96 | 97 | # Prevents dimming of saturated pixels by smoothly masking gains near white. 98 | gray = tf.reduce_mean(image, axis=-1, keepdims=True) 99 | inflection = 0.9 100 | mask = (tf.maximum(gray - inflection, 0.0) / (1.0 - inflection)) ** 2.0 101 | safe_gains = tf.maximum(mask + (1.0 - mask) * gains, gains) 102 | return image * safe_gains 103 | 104 | 105 | def mosaic(image): 106 | """Extracts RGGB Bayer planes from an RGB image.""" 107 | image.shape.assert_is_compatible_with((None, None, 3)) 108 | shape = tf.shape(image) 109 | red = image[0::2, 0::2, 0] 110 | green_red = image[0::2, 1::2, 1] 111 | green_blue = image[1::2, 0::2, 1] 112 | blue = image[1::2, 1::2, 2] 113 | image = tf.stack((red, green_red, green_blue, blue), axis=-1) 114 | image = tf.reshape(image, (shape[0] // 2, shape[1] // 2, 4)) 115 | return image 116 | 117 | 118 | def unprocess(image): 119 | """Unprocesses an image from sRGB to realistic raw data.""" 120 | with tf.name_scope(None, 'unprocess'): 121 | image.shape.assert_is_compatible_with([None, None, 3]) 122 | 123 | # Randomly creates image metadata. 124 | rgb2cam = random_ccm() 125 | cam2rgb = tf.matrix_inverse(rgb2cam) 126 | rgb_gain, red_gain, blue_gain = random_gains() 127 | 128 | # Approximately inverts global tone mapping. 129 | image = inverse_smoothstep(image) 130 | # Inverts gamma compression. 131 | image = gamma_expansion(image) 132 | # Inverts color correction. 133 | image = apply_ccm(image, rgb2cam) 134 | # Approximately inverts white balance and brightening. 135 | image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain) 136 | # Clips saturated pixels. 137 | image = tf.clip_by_value(image, 0.0, 1.0) 138 | # Applies a Bayer mosaic. 139 | image = mosaic(image) 140 | 141 | metadata = { 142 | 'cam2rgb': cam2rgb, 143 | 'rgb_gain': rgb_gain, 144 | 'red_gain': red_gain, 145 | 'blue_gain': blue_gain, 146 | } 147 | return image, metadata 148 | 149 | 150 | def random_noise_levels(): 151 | """Generates random noise levels from a log-log linear distribution.""" 152 | log_min_shot_noise = tf.log(0.0001) 153 | log_max_shot_noise = tf.log(0.012) 154 | log_shot_noise = tf.random_uniform((), log_min_shot_noise, log_max_shot_noise) 155 | shot_noise = tf.exp(log_shot_noise) 156 | 157 | line = lambda x: 2.18 * x + 1.20 158 | log_read_noise = line(log_shot_noise) + tf.random_normal((), stddev=0.26) 159 | read_noise = tf.exp(log_read_noise) 160 | return shot_noise, read_noise 161 | 162 | 163 | def add_noise(image, shot_noise=0.01, read_noise=0.0005): 164 | """Adds random shot (proportional to image) and read (independent) noise.""" 165 | variance = image * shot_noise + read_noise 166 | noise = tf.random_normal(tf.shape(image), stddev=tf.sqrt(variance)) 167 | return image + noise 168 | --------------------------------------------------------------------------------