├── README.md ├── images ├── ScaDecArch.jpg ├── convergence.jpg ├── expExamples.jpg └── visualExamples.jpg ├── scadec ├── __init__.py ├── image_util.py ├── layers.py ├── nets.py ├── train.py ├── unet_bn.py └── util.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # [Efficient and accurate inversion of multiple scattering with deep learning](https://www.osapublishing.org/oe/abstract.cfm?uri=oe-26-11-14678&origin=search) 2 | 3 | This is the python implementation of the deep leraning model [ScaDec](https://www.osapublishing.org/oe/abstract.cfm?uri=oe-26-11-14678&origin=search) for inverting multiple light scattering in a surpervised manner. The [paper](https://www.osapublishing.org/oe/abstract.cfm?uri=oe-26-11-14678&origin=search) is originally published on [Optics Express](https://www.osapublishing.org/oe/home.cfm). The arxiv version of the paper is available [here](https://arxiv.org/abs/1803.06594) 4 | 5 | ## Abstract 6 | Image reconstruction under multiple light scattering is crucial in a number of applications such as diffraction tomography. The reconstruction problem is often formulated as a nonconvex optimization, where a nonlinear measurement model is used to account for multiple scattering and regularization is used to enforce prior constraints on the object. In this paper, we propose a powerful alternative to this optimization-based view of image reconstruction by designing and training a deep convolutional neural network that can invert multiple scattered measurements to produce a high-quality image of the refractive index. Our results on both simulated and experimental datasets show that the proposed approach is substantially faster and achieves higher imaging quality compared to the state-of-the-art methods based on optimization. 7 | 8 | ## Experimental Results 9 | The following two figures show the visual results of ScaDec on simulated and experimental datasets ([CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) & [Fresnel](http://iopscience.iop.org/article/10.1088/0266-5611/21/6/S09/meta)). The ScaDec substantially outperforms the optimization-based methods. 10 | 11 | ![visualExamples](images/visualExamples.jpg "Visual illustration of reconstructed images of ScaDec") 12 | 13 | ![expExamples](images/expExamples.jpg "Visual Example of Fresnel2D dataset") 14 | 15 | ## Trianing & Testing 16 | The scripts train.py and test.py are for training and testing the model ScaDec. If you want to train your own ScaDec, please checkout these two files. We also share [**pre-trained model and code here**](https://wustl.box.com/s/kjnpwgg9ktauebolqid41dozyso0q9kz). Please cite the paper if you find our work useful for your research. 17 | The matlab code does 18 | 19 | (1) generating measurements by solving 2D Lippmann Schwinger Equation 20 | 21 | (2) generating syntheized Gaussian piece-wise smoothed elipses 22 | 23 | (3) generating syntheized circles for Fresnel 2D 24 | 25 | (4) backpropagation 26 | 27 | 28 | To train/test a ScaDec, simply type: 29 | 30 | python train.py 31 | python test.py 32 | 33 | ## Environment Requirement 34 | ``` 35 | Tensorflow v1.4 36 | PIL 37 | Python 3.6 38 | Scipy 39 | ``` 40 | 41 | ## Citation 42 | If you find the paper useful in your research, please cite the paper: 43 | 44 | @article{{Sun:18, 45 | Author = {Yu Sun and Zhihao Xia and Ulugbek S. Kamilov}, 46 | Doi = {10.1364/OE.26.014678}, 47 | Journal = {Opt. Express}, 48 | Keywords = {Image reconstruction techniques; Inverse problems; Tomographic image processing; Inverse scattering}, 49 | Month = {May}, 50 | Number = {11}, 51 | Pages = {14678--14688}, 52 | Publisher = {OSA}, 53 | Title = {Efficient and accurate inversion of multiple scattering with deep learning}, 54 | Url = {http://www.opticsexpress.org/abstract.cfm?URI=oe-26-11-14678}, 55 | Volume = {26}, 56 | Year = {2018}, 57 | Bdsk-Url-1 = {http://www.opticsexpress.org/abstract.cfm?URI=oe-26-11-14678}, 58 | Bdsk-Url-2 = {https://doi.org/10.1364/OE.26.014678}} 59 | 60 | 61 | ## Erratum 62 | In the paper, the feature dimension of ScaDec is mistyped in Fig.3. The corrected feature figure is shown here. 63 | 64 | ![ScaDecArch](images/ScaDecArch.jpg "Visual illustration of ScaDec") 65 | 66 | Thanks Jiaming for pointing out this typo. If you have any concerns about this paper, feel free to contact us: sun.yu@wustl.edu 67 | 68 | 69 | -------------------------------------------------------------------------------- /images/ScaDecArch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunyumark/ScaDec-deep-learning-diffractive-tomography/ebd952f914ff7bb0109679b39cbf588b5e74c105/images/ScaDecArch.jpg -------------------------------------------------------------------------------- /images/convergence.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunyumark/ScaDec-deep-learning-diffractive-tomography/ebd952f914ff7bb0109679b39cbf588b5e74c105/images/convergence.jpg -------------------------------------------------------------------------------- /images/expExamples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunyumark/ScaDec-deep-learning-diffractive-tomography/ebd952f914ff7bb0109679b39cbf588b5e74c105/images/expExamples.jpg -------------------------------------------------------------------------------- /images/visualExamples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sunyumark/ScaDec-deep-learning-diffractive-tomography/ebd952f914ff7bb0109679b39cbf588b5e74c105/images/visualExamples.jpg -------------------------------------------------------------------------------- /scadec/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Yu Sun' 2 | __version__ = '0.1.1' 3 | __credits__ = 'Wash U, Department of Computer Science & Engineering' 4 | -------------------------------------------------------------------------------- /scadec/image_util.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | ''' 15 | Modified on Feb, 2018 based on the work of jakeret 16 | 17 | author: yusun 18 | ''' 19 | from __future__ import print_function, division, absolute_import, unicode_literals 20 | 21 | #import cv2 22 | import glob 23 | import numpy as np 24 | from PIL import Image 25 | 26 | 27 | class BaseDataProvider(object): 28 | 29 | def __init__(self, a_min=None, a_max=None): 30 | self.a_min = a_min if a_min is not None else -np.inf 31 | self.a_max = a_max if a_min is not None else np.inf 32 | 33 | def __call__(self, n, fix=False): 34 | if type(n) == int and not fix: 35 | # X and Y are the images and truths 36 | train_data, truths = self._next_batch(n) 37 | elif type(n) == int and fix: 38 | train_data, truths = self._fix_batch(n) 39 | elif type(n) == str and n == 'full': 40 | train_data, truths = self._full_batch() 41 | else: 42 | raise ValueError("Invalid batch_size: "%n) 43 | 44 | return train_data, truths 45 | 46 | def _next_batch(self, n): 47 | pass 48 | 49 | def _full_batch(self): 50 | pass 51 | 52 | 53 | class SimpleDataProvider(BaseDataProvider): 54 | 55 | def __init__(self, data, truths): 56 | super(SimpleDataProvider, self).__init__() 57 | self.data = np.float64(data) 58 | self.truths = np.float64(truths) 59 | self.img_channels = self.data[0].shape[2] 60 | self.truth_channels = self.truths[0].shape[2] 61 | self.file_count = data.shape[0] 62 | 63 | def _next_batch(self, n): 64 | idx = np.random.choice(self.file_count, n, replace=False) 65 | img = self.data[idx[0]] 66 | nx = img.shape[0] 67 | ny = img.shape[1] 68 | X = np.zeros((n, nx, ny, self.img_channels)) 69 | Y = np.zeros((n, nx, ny, self.truth_channels)) 70 | for i in range(n): 71 | X[i] = self._process_data(self.data[idx[i]]) 72 | Y[i] = self._process_truths(self.truths[idx[i]]) 73 | return X, Y 74 | 75 | def _fix_batch(self, n): 76 | # first n data 77 | img = self.data[0] 78 | nx = img.shape[0] 79 | ny = img.shape[1] 80 | X = np.zeros((n, nx, ny, self.img_channels)) 81 | Y = np.zeros((n, nx, ny, self.truth_channels)) 82 | for i in range(n): 83 | X[i] = self._process_data(self.data[i]) 84 | Y[i] = self._process_truths(self.truths[i]) 85 | return X, Y 86 | 87 | def _full_batch(self): 88 | return self.data, self.truths 89 | 90 | def _process_truths(self, truth): 91 | # normalization by channels 92 | truth = np.clip(np.fabs(truth), self.a_min, self.a_max) 93 | for channel in range(self.truth_channels): 94 | truth[:,:,channel] -= np.amin(truth[:,:,channel]) 95 | truth[:,:,channel] /= np.amax(truth[:,:,channel]) 96 | return truth 97 | 98 | def _process_data(self, data): 99 | # normalization by channels 100 | data = np.clip(np.fabs(data), self.a_min, self.a_max) 101 | for channel in range(self.img_channels): 102 | data[:,:,channel] -= np.amin(data[:,:,channel]) 103 | data[:,:,channel] /= np.amax(data[:,:,channel]) 104 | return data 105 | -------------------------------------------------------------------------------- /scadec/layers.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Modified on Feb, 2018 based on the work of jakeret 17 | 18 | author: yusun 19 | ''' 20 | from __future__ import print_function, division, absolute_import, unicode_literals 21 | 22 | import tensorflow as tf 23 | 24 | def log(x, base): 25 | numerator = tf.log(x) 26 | denominator = tf.log(tf.constant(base, dtype=numerator.dtype)) 27 | return numerator / denominator 28 | 29 | def weight_variable(shape, stddev=0.1): 30 | initial = tf.truncated_normal(shape, stddev=stddev) 31 | return tf.Variable(initial) 32 | 33 | def rescale(array_x): # convert to [0,1] 34 | amax = tf.reduce_max(array_x, axis=1, keep_dims=True) 35 | amin = tf.reduce_min(array_x, axis=1, keep_dims=True) 36 | rescaled = array_x - amin 37 | rescaled = rescaled / amax 38 | return rescaled 39 | 40 | # receives an array of images and return the mse per image. 41 | # size ~ num of pixels in the img 42 | def mse_array(array_x, array_y, size): 43 | rescale_x = array_x 44 | rescale_y = array_y 45 | se = tf.reduce_sum(tf.squared_difference(rescale_x, rescale_y), 1) 46 | inv_size = tf.to_float(1/size) 47 | return tf.scalar_mul(inv_size, se) 48 | 49 | def conv2d_bn_relu(x, w_size, num_outputs, keep_prob_, phase, scope): # output size should be the same. 50 | conv_2d = tf.contrib.layers.conv2d(x, num_outputs, w_size, 51 | activation_fn=tf.nn.relu, # elu is an alternative 52 | normalizer_fn=tf.layers.batch_normalization, 53 | normalizer_params={'training': phase}, 54 | scope=scope) 55 | 56 | return tf.nn.dropout(conv_2d, keep_prob_) 57 | 58 | def deconv2d_bn_relu(x, w_size, num_outputs, stride, keep_prob_, phase, scope): 59 | conv_2d = tf.contrib.layers.conv2d_transpose(x, num_outputs, w_size, 60 | stride=stride, 61 | activation_fn=tf.nn.relu, # elu is an alternative 62 | normalizer_fn=tf.layers.batch_normalization, 63 | normalizer_params={'training': phase}, 64 | scope=scope) 65 | 66 | return tf.nn.dropout(conv_2d, keep_prob_) 67 | 68 | def conv2d_bn(x, w_size, num_outputs, keep_prob_, phase, scope): 69 | conv_2d = tf.contrib.layers.conv2d(x, num_outputs, w_size, 70 | activation_fn=None, 71 | normalizer_fn=tf.layers.batch_normalization, 72 | normalizer_params={'training': phase}, 73 | scope=scope) 74 | return conv_2d 75 | 76 | def conv2d(x, w_size, num_outputs, keep_prob_, scope): 77 | conv_2d = tf.contrib.layers.conv2d(x, num_outputs, w_size, 78 | activation_fn=None, 79 | normalizer_fn=None, 80 | scope=scope) 81 | return conv_2d 82 | 83 | def max_pool(x,n): 84 | return tf.nn.max_pool(x, ksize=[1, n, n, 1], strides=[1, n, n, 1], padding='SAME') 85 | 86 | def concat(x1,x2): 87 | return tf.concat([x1, x2], 3) 88 | -------------------------------------------------------------------------------- /scadec/nets.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import, unicode_literals 2 | 3 | import os 4 | import shutil 5 | import numpy as np 6 | from collections import OrderedDict 7 | import logging 8 | 9 | import tensorflow as tf 10 | import tensorflow.contrib as contrib 11 | 12 | from scadec import util 13 | from scadec.layers import * 14 | 15 | 16 | def unet_decoder(x, keep_prob, phase, img_channels, truth_channels, layers=3, conv_times=3, features_root=16, filter_size=3, pool_size=2, summaries=True): 17 | """ 18 | Creates a new convolutional unet for the given parametrization. 19 | 20 | :param x: input tensor, shape [?,nx,ny,img_channels] 21 | :param keep_prob: dropout probability tensor 22 | :param img_channels: number of channels in the input image 23 | :param layers: number of layers in the net 24 | :param features_root: number of features in the first layer 25 | :param filter_size: size of the convolution filter 26 | :param pool_size: size of the max pooling operation 27 | :param summaries: Flag if summaries should be created 28 | """ 29 | 30 | logging.info("Layers {layers}, features {features}, filter size {filter_size}x{filter_size}, pool size: {pool_size}x{pool_size}".format(layers=layers, 31 | features=features_root, 32 | filter_size=filter_size, 33 | pool_size=pool_size)) 34 | 35 | # Placeholder for the input image 36 | nx = tf.shape(x)[1] 37 | ny = tf.shape(x)[2] 38 | x_image = tf.reshape(x, tf.stack([-1,nx,ny,img_channels])) 39 | batch_size = tf.shape(x_image)[0] 40 | 41 | pools = OrderedDict() # pooling layers 42 | deconvs = OrderedDict() # deconvolution layer 43 | dw_h_convs = OrderedDict() # down-side convs 44 | up_h_convs = OrderedDict() # up-side convs 45 | 46 | # conv the input image to desired feature maps 47 | in_node = conv2d_bn_relu(x_image, filter_size, features_root, keep_prob, phase, 'conv2feature_roots') 48 | 49 | # Down layers 50 | for layer in range(0, layers): 51 | features = 2**layer*features_root 52 | with tf.variable_scope('down_layer_' + str(layer)): 53 | for conv_iter in range(0, conv_times): 54 | scope = 'conv_bn_relu_{}'.format(conv_iter) 55 | conv = conv2d_bn_relu(in_node, filter_size, features, keep_prob, phase, scope) 56 | in_node = conv 57 | 58 | # store the intermediate result per layer 59 | dw_h_convs[layer] = in_node 60 | 61 | # down sampling 62 | if layer < layers-1: 63 | with tf.variable_scope('pooling'): 64 | pools[layer] = max_pool(dw_h_convs[layer], pool_size) 65 | in_node = pools[layer] 66 | 67 | in_node = dw_h_convs[layers-1] 68 | 69 | # Up layers 70 | for layer in range(layers-2, -1, -1): 71 | features = 2**(layer+1)*features_root 72 | with tf.variable_scope('up_layer_' + str(layer)): 73 | with tf.variable_scope('unsample_concat_layer'): 74 | # number of features = lower layer's number of features 75 | h_deconv = deconv2d_bn_relu(in_node, filter_size, features//2, pool_size, keep_prob, phase, 'unsample_layer') 76 | h_deconv_concat = concat(dw_h_convs[layer], h_deconv) 77 | deconvs[layer] = h_deconv_concat 78 | in_node = h_deconv_concat 79 | 80 | for conv_iter in range(0, conv_times): 81 | scope = 'conv_bn_relu_{}'.format(conv_iter) 82 | conv = conv2d_bn_relu(in_node, filter_size, features//2, keep_prob, phase, scope) 83 | in_node = conv 84 | 85 | up_h_convs[layer] = in_node 86 | 87 | in_node = up_h_convs[0] 88 | 89 | # Output with residual 90 | with tf.variable_scope("conv2d_1by1"): 91 | output = conv2d(in_node, 1, truth_channels, keep_prob, 'conv2truth_channels') 92 | up_h_convs["out"] = output 93 | 94 | if summaries: 95 | # for i, (c1, c2) in enumerate(convs): 96 | # tf.summary.image('summary_conv_%02d_01'%i, get_image_summary(c1)) 97 | # tf.summary.image('summary_conv_%02d_02'%i, get_image_summary(c2)) 98 | 99 | for k in pools.keys(): 100 | tf.summary.image('summary_pool_%02d'%k, get_image_summary(pools[k])) 101 | 102 | for k in deconvs.keys(): 103 | tf.summary.image('summary_deconv_concat_%02d'%k, get_image_summary(deconvs[k])) 104 | 105 | for k in dw_h_convs.keys(): 106 | tf.summary.histogram("dw_convolution_%02d"%k + '/activations', dw_h_convs[k]) 107 | 108 | for k in up_h_convs.keys(): 109 | tf.summary.histogram("up_convolution_%s"%k + '/activations', up_h_convs[k]) 110 | 111 | return output 112 | 113 | def get_image_summary(img, idx=0): 114 | """ 115 | Make an image summary for 4d tensor image with index idx 116 | """ 117 | 118 | V = tf.slice(img, (0, 0, 0, idx), (1, -1, -1, 1)) 119 | V -= tf.reduce_min(V) 120 | V /= tf.reduce_max(V) 121 | V *= 255 122 | 123 | img_w = tf.shape(img)[1] 124 | img_h = tf.shape(img)[2] 125 | V = tf.reshape(V, tf.stack((img_w, img_h, 1))) 126 | V = tf.transpose(V, (2, 0, 1)) 127 | V = tf.reshape(V, tf.stack((-1, img_w, img_h, 1))) 128 | return V -------------------------------------------------------------------------------- /scadec/train.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Modified on Feb, 2018 based on the work of jakeret 17 | 18 | author: yusun 19 | ''' 20 | 21 | from __future__ import print_function, division, absolute_import, unicode_literals 22 | 23 | import os 24 | import shutil 25 | import numpy as np 26 | from collections import OrderedDict 27 | import logging 28 | 29 | import tensorflow as tf 30 | from scadec import util 31 | 32 | class Trainer_bn(object): 33 | """ 34 | Trains a unet instance 35 | 36 | :param net: the unet instance to train 37 | :param batch_size: size of training batch 38 | :param optimizer: (optional) name of the optimizer to use (momentum or adam) 39 | :param opt_kwargs: (optional) kwargs passed to the learning rate (momentum opt) and to the optimizer 40 | 41 | the phase of the unet are True by default 42 | 43 | """ 44 | 45 | def __init__(self, net, batch_size=1, optimizer="adam", opt_kwargs={}): 46 | self.net = net 47 | self.batch_size = batch_size 48 | self.optimizer = optimizer 49 | self.opt_kwargs = opt_kwargs 50 | 51 | def _get_optimizer(self, training_iters, global_step): 52 | if self.optimizer == "momentum": 53 | learning_rate = self.opt_kwargs.pop("learning_rate", 0.2) 54 | decay_rate = self.opt_kwargs.pop("decay_rate", 0.95) 55 | momentum = self.opt_kwargs.pop("momentum", 0.2) 56 | 57 | self.learning_rate_node = tf.train.exponential_decay(learning_rate=learning_rate, 58 | global_step=global_step, 59 | decay_steps=training_iters, 60 | decay_rate=decay_rate, 61 | staircase=True) 62 | 63 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 64 | with tf.control_dependencies(update_ops): 65 | optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate_node, momentum=momentum, 66 | **self.opt_kwargs).minimize(self.net.loss, 67 | global_step=global_step) 68 | elif self.optimizer == "adam": 69 | learning_rate = self.opt_kwargs.pop("learning_rate", 0.001) 70 | self.learning_rate_node = tf.Variable(learning_rate) 71 | 72 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 73 | with tf.control_dependencies(update_ops): 74 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_node, 75 | **self.opt_kwargs).minimize(self.net.loss, 76 | global_step=global_step) 77 | 78 | return optimizer 79 | 80 | def _initialize(self, training_iters, output_path, restore, prediction_path): 81 | global_step = tf.Variable(0) 82 | logging.getLogger().setLevel(logging.INFO) 83 | 84 | # get optimizer 85 | self.optimizer = self._get_optimizer(training_iters, global_step) 86 | init = tf.global_variables_initializer() 87 | 88 | # get validation_path 89 | self.prediction_path = prediction_path 90 | abs_prediction_path = os.path.abspath(self.prediction_path) 91 | output_path = os.path.abspath(output_path) 92 | 93 | if not restore: 94 | logging.info("Removing '{:}'".format(abs_prediction_path)) 95 | shutil.rmtree(abs_prediction_path, ignore_errors=True) 96 | logging.info("Removing '{:}'".format(output_path)) 97 | shutil.rmtree(output_path, ignore_errors=True) 98 | 99 | if not os.path.exists(abs_prediction_path): 100 | logging.info("Allocating '{:}'".format(abs_prediction_path)) 101 | os.makedirs(abs_prediction_path) 102 | 103 | if not os.path.exists(output_path): 104 | logging.info("Allocating '{:}'".format(output_path)) 105 | os.makedirs(output_path) 106 | 107 | return init 108 | 109 | def train(self, data_provider, output_path, valid_provider, valid_size, training_iters=100, epochs=1000, dropout=0.75, display_step=1, save_epoch=50, restore=False, write_graph=False, prediction_path='validation'): 110 | """ 111 | Lauches the training process 112 | 113 | :param data_provider: callable returning training and verification data 114 | :param output_path: path where to store checkpoints 115 | :param valid_provider: data provider for the validation dataset 116 | :param valid_size: batch size for validation provider 117 | :param training_iters: number of training mini batch iteration 118 | :param epochs: number of epochs 119 | :param dropout: dropout probability 120 | :param display_step: number of steps till outputting stats 121 | :param restore: Flag if previous model should be restored 122 | :param write_graph: Flag if the computation graph should be written as protobuf file to the output path 123 | :param prediction_path: path where to save predictions on each epoch 124 | """ 125 | 126 | # initialize the training process. 127 | init = self._initialize(training_iters, output_path, restore, prediction_path) 128 | 129 | # create output path 130 | directory = os.path.join(output_path, "final/") 131 | if not os.path.exists(directory): 132 | os.makedirs(directory) 133 | 134 | save_path = os.path.join(directory, "model.cpkt") 135 | if epochs == 0: 136 | return save_path 137 | 138 | with tf.Session() as sess: 139 | if write_graph: 140 | tf.train.write_graph(sess.graph_def, output_path, "graph.pb", False) 141 | 142 | sess.run(init) 143 | 144 | if restore: 145 | ckpt = tf.train.get_checkpoint_state(output_path) 146 | if ckpt and ckpt.model_checkpoint_path: 147 | self.net.restore(sess, ckpt.model_checkpoint_path) 148 | 149 | 150 | summary_writer = tf.summary.FileWriter(output_path, graph=sess.graph) 151 | logging.info("Start optimization") 152 | 153 | # select validation dataset 154 | valid_x, valid_y = valid_provider(valid_size, fix=True) 155 | util.save_mat(valid_y, "%s/%s.mat"%(self.prediction_path, 'origin_y')) 156 | util.save_mat(valid_x, "%s/%s.mat"%(self.prediction_path, 'origin_x')) 157 | 158 | for epoch in range(epochs): 159 | total_loss = 0 160 | # batch_x, batch_y = data_provider(self.batch_size) 161 | for step in range((epoch*training_iters), ((epoch+1)*training_iters)): 162 | batch_x, batch_y = data_provider(self.batch_size) 163 | # Run optimization op (backprop) 164 | _, loss, lr, avg_psnr = sess.run([self.optimizer, 165 | self.net.loss, 166 | self.learning_rate_node, 167 | self.net.avg_psnr], 168 | feed_dict={self.net.x: batch_x, 169 | self.net.y: batch_y, 170 | self.net.keep_prob: dropout, 171 | self.net.phase: True}) 172 | 173 | if step % display_step == 0: 174 | logging.info("Iter {:} (before training on the batch) Minibatch MSE= {:.4f}, Minibatch Avg PSNR= {:.4f}".format(step, loss, avg_psnr)) 175 | self.output_minibatch_stats(sess, summary_writer, step, batch_x, batch_y) 176 | 177 | total_loss += loss 178 | 179 | self.record_summary(summary_writer, 'training_loss', loss, step) 180 | self.record_summary(summary_writer, 'training_avg_psnr', avg_psnr, step) 181 | 182 | # output statistics for epoch 183 | self.output_epoch_stats(epoch, total_loss, training_iters, lr) 184 | self.output_valstats(sess, summary_writer, step, valid_x, valid_y, "epoch_%s"%epoch, store_img=True) 185 | 186 | if epoch % save_epoch == 0: 187 | directory = os.path.join(output_path, "{}_cpkt/".format(step)) 188 | if not os.path.exists(directory): 189 | os.makedirs(directory) 190 | path = os.path.join(directory, "model.cpkt".format(step)) 191 | self.net.save(sess, path) 192 | 193 | save_path = self.net.save(sess, save_path) 194 | 195 | logging.info("Optimization Finished!") 196 | 197 | return save_path 198 | 199 | def output_epoch_stats(self, epoch, total_loss, training_iters, lr): 200 | logging.info("Epoch {:}, Average MSE: {:.4f}, learning rate: {:.4f}".format(epoch, (total_loss / training_iters), lr)) 201 | 202 | def output_minibatch_stats(self, sess, summary_writer, step, batch_x, batch_y): 203 | # Calculate batch loss and accuracy 204 | loss, predictions, avg_psnr = sess.run([self.net.loss, 205 | self.net.recons, 206 | self.net.avg_psnr], 207 | feed_dict={self.net.x: batch_x, 208 | self.net.y: batch_y, 209 | self.net.keep_prob: 1., 210 | self.net.phase: False}) 211 | 212 | self.record_summary(summary_writer, 'minibatch_loss', loss, step) 213 | self.record_summary(summary_writer, 'minibatch_avg_psnr', avg_psnr, step) 214 | 215 | logging.info("Iter {:} (After training on the batch) Minibatch MSE= {:.4f}, Minibatch Avg PSNR= {:.4f}".format(step,loss,avg_psnr)) 216 | 217 | def output_valstats(self, sess, summary_writer, step, batch_x, batch_y, name, store_img=True): 218 | prediction, loss, avg_psnr = sess.run([self.net.recons, 219 | self.net.valid_loss, 220 | self.net.valid_avg_psnr], 221 | feed_dict={self.net.x: batch_x, 222 | self.net.y: batch_y, 223 | self.net.keep_prob: 1., 224 | self.net.phase: False}) 225 | 226 | self.record_summary(summary_writer, 'valid_loss', loss, step) 227 | self.record_summary(summary_writer, 'valid_avg_psnr', avg_psnr, step) 228 | 229 | logging.info("Validation Statistics, validation loss= {:.4f}, Avg PSNR= {:.4f}".format(loss, avg_psnr)) 230 | 231 | util.save_mat(prediction, "%s/%s.mat"%(self.prediction_path, name)) 232 | 233 | if store_img: 234 | util.save_img(prediction[0,...], "%s/%s_img.tif"%(self.prediction_path, name)) 235 | 236 | def record_summary(self, writer, name, value, step): 237 | summary=tf.Summary() 238 | summary.value.add(tag=name, simple_value = value) 239 | writer.add_summary(summary, step) 240 | writer.flush() 241 | 242 | -------------------------------------------------------------------------------- /scadec/unet_bn.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Modified on Feb, 2018 based on the work of jakeret 17 | 18 | author: yusun 19 | ''' 20 | 21 | from __future__ import print_function, division, absolute_import, unicode_literals 22 | 23 | import os 24 | import shutil 25 | import math 26 | import numpy as np 27 | from collections import OrderedDict 28 | import logging 29 | 30 | import tensorflow as tf 31 | 32 | from scadec import util 33 | from scadec.layers import * 34 | from scadec.nets import * 35 | 36 | 37 | class Unet_bn(object): 38 | """ 39 | A unet implementation 40 | 41 | :param channels: (optional) number of channels in the input image 42 | :param cost: (optional) name of the cost function. Default is 'cross_entropy' 43 | :param cost_kwargs: (optional) kwargs passed to the cost function. See Unet._get_cost for more options 44 | :param kwargs: args passed to create_net function. 45 | """ 46 | 47 | def __init__(self, img_channels=3, truth_channels=3, cost="mean_squared_error", cost_kwargs={}, **kwargs): 48 | tf.reset_default_graph() 49 | 50 | # basic variables 51 | self.summaries = kwargs.get("summaries", True) 52 | self.img_channels = img_channels 53 | self.truth_channels = truth_channels 54 | 55 | # placeholders for input x and y 56 | self.x = tf.placeholder("float", shape=[None, None, None, img_channels]) 57 | self.y = tf.placeholder("float", shape=[None, None, None, truth_channels]) 58 | self.phase = tf.placeholder(tf.bool, name='phase') 59 | self.keep_prob = tf.placeholder(tf.float32) #dropout (keep probability) 60 | 61 | # reused variables 62 | self.nx = tf.shape(self.x)[1] 63 | self.ny = tf.shape(self.x)[2] 64 | self.num_examples = tf.shape(self.x)[0] 65 | 66 | # variables need to be calculated 67 | self.recons = unet_decoder(self.x, self.keep_prob, self.phase, self.img_channels, self.truth_channels, **kwargs) 68 | self.loss = self._get_cost(cost, cost_kwargs) 69 | self.valid_loss = self._get_cost(cost, cost_kwargs) 70 | self.avg_psnr = self._get_measure('avg_psnr') 71 | self.valid_avg_psnr = self._get_measure('avg_psnr') 72 | 73 | def _get_measure(self, measure): 74 | total_pixels = self.nx * self.ny * self.truth_channels 75 | dtype = self.x.dtype 76 | flat_recons = tf.reshape(self.recons, [-1, total_pixels]) 77 | flat_truths = tf.reshape(self.y, [-1, total_pixels]) 78 | 79 | if measure == 'psnr': 80 | # mse are of the same length of the truths 81 | mse = mse_array(flat_recons, flat_truths, total_pixels) 82 | term1 = log(tf.constant(1, dtype), 10.) 83 | term2 = log(mse, 10.) 84 | psnr = tf.scalar_mul(20., term1) - tf.scalar_mul(10., term2) 85 | result = psnr 86 | 87 | elif measure == 'avg_psnr': 88 | # mse are of the same length of the truths 89 | mse = mse_array(flat_recons, flat_truths, total_pixels) 90 | term1 = log(tf.constant(1, dtype), 10.) 91 | term2 = log(mse, 10.) 92 | psnr = tf.scalar_mul(20., term1) - tf.scalar_mul(10., term2) 93 | avg_psnr = tf.reduce_mean(psnr) 94 | result = avg_psnr 95 | 96 | else: 97 | raise ValueError("Unknown measure: "%cost_name) 98 | 99 | return result 100 | 101 | def _get_cost(self, cost_name, cost_kwargs): 102 | """ 103 | Constructs the cost function. 104 | 105 | """ 106 | 107 | total_pixels = self.nx * self.ny * self.truth_channels 108 | flat_recons = tf.reshape(self.recons, [-1, total_pixels]) 109 | flat_truths = tf.reshape(self.y, [-1, total_pixels]) 110 | if cost_name == "mean_squared_error": 111 | loss = tf.losses.mean_squared_error(flat_recons, flat_truths) 112 | # the mean_squared_error is equal to the following code 113 | # se = tf.squared_difference(flat_recons, flat_truths) 114 | # loss = tf.reduce_mean(se, 1) 115 | 116 | # add new loss function here 117 | else: 118 | raise ValueError("Unknown cost function: "%cost_name) 119 | 120 | return loss 121 | 122 | # predict 123 | def predict(self, model_path, x_test, keep_prob, phase): 124 | """ 125 | Uses the model to create a prediction for the given data 126 | 127 | :param model_path: path to the model checkpoint to restore 128 | :param x_test: Data to predict on. Shape [n, nx, ny, channels] 129 | :returns prediction: The unet prediction Shape [n, px, py, labels] (px=nx-self.offset/2) 130 | """ 131 | 132 | init = tf.global_variables_initializer() 133 | with tf.Session() as sess: 134 | # Initialize variables 135 | sess.run(init) 136 | 137 | # Restore model weights from previously saved model 138 | self.restore(sess, model_path) 139 | 140 | prediction = sess.run(self.recons, feed_dict={self.x: x_test, 141 | self.keep_prob: keep_prob, 142 | self.phase: phase}) # set phase to False for every prediction 143 | # define operation 144 | return prediction 145 | 146 | def save(self, sess, model_path): 147 | """ 148 | Saves the current session to a checkpoint 149 | 150 | :param sess: current session 151 | :param model_path: path to file system location 152 | """ 153 | 154 | saver = tf.train.Saver() 155 | save_path = saver.save(sess, model_path) 156 | return save_path 157 | 158 | def restore(self, sess, model_path): 159 | """ 160 | Restores a session from a checkpoint 161 | 162 | :param sess: current session instance 163 | :param model_path: path to file system checkpoint location 164 | """ 165 | 166 | saver = tf.train.Saver() 167 | saver.restore(sess, model_path) 168 | logging.info("Model restored from file: %s" % model_path) -------------------------------------------------------------------------------- /scadec/util.py: -------------------------------------------------------------------------------- 1 | # tf_unet is free software: you can redistribute it and/or modify 2 | # it under the terms of the GNU General Public License as published by 3 | # the Free Software Foundation, either version 3 of the License, or 4 | # (at your option) any later version. 5 | # 6 | # tf_unet is distributed in the hope that it will be useful, 7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 9 | # GNU General Public License for more details. 10 | # 11 | # You should have received a copy of the GNU General Public License 12 | # along with tf_unet. If not, see . 13 | 14 | 15 | ''' 16 | Modified on Feb, 2018 based on the work of jakeret 17 | 18 | author: yusun 19 | ''' 20 | from __future__ import print_function, division, absolute_import, unicode_literals 21 | import numpy as np 22 | import scipy.io as sio 23 | import scipy.misc as smisc 24 | 25 | 26 | def to_rgb(img): 27 | """ 28 | Converts the given array into a RGB image. If the number of channels is not 29 | 3 the array is tiled such that it has 3 channels. Finally, the values are 30 | rescaled to [0,255) 31 | 32 | :param img: the array to convert [nx, ny, channels] 33 | 34 | :returns img: the rgb image [nx, ny, 3] 35 | """ 36 | img = np.atleast_3d(img) 37 | channels = img.shape[2] 38 | if channels < 3: 39 | img = np.tile(img, 3) 40 | 41 | img[np.isnan(img)] = 0 42 | img -= np.amin(img) 43 | img /= np.amax(img) 44 | img *= 255 45 | return img 46 | 47 | def to_double(img): 48 | img = np.atleast_3d(img) 49 | channels = img.shape[2] 50 | if channels < 3: 51 | img = np.tile(img, 3) 52 | 53 | img[np.isnan(img)] = 0 54 | img -= np.amin(img) 55 | img /= np.amax(img) 56 | return img 57 | 58 | def save_mat(img, path): 59 | """ 60 | Writes the image to disk 61 | 62 | :param img: the rgb image to save 63 | :param path: the target path 64 | """ 65 | 66 | sio.savemat(path, {'img':img}) 67 | 68 | 69 | def save_img(img, path): 70 | """ 71 | Writes the image to disk 72 | 73 | :param img: the rgb image to save 74 | :param path: the target path 75 | """ 76 | img = to_rgb(img) 77 | smisc.imsave(path, img.round().astype(np.uint8)) 78 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from scadec.unet_bn import Unet_bn 2 | from scadec.train import Trainer_bn 3 | 4 | from scadec import image_util 5 | from scadec import util 6 | 7 | import scipy.io as spio 8 | import numpy as np 9 | import os 10 | 11 | #################################################### 12 | #### PREPARE WORKSPACE ### 13 | #################################################### 14 | 15 | # here indicating the GPU you want to use. if you don't have GPU, just leave it. 16 | gpu_vis = '3' 17 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_vis; # 0,1,2,3 18 | 19 | # here specify the path of the model you want to load 20 | gpu_ind = '0' 21 | model_path = 'gpu' + gpu_ind + '/models/60099_cpkt/models/final/model.cpkt' 22 | 23 | data_channels = 2 24 | truth_channels = 1 25 | 26 | #################################################### 27 | #### FUNCTIONS ### 28 | #################################################### 29 | 30 | # make the data a 4D vector 31 | def preprocess(data, channels): 32 | nx = data.shape[1] 33 | ny = data.shape[2] 34 | return data.reshape((-1, nx, ny, channels)) 35 | 36 | #################################################### 37 | #### lOAD MODEL ### 38 | #################################################### 39 | 40 | # set up args for the unet, should be exactly the same as the loading model 41 | kwargs = { 42 | "layers": 5, 43 | "conv_times": 2, 44 | "features_root": 64, 45 | "filter_size": 3, 46 | "pool_size": 2, 47 | "summaries": True 48 | } 49 | 50 | net = Unet_bn(img_channels=data_channels, truth_channels=truth_channels, cost="mean_squared_error", **kwargs) 51 | 52 | 53 | #################################################### 54 | #### lOAD TRAIN ### 55 | #################################################### 56 | 57 | #preparing training data 58 | data_mat = spio.loadmat('train_np/obhatGausWeak128_40.mat', squeeze_me=True) 59 | truths_mat = spio.loadmat('train_np/obGausWeak128_40.mat', squeeze_me=True) 60 | 61 | data = data_mat['obhatGausWeak128'] 62 | data = preprocess(data, data_channels) # 4 dimension -> 3 dimension if you do data[:,:,:,1] 63 | truths = preprocess(truths_mat['obGausWeak128'], truth_channels) 64 | 65 | data_provider = image_util.SimpleDataProvider(data, truths) 66 | 67 | 68 | #################################################### 69 | #### lOAD TEST ### 70 | #################################################### 71 | 72 | vdata_mat = spio.loadmat('test_np_noise/obhatGausWeak{}Noise128.mat'.format(level), squeeze_me=True) 73 | vtruths_mat = spio.loadmat('valid_np/obGausN1S128val.mat', squeeze_me=True) 74 | 75 | vdata = vdata_mat['obhatGausWeak128'] 76 | vdata = preprocess(vdata, data_channels) 77 | vtruths = preprocess(vtruths_mat['obGausN1S128val'], truth_channels) 78 | 79 | valid_provider = image_util.SimpleDataProvider(vdata, vtruths) 80 | 81 | #################################################### 82 | #### PREDICT ### 83 | #################################################### 84 | 85 | predicts = [] 86 | 87 | valid_x, valid_y = valid_provider('full') 88 | num = valid_x.shape[0] 89 | 90 | for i in range(num): 91 | 92 | print('') 93 | print('') 94 | print('************* {} *************'.format(i)) 95 | print('') 96 | print('') 97 | 98 | x_train, y_train = data_provider(23) 99 | x_input = valid_x[i:i+1,:,:,:] 100 | x_input = np.concatenate((x_input, x_train), axis=0) 101 | predict = net.predict(model_path, x_input, 1, True) 102 | predicts.append(predict[0:1,:,:]) 103 | 104 | predicts = np.concatenate(predicts, axis=0) 105 | util.save_mat(predicts, 'test{}Noise.mat'.format(level)) 106 | 107 | 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from scadec.unet_bn import Unet_bn 2 | from scadec.train import Trainer_bn 3 | 4 | from scadec import image_util 5 | 6 | import scipy.io as spio 7 | import numpy as np 8 | import os 9 | 10 | #################################################### 11 | #### FUNCTIONS ### 12 | #################################################### 13 | 14 | # make the data a 4D vector 15 | def preprocess(data, channels): 16 | nx = data.shape[1] 17 | ny = data.shape[2] 18 | return data.reshape((-1, nx, ny, channels)) 19 | 20 | #################################################### 21 | #### HYPER-PARAMETERS ### 22 | #################################################### 23 | 24 | # here indicating the GPU you want to use. if you don't have GPU, just leave it. 25 | gpu_ind = '2' 26 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ind; # 0,1,2,3 27 | 28 | 29 | # Because we have real & imaginary part of our input, data_channels is set to 2 30 | data_channels = 2 31 | truth_channels = 1 32 | 33 | #################################################### 34 | #### DATA LOADING ### 35 | #################################################### 36 | 37 | """ 38 | here loads all the data we need for training and validating. 39 | 40 | """ 41 | 42 | #-- Training Data --# 43 | data_mat = spio.loadmat('train_np/obhatGausWeak128_40.mat', squeeze_me=True) 44 | truths_mat = spio.loadmat('train_np/obGausWeak128_40.mat', squeeze_me=True) 45 | 46 | data = data_mat['obhatGausWeak128'] 47 | data = preprocess(data, data_channels) # 4 dimension -> 3 dimension if you do data[:,:,:,1] 48 | truths = preprocess(truths_mat['obGausWeak128'], truth_channels) 49 | 50 | data_provider = image_util.SimpleDataProvider(data, truths) 51 | 52 | #-- Validating Data --# 53 | vdata_mat = spio.loadmat('valid_np/obhatGausWeak128val_40.mat', squeeze_me=True) 54 | vtruths_mat = spio.loadmat('valid_np/obGausWeak128val_40.mat', squeeze_me=True) 55 | 56 | vdata = vdata_mat['obhatGausWeak128val'] 57 | vdata = preprocess(vdata, data_channels) 58 | vtruths = preprocess(vtruths_mat['obGausWeak128val'], truth_channels) 59 | 60 | valid_provider = image_util.SimpleDataProvider(vdata, vtruths) 61 | 62 | 63 | #################################################### 64 | #### NETWORK ### 65 | #################################################### 66 | 67 | """ 68 | here we specify the neural network. 69 | 70 | """ 71 | 72 | #-- Network Setup --# 73 | # set up args for the unet 74 | kwargs = { 75 | "layers": 5, # how many resolution levels we want to have 76 | "conv_times": 2, # how many times we want to convolve in each level 77 | "features_root": 64, # how many feature_maps we want to have as root (the following levels will calculate the feature_map by multiply by 2, exp, 64, 128, 256) 78 | "filter_size": 3, # filter size used in convolution 79 | "pool_size": 2, # pooling size used in max-pooling 80 | "summaries": True 81 | } 82 | 83 | net = Unet_bn(img_channels=data_channels, truth_channels=truth_channels, cost="mean_squared_error", **kwargs) 84 | 85 | 86 | #################################################### 87 | #### TRAINING ### 88 | #################################################### 89 | 90 | # args for training 91 | batch_size = 24 # batch size for training 92 | valid_size = 24 # batch size for validating 93 | optimizer = "adam" # optimizer we want to use, 'adam' or 'momentum' 94 | 95 | # output paths for results 96 | output_path = 'gpu' + gpu_ind + '/models' 97 | prediction_path = 'gpu' + gpu_ind + '/validation' 98 | # restore_path = 'gpu001/models/50099_cpkt' 99 | 100 | # optional args 101 | opt_kwargs = { 102 | 'learning_rate': 0.001 103 | } 104 | 105 | # make a trainer for scadec 106 | trainer = Trainer_bn(net, batch_size=batch_size, optimizer = "adam", opt_kwargs=opt_kwargs) 107 | path = trainer.train(data_provider, output_path, valid_provider, valid_size, training_iters=100, epochs=1000, display_step=20, save_epoch=100, prediction_path=prediction_path) 108 | 109 | 110 | 111 | 112 | 113 | --------------------------------------------------------------------------------