├── README.md
├── images
├── ScaDecArch.jpg
├── convergence.jpg
├── expExamples.jpg
└── visualExamples.jpg
├── scadec
├── __init__.py
├── image_util.py
├── layers.py
├── nets.py
├── train.py
├── unet_bn.py
└── util.py
├── test.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | # [Efficient and accurate inversion of multiple scattering with deep learning](https://www.osapublishing.org/oe/abstract.cfm?uri=oe-26-11-14678&origin=search)
2 |
3 | This is the python implementation of the deep leraning model [ScaDec](https://www.osapublishing.org/oe/abstract.cfm?uri=oe-26-11-14678&origin=search) for inverting multiple light scattering in a surpervised manner. The [paper](https://www.osapublishing.org/oe/abstract.cfm?uri=oe-26-11-14678&origin=search) is originally published on [Optics Express](https://www.osapublishing.org/oe/home.cfm). The arxiv version of the paper is available [here](https://arxiv.org/abs/1803.06594)
4 |
5 | ## Abstract
6 | Image reconstruction under multiple light scattering is crucial in a number of applications such as diffraction tomography. The reconstruction problem is often formulated as a nonconvex optimization, where a nonlinear measurement model is used to account for multiple scattering and regularization is used to enforce prior constraints on the object. In this paper, we propose a powerful alternative to this optimization-based view of image reconstruction by designing and training a deep convolutional neural network that can invert multiple scattered measurements to produce a high-quality image of the refractive index. Our results on both simulated and experimental datasets show that the proposed approach is substantially faster and achieves higher imaging quality compared to the state-of-the-art methods based on optimization.
7 |
8 | ## Experimental Results
9 | The following two figures show the visual results of ScaDec on simulated and experimental datasets ([CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) & [Fresnel](http://iopscience.iop.org/article/10.1088/0266-5611/21/6/S09/meta)). The ScaDec substantially outperforms the optimization-based methods.
10 |
11 | 
12 |
13 | 
14 |
15 | ## Trianing & Testing
16 | The scripts train.py and test.py are for training and testing the model ScaDec. If you want to train your own ScaDec, please checkout these two files. We also share [**pre-trained model and code here**](https://wustl.box.com/s/kjnpwgg9ktauebolqid41dozyso0q9kz). Please cite the paper if you find our work useful for your research.
17 | The matlab code does
18 |
19 | (1) generating measurements by solving 2D Lippmann Schwinger Equation
20 |
21 | (2) generating syntheized Gaussian piece-wise smoothed elipses
22 |
23 | (3) generating syntheized circles for Fresnel 2D
24 |
25 | (4) backpropagation
26 |
27 |
28 | To train/test a ScaDec, simply type:
29 |
30 | python train.py
31 | python test.py
32 |
33 | ## Environment Requirement
34 | ```
35 | Tensorflow v1.4
36 | PIL
37 | Python 3.6
38 | Scipy
39 | ```
40 |
41 | ## Citation
42 | If you find the paper useful in your research, please cite the paper:
43 |
44 | @article{{Sun:18,
45 | Author = {Yu Sun and Zhihao Xia and Ulugbek S. Kamilov},
46 | Doi = {10.1364/OE.26.014678},
47 | Journal = {Opt. Express},
48 | Keywords = {Image reconstruction techniques; Inverse problems; Tomographic image processing; Inverse scattering},
49 | Month = {May},
50 | Number = {11},
51 | Pages = {14678--14688},
52 | Publisher = {OSA},
53 | Title = {Efficient and accurate inversion of multiple scattering with deep learning},
54 | Url = {http://www.opticsexpress.org/abstract.cfm?URI=oe-26-11-14678},
55 | Volume = {26},
56 | Year = {2018},
57 | Bdsk-Url-1 = {http://www.opticsexpress.org/abstract.cfm?URI=oe-26-11-14678},
58 | Bdsk-Url-2 = {https://doi.org/10.1364/OE.26.014678}}
59 |
60 |
61 | ## Erratum
62 | In the paper, the feature dimension of ScaDec is mistyped in Fig.3. The corrected feature figure is shown here.
63 |
64 | 
65 |
66 | Thanks Jiaming for pointing out this typo. If you have any concerns about this paper, feel free to contact us: sun.yu@wustl.edu
67 |
68 |
69 |
--------------------------------------------------------------------------------
/images/ScaDecArch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunyumark/ScaDec-deep-learning-diffractive-tomography/ebd952f914ff7bb0109679b39cbf588b5e74c105/images/ScaDecArch.jpg
--------------------------------------------------------------------------------
/images/convergence.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunyumark/ScaDec-deep-learning-diffractive-tomography/ebd952f914ff7bb0109679b39cbf588b5e74c105/images/convergence.jpg
--------------------------------------------------------------------------------
/images/expExamples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunyumark/ScaDec-deep-learning-diffractive-tomography/ebd952f914ff7bb0109679b39cbf588b5e74c105/images/expExamples.jpg
--------------------------------------------------------------------------------
/images/visualExamples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sunyumark/ScaDec-deep-learning-diffractive-tomography/ebd952f914ff7bb0109679b39cbf588b5e74c105/images/visualExamples.jpg
--------------------------------------------------------------------------------
/scadec/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'Yu Sun'
2 | __version__ = '0.1.1'
3 | __credits__ = 'Wash U, Department of Computer Science & Engineering'
4 |
--------------------------------------------------------------------------------
/scadec/image_util.py:
--------------------------------------------------------------------------------
1 | # tf_unet is free software: you can redistribute it and/or modify
2 | # it under the terms of the GNU General Public License as published by
3 | # the Free Software Foundation, either version 3 of the License, or
4 | # (at your option) any later version.
5 | #
6 | # tf_unet is distributed in the hope that it will be useful,
7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 | # GNU General Public License for more details.
10 | #
11 | # You should have received a copy of the GNU General Public License
12 | # along with tf_unet. If not, see .
13 |
14 | '''
15 | Modified on Feb, 2018 based on the work of jakeret
16 |
17 | author: yusun
18 | '''
19 | from __future__ import print_function, division, absolute_import, unicode_literals
20 |
21 | #import cv2
22 | import glob
23 | import numpy as np
24 | from PIL import Image
25 |
26 |
27 | class BaseDataProvider(object):
28 |
29 | def __init__(self, a_min=None, a_max=None):
30 | self.a_min = a_min if a_min is not None else -np.inf
31 | self.a_max = a_max if a_min is not None else np.inf
32 |
33 | def __call__(self, n, fix=False):
34 | if type(n) == int and not fix:
35 | # X and Y are the images and truths
36 | train_data, truths = self._next_batch(n)
37 | elif type(n) == int and fix:
38 | train_data, truths = self._fix_batch(n)
39 | elif type(n) == str and n == 'full':
40 | train_data, truths = self._full_batch()
41 | else:
42 | raise ValueError("Invalid batch_size: "%n)
43 |
44 | return train_data, truths
45 |
46 | def _next_batch(self, n):
47 | pass
48 |
49 | def _full_batch(self):
50 | pass
51 |
52 |
53 | class SimpleDataProvider(BaseDataProvider):
54 |
55 | def __init__(self, data, truths):
56 | super(SimpleDataProvider, self).__init__()
57 | self.data = np.float64(data)
58 | self.truths = np.float64(truths)
59 | self.img_channels = self.data[0].shape[2]
60 | self.truth_channels = self.truths[0].shape[2]
61 | self.file_count = data.shape[0]
62 |
63 | def _next_batch(self, n):
64 | idx = np.random.choice(self.file_count, n, replace=False)
65 | img = self.data[idx[0]]
66 | nx = img.shape[0]
67 | ny = img.shape[1]
68 | X = np.zeros((n, nx, ny, self.img_channels))
69 | Y = np.zeros((n, nx, ny, self.truth_channels))
70 | for i in range(n):
71 | X[i] = self._process_data(self.data[idx[i]])
72 | Y[i] = self._process_truths(self.truths[idx[i]])
73 | return X, Y
74 |
75 | def _fix_batch(self, n):
76 | # first n data
77 | img = self.data[0]
78 | nx = img.shape[0]
79 | ny = img.shape[1]
80 | X = np.zeros((n, nx, ny, self.img_channels))
81 | Y = np.zeros((n, nx, ny, self.truth_channels))
82 | for i in range(n):
83 | X[i] = self._process_data(self.data[i])
84 | Y[i] = self._process_truths(self.truths[i])
85 | return X, Y
86 |
87 | def _full_batch(self):
88 | return self.data, self.truths
89 |
90 | def _process_truths(self, truth):
91 | # normalization by channels
92 | truth = np.clip(np.fabs(truth), self.a_min, self.a_max)
93 | for channel in range(self.truth_channels):
94 | truth[:,:,channel] -= np.amin(truth[:,:,channel])
95 | truth[:,:,channel] /= np.amax(truth[:,:,channel])
96 | return truth
97 |
98 | def _process_data(self, data):
99 | # normalization by channels
100 | data = np.clip(np.fabs(data), self.a_min, self.a_max)
101 | for channel in range(self.img_channels):
102 | data[:,:,channel] -= np.amin(data[:,:,channel])
103 | data[:,:,channel] /= np.amax(data[:,:,channel])
104 | return data
105 |
--------------------------------------------------------------------------------
/scadec/layers.py:
--------------------------------------------------------------------------------
1 | # tf_unet is free software: you can redistribute it and/or modify
2 | # it under the terms of the GNU General Public License as published by
3 | # the Free Software Foundation, either version 3 of the License, or
4 | # (at your option) any later version.
5 | #
6 | # tf_unet is distributed in the hope that it will be useful,
7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 | # GNU General Public License for more details.
10 | #
11 | # You should have received a copy of the GNU General Public License
12 | # along with tf_unet. If not, see .
13 |
14 |
15 | '''
16 | Modified on Feb, 2018 based on the work of jakeret
17 |
18 | author: yusun
19 | '''
20 | from __future__ import print_function, division, absolute_import, unicode_literals
21 |
22 | import tensorflow as tf
23 |
24 | def log(x, base):
25 | numerator = tf.log(x)
26 | denominator = tf.log(tf.constant(base, dtype=numerator.dtype))
27 | return numerator / denominator
28 |
29 | def weight_variable(shape, stddev=0.1):
30 | initial = tf.truncated_normal(shape, stddev=stddev)
31 | return tf.Variable(initial)
32 |
33 | def rescale(array_x): # convert to [0,1]
34 | amax = tf.reduce_max(array_x, axis=1, keep_dims=True)
35 | amin = tf.reduce_min(array_x, axis=1, keep_dims=True)
36 | rescaled = array_x - amin
37 | rescaled = rescaled / amax
38 | return rescaled
39 |
40 | # receives an array of images and return the mse per image.
41 | # size ~ num of pixels in the img
42 | def mse_array(array_x, array_y, size):
43 | rescale_x = array_x
44 | rescale_y = array_y
45 | se = tf.reduce_sum(tf.squared_difference(rescale_x, rescale_y), 1)
46 | inv_size = tf.to_float(1/size)
47 | return tf.scalar_mul(inv_size, se)
48 |
49 | def conv2d_bn_relu(x, w_size, num_outputs, keep_prob_, phase, scope): # output size should be the same.
50 | conv_2d = tf.contrib.layers.conv2d(x, num_outputs, w_size,
51 | activation_fn=tf.nn.relu, # elu is an alternative
52 | normalizer_fn=tf.layers.batch_normalization,
53 | normalizer_params={'training': phase},
54 | scope=scope)
55 |
56 | return tf.nn.dropout(conv_2d, keep_prob_)
57 |
58 | def deconv2d_bn_relu(x, w_size, num_outputs, stride, keep_prob_, phase, scope):
59 | conv_2d = tf.contrib.layers.conv2d_transpose(x, num_outputs, w_size,
60 | stride=stride,
61 | activation_fn=tf.nn.relu, # elu is an alternative
62 | normalizer_fn=tf.layers.batch_normalization,
63 | normalizer_params={'training': phase},
64 | scope=scope)
65 |
66 | return tf.nn.dropout(conv_2d, keep_prob_)
67 |
68 | def conv2d_bn(x, w_size, num_outputs, keep_prob_, phase, scope):
69 | conv_2d = tf.contrib.layers.conv2d(x, num_outputs, w_size,
70 | activation_fn=None,
71 | normalizer_fn=tf.layers.batch_normalization,
72 | normalizer_params={'training': phase},
73 | scope=scope)
74 | return conv_2d
75 |
76 | def conv2d(x, w_size, num_outputs, keep_prob_, scope):
77 | conv_2d = tf.contrib.layers.conv2d(x, num_outputs, w_size,
78 | activation_fn=None,
79 | normalizer_fn=None,
80 | scope=scope)
81 | return conv_2d
82 |
83 | def max_pool(x,n):
84 | return tf.nn.max_pool(x, ksize=[1, n, n, 1], strides=[1, n, n, 1], padding='SAME')
85 |
86 | def concat(x1,x2):
87 | return tf.concat([x1, x2], 3)
88 |
--------------------------------------------------------------------------------
/scadec/nets.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division, absolute_import, unicode_literals
2 |
3 | import os
4 | import shutil
5 | import numpy as np
6 | from collections import OrderedDict
7 | import logging
8 |
9 | import tensorflow as tf
10 | import tensorflow.contrib as contrib
11 |
12 | from scadec import util
13 | from scadec.layers import *
14 |
15 |
16 | def unet_decoder(x, keep_prob, phase, img_channels, truth_channels, layers=3, conv_times=3, features_root=16, filter_size=3, pool_size=2, summaries=True):
17 | """
18 | Creates a new convolutional unet for the given parametrization.
19 |
20 | :param x: input tensor, shape [?,nx,ny,img_channels]
21 | :param keep_prob: dropout probability tensor
22 | :param img_channels: number of channels in the input image
23 | :param layers: number of layers in the net
24 | :param features_root: number of features in the first layer
25 | :param filter_size: size of the convolution filter
26 | :param pool_size: size of the max pooling operation
27 | :param summaries: Flag if summaries should be created
28 | """
29 |
30 | logging.info("Layers {layers}, features {features}, filter size {filter_size}x{filter_size}, pool size: {pool_size}x{pool_size}".format(layers=layers,
31 | features=features_root,
32 | filter_size=filter_size,
33 | pool_size=pool_size))
34 |
35 | # Placeholder for the input image
36 | nx = tf.shape(x)[1]
37 | ny = tf.shape(x)[2]
38 | x_image = tf.reshape(x, tf.stack([-1,nx,ny,img_channels]))
39 | batch_size = tf.shape(x_image)[0]
40 |
41 | pools = OrderedDict() # pooling layers
42 | deconvs = OrderedDict() # deconvolution layer
43 | dw_h_convs = OrderedDict() # down-side convs
44 | up_h_convs = OrderedDict() # up-side convs
45 |
46 | # conv the input image to desired feature maps
47 | in_node = conv2d_bn_relu(x_image, filter_size, features_root, keep_prob, phase, 'conv2feature_roots')
48 |
49 | # Down layers
50 | for layer in range(0, layers):
51 | features = 2**layer*features_root
52 | with tf.variable_scope('down_layer_' + str(layer)):
53 | for conv_iter in range(0, conv_times):
54 | scope = 'conv_bn_relu_{}'.format(conv_iter)
55 | conv = conv2d_bn_relu(in_node, filter_size, features, keep_prob, phase, scope)
56 | in_node = conv
57 |
58 | # store the intermediate result per layer
59 | dw_h_convs[layer] = in_node
60 |
61 | # down sampling
62 | if layer < layers-1:
63 | with tf.variable_scope('pooling'):
64 | pools[layer] = max_pool(dw_h_convs[layer], pool_size)
65 | in_node = pools[layer]
66 |
67 | in_node = dw_h_convs[layers-1]
68 |
69 | # Up layers
70 | for layer in range(layers-2, -1, -1):
71 | features = 2**(layer+1)*features_root
72 | with tf.variable_scope('up_layer_' + str(layer)):
73 | with tf.variable_scope('unsample_concat_layer'):
74 | # number of features = lower layer's number of features
75 | h_deconv = deconv2d_bn_relu(in_node, filter_size, features//2, pool_size, keep_prob, phase, 'unsample_layer')
76 | h_deconv_concat = concat(dw_h_convs[layer], h_deconv)
77 | deconvs[layer] = h_deconv_concat
78 | in_node = h_deconv_concat
79 |
80 | for conv_iter in range(0, conv_times):
81 | scope = 'conv_bn_relu_{}'.format(conv_iter)
82 | conv = conv2d_bn_relu(in_node, filter_size, features//2, keep_prob, phase, scope)
83 | in_node = conv
84 |
85 | up_h_convs[layer] = in_node
86 |
87 | in_node = up_h_convs[0]
88 |
89 | # Output with residual
90 | with tf.variable_scope("conv2d_1by1"):
91 | output = conv2d(in_node, 1, truth_channels, keep_prob, 'conv2truth_channels')
92 | up_h_convs["out"] = output
93 |
94 | if summaries:
95 | # for i, (c1, c2) in enumerate(convs):
96 | # tf.summary.image('summary_conv_%02d_01'%i, get_image_summary(c1))
97 | # tf.summary.image('summary_conv_%02d_02'%i, get_image_summary(c2))
98 |
99 | for k in pools.keys():
100 | tf.summary.image('summary_pool_%02d'%k, get_image_summary(pools[k]))
101 |
102 | for k in deconvs.keys():
103 | tf.summary.image('summary_deconv_concat_%02d'%k, get_image_summary(deconvs[k]))
104 |
105 | for k in dw_h_convs.keys():
106 | tf.summary.histogram("dw_convolution_%02d"%k + '/activations', dw_h_convs[k])
107 |
108 | for k in up_h_convs.keys():
109 | tf.summary.histogram("up_convolution_%s"%k + '/activations', up_h_convs[k])
110 |
111 | return output
112 |
113 | def get_image_summary(img, idx=0):
114 | """
115 | Make an image summary for 4d tensor image with index idx
116 | """
117 |
118 | V = tf.slice(img, (0, 0, 0, idx), (1, -1, -1, 1))
119 | V -= tf.reduce_min(V)
120 | V /= tf.reduce_max(V)
121 | V *= 255
122 |
123 | img_w = tf.shape(img)[1]
124 | img_h = tf.shape(img)[2]
125 | V = tf.reshape(V, tf.stack((img_w, img_h, 1)))
126 | V = tf.transpose(V, (2, 0, 1))
127 | V = tf.reshape(V, tf.stack((-1, img_w, img_h, 1)))
128 | return V
--------------------------------------------------------------------------------
/scadec/train.py:
--------------------------------------------------------------------------------
1 | # tf_unet is free software: you can redistribute it and/or modify
2 | # it under the terms of the GNU General Public License as published by
3 | # the Free Software Foundation, either version 3 of the License, or
4 | # (at your option) any later version.
5 | #
6 | # tf_unet is distributed in the hope that it will be useful,
7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 | # GNU General Public License for more details.
10 | #
11 | # You should have received a copy of the GNU General Public License
12 | # along with tf_unet. If not, see .
13 |
14 |
15 | '''
16 | Modified on Feb, 2018 based on the work of jakeret
17 |
18 | author: yusun
19 | '''
20 |
21 | from __future__ import print_function, division, absolute_import, unicode_literals
22 |
23 | import os
24 | import shutil
25 | import numpy as np
26 | from collections import OrderedDict
27 | import logging
28 |
29 | import tensorflow as tf
30 | from scadec import util
31 |
32 | class Trainer_bn(object):
33 | """
34 | Trains a unet instance
35 |
36 | :param net: the unet instance to train
37 | :param batch_size: size of training batch
38 | :param optimizer: (optional) name of the optimizer to use (momentum or adam)
39 | :param opt_kwargs: (optional) kwargs passed to the learning rate (momentum opt) and to the optimizer
40 |
41 | the phase of the unet are True by default
42 |
43 | """
44 |
45 | def __init__(self, net, batch_size=1, optimizer="adam", opt_kwargs={}):
46 | self.net = net
47 | self.batch_size = batch_size
48 | self.optimizer = optimizer
49 | self.opt_kwargs = opt_kwargs
50 |
51 | def _get_optimizer(self, training_iters, global_step):
52 | if self.optimizer == "momentum":
53 | learning_rate = self.opt_kwargs.pop("learning_rate", 0.2)
54 | decay_rate = self.opt_kwargs.pop("decay_rate", 0.95)
55 | momentum = self.opt_kwargs.pop("momentum", 0.2)
56 |
57 | self.learning_rate_node = tf.train.exponential_decay(learning_rate=learning_rate,
58 | global_step=global_step,
59 | decay_steps=training_iters,
60 | decay_rate=decay_rate,
61 | staircase=True)
62 |
63 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
64 | with tf.control_dependencies(update_ops):
65 | optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate_node, momentum=momentum,
66 | **self.opt_kwargs).minimize(self.net.loss,
67 | global_step=global_step)
68 | elif self.optimizer == "adam":
69 | learning_rate = self.opt_kwargs.pop("learning_rate", 0.001)
70 | self.learning_rate_node = tf.Variable(learning_rate)
71 |
72 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
73 | with tf.control_dependencies(update_ops):
74 | optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate_node,
75 | **self.opt_kwargs).minimize(self.net.loss,
76 | global_step=global_step)
77 |
78 | return optimizer
79 |
80 | def _initialize(self, training_iters, output_path, restore, prediction_path):
81 | global_step = tf.Variable(0)
82 | logging.getLogger().setLevel(logging.INFO)
83 |
84 | # get optimizer
85 | self.optimizer = self._get_optimizer(training_iters, global_step)
86 | init = tf.global_variables_initializer()
87 |
88 | # get validation_path
89 | self.prediction_path = prediction_path
90 | abs_prediction_path = os.path.abspath(self.prediction_path)
91 | output_path = os.path.abspath(output_path)
92 |
93 | if not restore:
94 | logging.info("Removing '{:}'".format(abs_prediction_path))
95 | shutil.rmtree(abs_prediction_path, ignore_errors=True)
96 | logging.info("Removing '{:}'".format(output_path))
97 | shutil.rmtree(output_path, ignore_errors=True)
98 |
99 | if not os.path.exists(abs_prediction_path):
100 | logging.info("Allocating '{:}'".format(abs_prediction_path))
101 | os.makedirs(abs_prediction_path)
102 |
103 | if not os.path.exists(output_path):
104 | logging.info("Allocating '{:}'".format(output_path))
105 | os.makedirs(output_path)
106 |
107 | return init
108 |
109 | def train(self, data_provider, output_path, valid_provider, valid_size, training_iters=100, epochs=1000, dropout=0.75, display_step=1, save_epoch=50, restore=False, write_graph=False, prediction_path='validation'):
110 | """
111 | Lauches the training process
112 |
113 | :param data_provider: callable returning training and verification data
114 | :param output_path: path where to store checkpoints
115 | :param valid_provider: data provider for the validation dataset
116 | :param valid_size: batch size for validation provider
117 | :param training_iters: number of training mini batch iteration
118 | :param epochs: number of epochs
119 | :param dropout: dropout probability
120 | :param display_step: number of steps till outputting stats
121 | :param restore: Flag if previous model should be restored
122 | :param write_graph: Flag if the computation graph should be written as protobuf file to the output path
123 | :param prediction_path: path where to save predictions on each epoch
124 | """
125 |
126 | # initialize the training process.
127 | init = self._initialize(training_iters, output_path, restore, prediction_path)
128 |
129 | # create output path
130 | directory = os.path.join(output_path, "final/")
131 | if not os.path.exists(directory):
132 | os.makedirs(directory)
133 |
134 | save_path = os.path.join(directory, "model.cpkt")
135 | if epochs == 0:
136 | return save_path
137 |
138 | with tf.Session() as sess:
139 | if write_graph:
140 | tf.train.write_graph(sess.graph_def, output_path, "graph.pb", False)
141 |
142 | sess.run(init)
143 |
144 | if restore:
145 | ckpt = tf.train.get_checkpoint_state(output_path)
146 | if ckpt and ckpt.model_checkpoint_path:
147 | self.net.restore(sess, ckpt.model_checkpoint_path)
148 |
149 |
150 | summary_writer = tf.summary.FileWriter(output_path, graph=sess.graph)
151 | logging.info("Start optimization")
152 |
153 | # select validation dataset
154 | valid_x, valid_y = valid_provider(valid_size, fix=True)
155 | util.save_mat(valid_y, "%s/%s.mat"%(self.prediction_path, 'origin_y'))
156 | util.save_mat(valid_x, "%s/%s.mat"%(self.prediction_path, 'origin_x'))
157 |
158 | for epoch in range(epochs):
159 | total_loss = 0
160 | # batch_x, batch_y = data_provider(self.batch_size)
161 | for step in range((epoch*training_iters), ((epoch+1)*training_iters)):
162 | batch_x, batch_y = data_provider(self.batch_size)
163 | # Run optimization op (backprop)
164 | _, loss, lr, avg_psnr = sess.run([self.optimizer,
165 | self.net.loss,
166 | self.learning_rate_node,
167 | self.net.avg_psnr],
168 | feed_dict={self.net.x: batch_x,
169 | self.net.y: batch_y,
170 | self.net.keep_prob: dropout,
171 | self.net.phase: True})
172 |
173 | if step % display_step == 0:
174 | logging.info("Iter {:} (before training on the batch) Minibatch MSE= {:.4f}, Minibatch Avg PSNR= {:.4f}".format(step, loss, avg_psnr))
175 | self.output_minibatch_stats(sess, summary_writer, step, batch_x, batch_y)
176 |
177 | total_loss += loss
178 |
179 | self.record_summary(summary_writer, 'training_loss', loss, step)
180 | self.record_summary(summary_writer, 'training_avg_psnr', avg_psnr, step)
181 |
182 | # output statistics for epoch
183 | self.output_epoch_stats(epoch, total_loss, training_iters, lr)
184 | self.output_valstats(sess, summary_writer, step, valid_x, valid_y, "epoch_%s"%epoch, store_img=True)
185 |
186 | if epoch % save_epoch == 0:
187 | directory = os.path.join(output_path, "{}_cpkt/".format(step))
188 | if not os.path.exists(directory):
189 | os.makedirs(directory)
190 | path = os.path.join(directory, "model.cpkt".format(step))
191 | self.net.save(sess, path)
192 |
193 | save_path = self.net.save(sess, save_path)
194 |
195 | logging.info("Optimization Finished!")
196 |
197 | return save_path
198 |
199 | def output_epoch_stats(self, epoch, total_loss, training_iters, lr):
200 | logging.info("Epoch {:}, Average MSE: {:.4f}, learning rate: {:.4f}".format(epoch, (total_loss / training_iters), lr))
201 |
202 | def output_minibatch_stats(self, sess, summary_writer, step, batch_x, batch_y):
203 | # Calculate batch loss and accuracy
204 | loss, predictions, avg_psnr = sess.run([self.net.loss,
205 | self.net.recons,
206 | self.net.avg_psnr],
207 | feed_dict={self.net.x: batch_x,
208 | self.net.y: batch_y,
209 | self.net.keep_prob: 1.,
210 | self.net.phase: False})
211 |
212 | self.record_summary(summary_writer, 'minibatch_loss', loss, step)
213 | self.record_summary(summary_writer, 'minibatch_avg_psnr', avg_psnr, step)
214 |
215 | logging.info("Iter {:} (After training on the batch) Minibatch MSE= {:.4f}, Minibatch Avg PSNR= {:.4f}".format(step,loss,avg_psnr))
216 |
217 | def output_valstats(self, sess, summary_writer, step, batch_x, batch_y, name, store_img=True):
218 | prediction, loss, avg_psnr = sess.run([self.net.recons,
219 | self.net.valid_loss,
220 | self.net.valid_avg_psnr],
221 | feed_dict={self.net.x: batch_x,
222 | self.net.y: batch_y,
223 | self.net.keep_prob: 1.,
224 | self.net.phase: False})
225 |
226 | self.record_summary(summary_writer, 'valid_loss', loss, step)
227 | self.record_summary(summary_writer, 'valid_avg_psnr', avg_psnr, step)
228 |
229 | logging.info("Validation Statistics, validation loss= {:.4f}, Avg PSNR= {:.4f}".format(loss, avg_psnr))
230 |
231 | util.save_mat(prediction, "%s/%s.mat"%(self.prediction_path, name))
232 |
233 | if store_img:
234 | util.save_img(prediction[0,...], "%s/%s_img.tif"%(self.prediction_path, name))
235 |
236 | def record_summary(self, writer, name, value, step):
237 | summary=tf.Summary()
238 | summary.value.add(tag=name, simple_value = value)
239 | writer.add_summary(summary, step)
240 | writer.flush()
241 |
242 |
--------------------------------------------------------------------------------
/scadec/unet_bn.py:
--------------------------------------------------------------------------------
1 | # tf_unet is free software: you can redistribute it and/or modify
2 | # it under the terms of the GNU General Public License as published by
3 | # the Free Software Foundation, either version 3 of the License, or
4 | # (at your option) any later version.
5 | #
6 | # tf_unet is distributed in the hope that it will be useful,
7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 | # GNU General Public License for more details.
10 | #
11 | # You should have received a copy of the GNU General Public License
12 | # along with tf_unet. If not, see .
13 |
14 |
15 | '''
16 | Modified on Feb, 2018 based on the work of jakeret
17 |
18 | author: yusun
19 | '''
20 |
21 | from __future__ import print_function, division, absolute_import, unicode_literals
22 |
23 | import os
24 | import shutil
25 | import math
26 | import numpy as np
27 | from collections import OrderedDict
28 | import logging
29 |
30 | import tensorflow as tf
31 |
32 | from scadec import util
33 | from scadec.layers import *
34 | from scadec.nets import *
35 |
36 |
37 | class Unet_bn(object):
38 | """
39 | A unet implementation
40 |
41 | :param channels: (optional) number of channels in the input image
42 | :param cost: (optional) name of the cost function. Default is 'cross_entropy'
43 | :param cost_kwargs: (optional) kwargs passed to the cost function. See Unet._get_cost for more options
44 | :param kwargs: args passed to create_net function.
45 | """
46 |
47 | def __init__(self, img_channels=3, truth_channels=3, cost="mean_squared_error", cost_kwargs={}, **kwargs):
48 | tf.reset_default_graph()
49 |
50 | # basic variables
51 | self.summaries = kwargs.get("summaries", True)
52 | self.img_channels = img_channels
53 | self.truth_channels = truth_channels
54 |
55 | # placeholders for input x and y
56 | self.x = tf.placeholder("float", shape=[None, None, None, img_channels])
57 | self.y = tf.placeholder("float", shape=[None, None, None, truth_channels])
58 | self.phase = tf.placeholder(tf.bool, name='phase')
59 | self.keep_prob = tf.placeholder(tf.float32) #dropout (keep probability)
60 |
61 | # reused variables
62 | self.nx = tf.shape(self.x)[1]
63 | self.ny = tf.shape(self.x)[2]
64 | self.num_examples = tf.shape(self.x)[0]
65 |
66 | # variables need to be calculated
67 | self.recons = unet_decoder(self.x, self.keep_prob, self.phase, self.img_channels, self.truth_channels, **kwargs)
68 | self.loss = self._get_cost(cost, cost_kwargs)
69 | self.valid_loss = self._get_cost(cost, cost_kwargs)
70 | self.avg_psnr = self._get_measure('avg_psnr')
71 | self.valid_avg_psnr = self._get_measure('avg_psnr')
72 |
73 | def _get_measure(self, measure):
74 | total_pixels = self.nx * self.ny * self.truth_channels
75 | dtype = self.x.dtype
76 | flat_recons = tf.reshape(self.recons, [-1, total_pixels])
77 | flat_truths = tf.reshape(self.y, [-1, total_pixels])
78 |
79 | if measure == 'psnr':
80 | # mse are of the same length of the truths
81 | mse = mse_array(flat_recons, flat_truths, total_pixels)
82 | term1 = log(tf.constant(1, dtype), 10.)
83 | term2 = log(mse, 10.)
84 | psnr = tf.scalar_mul(20., term1) - tf.scalar_mul(10., term2)
85 | result = psnr
86 |
87 | elif measure == 'avg_psnr':
88 | # mse are of the same length of the truths
89 | mse = mse_array(flat_recons, flat_truths, total_pixels)
90 | term1 = log(tf.constant(1, dtype), 10.)
91 | term2 = log(mse, 10.)
92 | psnr = tf.scalar_mul(20., term1) - tf.scalar_mul(10., term2)
93 | avg_psnr = tf.reduce_mean(psnr)
94 | result = avg_psnr
95 |
96 | else:
97 | raise ValueError("Unknown measure: "%cost_name)
98 |
99 | return result
100 |
101 | def _get_cost(self, cost_name, cost_kwargs):
102 | """
103 | Constructs the cost function.
104 |
105 | """
106 |
107 | total_pixels = self.nx * self.ny * self.truth_channels
108 | flat_recons = tf.reshape(self.recons, [-1, total_pixels])
109 | flat_truths = tf.reshape(self.y, [-1, total_pixels])
110 | if cost_name == "mean_squared_error":
111 | loss = tf.losses.mean_squared_error(flat_recons, flat_truths)
112 | # the mean_squared_error is equal to the following code
113 | # se = tf.squared_difference(flat_recons, flat_truths)
114 | # loss = tf.reduce_mean(se, 1)
115 |
116 | # add new loss function here
117 | else:
118 | raise ValueError("Unknown cost function: "%cost_name)
119 |
120 | return loss
121 |
122 | # predict
123 | def predict(self, model_path, x_test, keep_prob, phase):
124 | """
125 | Uses the model to create a prediction for the given data
126 |
127 | :param model_path: path to the model checkpoint to restore
128 | :param x_test: Data to predict on. Shape [n, nx, ny, channels]
129 | :returns prediction: The unet prediction Shape [n, px, py, labels] (px=nx-self.offset/2)
130 | """
131 |
132 | init = tf.global_variables_initializer()
133 | with tf.Session() as sess:
134 | # Initialize variables
135 | sess.run(init)
136 |
137 | # Restore model weights from previously saved model
138 | self.restore(sess, model_path)
139 |
140 | prediction = sess.run(self.recons, feed_dict={self.x: x_test,
141 | self.keep_prob: keep_prob,
142 | self.phase: phase}) # set phase to False for every prediction
143 | # define operation
144 | return prediction
145 |
146 | def save(self, sess, model_path):
147 | """
148 | Saves the current session to a checkpoint
149 |
150 | :param sess: current session
151 | :param model_path: path to file system location
152 | """
153 |
154 | saver = tf.train.Saver()
155 | save_path = saver.save(sess, model_path)
156 | return save_path
157 |
158 | def restore(self, sess, model_path):
159 | """
160 | Restores a session from a checkpoint
161 |
162 | :param sess: current session instance
163 | :param model_path: path to file system checkpoint location
164 | """
165 |
166 | saver = tf.train.Saver()
167 | saver.restore(sess, model_path)
168 | logging.info("Model restored from file: %s" % model_path)
--------------------------------------------------------------------------------
/scadec/util.py:
--------------------------------------------------------------------------------
1 | # tf_unet is free software: you can redistribute it and/or modify
2 | # it under the terms of the GNU General Public License as published by
3 | # the Free Software Foundation, either version 3 of the License, or
4 | # (at your option) any later version.
5 | #
6 | # tf_unet is distributed in the hope that it will be useful,
7 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
8 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9 | # GNU General Public License for more details.
10 | #
11 | # You should have received a copy of the GNU General Public License
12 | # along with tf_unet. If not, see .
13 |
14 |
15 | '''
16 | Modified on Feb, 2018 based on the work of jakeret
17 |
18 | author: yusun
19 | '''
20 | from __future__ import print_function, division, absolute_import, unicode_literals
21 | import numpy as np
22 | import scipy.io as sio
23 | import scipy.misc as smisc
24 |
25 |
26 | def to_rgb(img):
27 | """
28 | Converts the given array into a RGB image. If the number of channels is not
29 | 3 the array is tiled such that it has 3 channels. Finally, the values are
30 | rescaled to [0,255)
31 |
32 | :param img: the array to convert [nx, ny, channels]
33 |
34 | :returns img: the rgb image [nx, ny, 3]
35 | """
36 | img = np.atleast_3d(img)
37 | channels = img.shape[2]
38 | if channels < 3:
39 | img = np.tile(img, 3)
40 |
41 | img[np.isnan(img)] = 0
42 | img -= np.amin(img)
43 | img /= np.amax(img)
44 | img *= 255
45 | return img
46 |
47 | def to_double(img):
48 | img = np.atleast_3d(img)
49 | channels = img.shape[2]
50 | if channels < 3:
51 | img = np.tile(img, 3)
52 |
53 | img[np.isnan(img)] = 0
54 | img -= np.amin(img)
55 | img /= np.amax(img)
56 | return img
57 |
58 | def save_mat(img, path):
59 | """
60 | Writes the image to disk
61 |
62 | :param img: the rgb image to save
63 | :param path: the target path
64 | """
65 |
66 | sio.savemat(path, {'img':img})
67 |
68 |
69 | def save_img(img, path):
70 | """
71 | Writes the image to disk
72 |
73 | :param img: the rgb image to save
74 | :param path: the target path
75 | """
76 | img = to_rgb(img)
77 | smisc.imsave(path, img.round().astype(np.uint8))
78 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from scadec.unet_bn import Unet_bn
2 | from scadec.train import Trainer_bn
3 |
4 | from scadec import image_util
5 | from scadec import util
6 |
7 | import scipy.io as spio
8 | import numpy as np
9 | import os
10 |
11 | ####################################################
12 | #### PREPARE WORKSPACE ###
13 | ####################################################
14 |
15 | # here indicating the GPU you want to use. if you don't have GPU, just leave it.
16 | gpu_vis = '3'
17 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_vis; # 0,1,2,3
18 |
19 | # here specify the path of the model you want to load
20 | gpu_ind = '0'
21 | model_path = 'gpu' + gpu_ind + '/models/60099_cpkt/models/final/model.cpkt'
22 |
23 | data_channels = 2
24 | truth_channels = 1
25 |
26 | ####################################################
27 | #### FUNCTIONS ###
28 | ####################################################
29 |
30 | # make the data a 4D vector
31 | def preprocess(data, channels):
32 | nx = data.shape[1]
33 | ny = data.shape[2]
34 | return data.reshape((-1, nx, ny, channels))
35 |
36 | ####################################################
37 | #### lOAD MODEL ###
38 | ####################################################
39 |
40 | # set up args for the unet, should be exactly the same as the loading model
41 | kwargs = {
42 | "layers": 5,
43 | "conv_times": 2,
44 | "features_root": 64,
45 | "filter_size": 3,
46 | "pool_size": 2,
47 | "summaries": True
48 | }
49 |
50 | net = Unet_bn(img_channels=data_channels, truth_channels=truth_channels, cost="mean_squared_error", **kwargs)
51 |
52 |
53 | ####################################################
54 | #### lOAD TRAIN ###
55 | ####################################################
56 |
57 | #preparing training data
58 | data_mat = spio.loadmat('train_np/obhatGausWeak128_40.mat', squeeze_me=True)
59 | truths_mat = spio.loadmat('train_np/obGausWeak128_40.mat', squeeze_me=True)
60 |
61 | data = data_mat['obhatGausWeak128']
62 | data = preprocess(data, data_channels) # 4 dimension -> 3 dimension if you do data[:,:,:,1]
63 | truths = preprocess(truths_mat['obGausWeak128'], truth_channels)
64 |
65 | data_provider = image_util.SimpleDataProvider(data, truths)
66 |
67 |
68 | ####################################################
69 | #### lOAD TEST ###
70 | ####################################################
71 |
72 | vdata_mat = spio.loadmat('test_np_noise/obhatGausWeak{}Noise128.mat'.format(level), squeeze_me=True)
73 | vtruths_mat = spio.loadmat('valid_np/obGausN1S128val.mat', squeeze_me=True)
74 |
75 | vdata = vdata_mat['obhatGausWeak128']
76 | vdata = preprocess(vdata, data_channels)
77 | vtruths = preprocess(vtruths_mat['obGausN1S128val'], truth_channels)
78 |
79 | valid_provider = image_util.SimpleDataProvider(vdata, vtruths)
80 |
81 | ####################################################
82 | #### PREDICT ###
83 | ####################################################
84 |
85 | predicts = []
86 |
87 | valid_x, valid_y = valid_provider('full')
88 | num = valid_x.shape[0]
89 |
90 | for i in range(num):
91 |
92 | print('')
93 | print('')
94 | print('************* {} *************'.format(i))
95 | print('')
96 | print('')
97 |
98 | x_train, y_train = data_provider(23)
99 | x_input = valid_x[i:i+1,:,:,:]
100 | x_input = np.concatenate((x_input, x_train), axis=0)
101 | predict = net.predict(model_path, x_input, 1, True)
102 | predicts.append(predict[0:1,:,:])
103 |
104 | predicts = np.concatenate(predicts, axis=0)
105 | util.save_mat(predicts, 'test{}Noise.mat'.format(level))
106 |
107 |
108 |
109 |
110 |
111 |
112 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from scadec.unet_bn import Unet_bn
2 | from scadec.train import Trainer_bn
3 |
4 | from scadec import image_util
5 |
6 | import scipy.io as spio
7 | import numpy as np
8 | import os
9 |
10 | ####################################################
11 | #### FUNCTIONS ###
12 | ####################################################
13 |
14 | # make the data a 4D vector
15 | def preprocess(data, channels):
16 | nx = data.shape[1]
17 | ny = data.shape[2]
18 | return data.reshape((-1, nx, ny, channels))
19 |
20 | ####################################################
21 | #### HYPER-PARAMETERS ###
22 | ####################################################
23 |
24 | # here indicating the GPU you want to use. if you don't have GPU, just leave it.
25 | gpu_ind = '2'
26 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_ind; # 0,1,2,3
27 |
28 |
29 | # Because we have real & imaginary part of our input, data_channels is set to 2
30 | data_channels = 2
31 | truth_channels = 1
32 |
33 | ####################################################
34 | #### DATA LOADING ###
35 | ####################################################
36 |
37 | """
38 | here loads all the data we need for training and validating.
39 |
40 | """
41 |
42 | #-- Training Data --#
43 | data_mat = spio.loadmat('train_np/obhatGausWeak128_40.mat', squeeze_me=True)
44 | truths_mat = spio.loadmat('train_np/obGausWeak128_40.mat', squeeze_me=True)
45 |
46 | data = data_mat['obhatGausWeak128']
47 | data = preprocess(data, data_channels) # 4 dimension -> 3 dimension if you do data[:,:,:,1]
48 | truths = preprocess(truths_mat['obGausWeak128'], truth_channels)
49 |
50 | data_provider = image_util.SimpleDataProvider(data, truths)
51 |
52 | #-- Validating Data --#
53 | vdata_mat = spio.loadmat('valid_np/obhatGausWeak128val_40.mat', squeeze_me=True)
54 | vtruths_mat = spio.loadmat('valid_np/obGausWeak128val_40.mat', squeeze_me=True)
55 |
56 | vdata = vdata_mat['obhatGausWeak128val']
57 | vdata = preprocess(vdata, data_channels)
58 | vtruths = preprocess(vtruths_mat['obGausWeak128val'], truth_channels)
59 |
60 | valid_provider = image_util.SimpleDataProvider(vdata, vtruths)
61 |
62 |
63 | ####################################################
64 | #### NETWORK ###
65 | ####################################################
66 |
67 | """
68 | here we specify the neural network.
69 |
70 | """
71 |
72 | #-- Network Setup --#
73 | # set up args for the unet
74 | kwargs = {
75 | "layers": 5, # how many resolution levels we want to have
76 | "conv_times": 2, # how many times we want to convolve in each level
77 | "features_root": 64, # how many feature_maps we want to have as root (the following levels will calculate the feature_map by multiply by 2, exp, 64, 128, 256)
78 | "filter_size": 3, # filter size used in convolution
79 | "pool_size": 2, # pooling size used in max-pooling
80 | "summaries": True
81 | }
82 |
83 | net = Unet_bn(img_channels=data_channels, truth_channels=truth_channels, cost="mean_squared_error", **kwargs)
84 |
85 |
86 | ####################################################
87 | #### TRAINING ###
88 | ####################################################
89 |
90 | # args for training
91 | batch_size = 24 # batch size for training
92 | valid_size = 24 # batch size for validating
93 | optimizer = "adam" # optimizer we want to use, 'adam' or 'momentum'
94 |
95 | # output paths for results
96 | output_path = 'gpu' + gpu_ind + '/models'
97 | prediction_path = 'gpu' + gpu_ind + '/validation'
98 | # restore_path = 'gpu001/models/50099_cpkt'
99 |
100 | # optional args
101 | opt_kwargs = {
102 | 'learning_rate': 0.001
103 | }
104 |
105 | # make a trainer for scadec
106 | trainer = Trainer_bn(net, batch_size=batch_size, optimizer = "adam", opt_kwargs=opt_kwargs)
107 | path = trainer.train(data_provider, output_path, valid_provider, valid_size, training_iters=100, epochs=1000, display_step=20, save_epoch=100, prediction_path=prediction_path)
108 |
109 |
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------