├── .gitattributes ├── images ├── img1.png ├── img2.png └── img3.png ├── test_data ├── letter_E.npy ├── letter_S.npy ├── number_8.npy └── number_9.npy ├── LICENSE ├── demo.py ├── traning code.py ├── README.md └── model.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /images/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bu-cisl/Deep-Speckle-Correlation/HEAD/images/img1.png -------------------------------------------------------------------------------- /images/img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bu-cisl/Deep-Speckle-Correlation/HEAD/images/img2.png -------------------------------------------------------------------------------- /images/img3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bu-cisl/Deep-Speckle-Correlation/HEAD/images/img3.png -------------------------------------------------------------------------------- /test_data/letter_E.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bu-cisl/Deep-Speckle-Correlation/HEAD/test_data/letter_E.npy -------------------------------------------------------------------------------- /test_data/letter_S.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bu-cisl/Deep-Speckle-Correlation/HEAD/test_data/letter_S.npy -------------------------------------------------------------------------------- /test_data/number_8.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bu-cisl/Deep-Speckle-Correlation/HEAD/test_data/number_8.npy -------------------------------------------------------------------------------- /test_data/number_9.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bu-cisl/Deep-Speckle-Correlation/HEAD/test_data/number_9.npy -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Boston University Computational Imaging Systems Lab 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a quick demo of deep speckle correlation project. 3 | 4 | Paper link: https://arxiv.org/abs/1806.04139 5 | 6 | Author: Yunzhe Li, Yujia Xue, Lei Tian 7 | 8 | Computational Imaging System Lab, @ ECE, Boston University 9 | 10 | Date: 2018.08.21 11 | """ 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | 15 | from model import get_model_deep_speckle 16 | 17 | # model is defined in model.py 18 | model = get_model_deep_speckle() 19 | # pretrained_weights.hdf5 can be downloaded from the link on our GitHub project page 20 | model.load_weights('pretrained_weights.hdf5') 21 | 22 | # test speckle patterns. Four types of objects (E,S,8,9), 23 | # Each object has five speckle patterns through 5 different test diffusers 24 | speckle_E = np.load('test_data/letter_E.npy') 25 | speckle_S = np.load('test_data/letter_S.npy') 26 | speckle_8 = np.load('test_data/number_8.npy') 27 | speckle_9 = np.load('test_data/number_9.npy') 28 | 29 | # prediction 30 | pred_speckle_E = model.predict(speckle_E, batch_size=2) 31 | pred_speckle_S = model.predict(speckle_S, batch_size=2) 32 | pred_speckle_8 = model.predict(speckle_8, batch_size=2) 33 | pred_speckle_9 = model.predict(speckle_9, batch_size=2) 34 | 35 | # plot results 36 | plt.figure() 37 | for i in range(5): 38 | plt.subplot(2, 5, i + 1) 39 | plt.imshow(speckle_E[i, :].squeeze(), cmap='hot') 40 | plt.axis('off') 41 | plt.subplot(2, 5, i + 1 + 5) 42 | plt.imshow(pred_speckle_E[i, :, :, 0].squeeze(), cmap='gray') 43 | plt.axis('off') 44 | 45 | plt.figure() 46 | for i in range(5): 47 | plt.subplot(2, 5, i + 1) 48 | plt.imshow(speckle_S[i, :].squeeze(), cmap='hot') 49 | plt.axis('off') 50 | plt.subplot(2, 5, i + 1 + 5) 51 | plt.imshow(pred_speckle_S[i, :, :, 0].squeeze(), cmap='gray') 52 | plt.axis('off') 53 | 54 | plt.figure() 55 | for i in range(5): 56 | plt.subplot(2, 5, i + 1) 57 | plt.imshow(speckle_8[i, :].squeeze(), cmap='hot') 58 | plt.axis('off') 59 | plt.subplot(2, 5, i + 1 + 5) 60 | plt.imshow(pred_speckle_8[i, :, :, 0].squeeze(), cmap='gray') 61 | plt.axis('off') 62 | 63 | plt.figure() 64 | for i in range(5): 65 | plt.subplot(2, 5, i + 1) 66 | plt.imshow(speckle_9[i, :].squeeze(), cmap='hot') 67 | plt.axis('off') 68 | plt.subplot(2, 5, i + 1 + 5) 69 | plt.imshow(pred_speckle_9[i, :, :, 0].squeeze(), cmap='gray') 70 | plt.axis('off') 71 | 72 | plt.show() 73 | -------------------------------------------------------------------------------- /traning code.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | from keras.optimizers import Adam 4 | from keras.callbacks import ModelCheckpoint 5 | from UNet_ResNet import get_unet_denseblock_x2_deeper 6 | from matplotlib import pyplot as plt 7 | #from loss_function_new import total_variation_balanced_cross_entropy, balanced_cross_entropy, gaussian_loss 8 | from parameter import save_path, num_epochs, batch_size, save_period, lr_rate 9 | 10 | from sklearn.model_selection import train_test_split 11 | 12 | # Split the data 13 | 14 | 15 | plt.switch_backend('agg') 16 | 17 | 18 | proj_name = 'MNIST_x2' 19 | show_groundtruth_flag = False 20 | 21 | def train_and_predict(): 22 | print('-' * 30) 23 | print('Loading training data...') 24 | print('-' * 30) 25 | 26 | print('-' * 30) 27 | print('create validation data...') 28 | print('-' * 30) 29 | # x_train, x_valid, y_train, y_valid = train_test_split(x_train_load, y_train_load, test_size=0.02, shuffle= True) 30 | # np.save('../data/x_valid.npy',x_valid) 31 | # np.save('../data/y_valid.npy',y_valid) 32 | # np.save('../data/x_train_new.npy',x_train) 33 | # np.save('../data/y_train_new.npy',y_train) 34 | 35 | x_train = np.load('../data/x_train.npy') 36 | y_train = np.load('../data/y_train.npy') 37 | x_valid = np.load('../data/x_vali.npy') 38 | y_valid = np.load('../data/y_vali.npy') 39 | print('-' * 30) 40 | print('Creating and compiling model...') 41 | print('-' * 30) 42 | model = get_unet_denseblock_x2_deeper() 43 | #model.load_weights('save/lr4/MNIST_x2.60.hdf5') 44 | 45 | model.compile(optimizer=Adam(lr=lr_rate), loss='binary_crossentropy') 46 | model_checkpoint = ModelCheckpoint(save_path+proj_name+'.{epoch:02d}.hdf5', monitor='loss', verbose=2, save_best_only=False, 47 | period=save_period) 48 | 49 | print('-' * 30) 50 | print('Fitting model...') 51 | print('-' * 30) 52 | 53 | history = model.fit(x_train, y_train, batch_size=batch_size, epochs=num_epochs, verbose=2, shuffle=True, 54 | callbacks=[model_checkpoint], validation_data = (x_valid, y_valid)) 55 | 56 | 57 | # summarize history for loss 58 | plt.plot(history.history['loss']) 59 | plt.plot(history.history['val_loss']) 60 | plt.title('model loss') 61 | plt.ylabel('loss') 62 | plt.xlabel('epoch') 63 | plt.legend(['train', 'validation'], loc='upper left') 64 | plt.savefig('total_loss_4.png') 65 | plt.close() 66 | 67 | 68 | if __name__ == '__main__': 69 | train_and_predict() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Speckle-Correlation 2 | Python implementation of paper: **Deep speckle correlation: a deep learning approach towards scalable imaging through scattering media**. We provide model, pre-trained weights(download link available below), test data and a quick demo. 3 | 4 | 5 | ### Citation 6 | If you find this project useful in your research, please consider citing our paper: 7 | 8 | [**Yunzhe Li, Yujia Xue, and Lei Tian, "Deep speckle correlation: a deep learning approach toward scalable imaging through scattering media," Optica 5, 1181-1190 (2018)**](https://www.osapublishing.org/optica/abstract.cfm?uri=optica-5-10-1181) 9 | 10 | 11 | ### Abstract 12 | Imaging through scattering is an important, yet challenging problem. Tremendous progress has been made by exploiting the deterministic input-output transmission matrix for a fixed medium. However, this one-for-one approach is highly susceptible to speckle decorrelations -- small perturbations to the scattering medium lead to model errors and severe degradation of the imaging performance. Our goal here is to develop a new framework that is highly scalable to both medium perturbations and measurement requirement. To do so, we propose a statistical one-for-all deep learning technique that encapsulates a wide range of statistical variations for the model to be resilient to speckle decorrelations. Specifically, we develop a convolutional neural network (CNN) that is able to learn the statistical information contained in the speckle intensity patterns captured on a set of diffusers having the same macroscopic parameter. We then show for the first time, to the best of our knowledge, that the trained CNN is able to generalize and make high-quality object prediction through an entirely different set of diffusers of the same class. Our work paves the way to a highly scalable deep learning approach for imaging through scattering media. 13 | 14 |

15 | 16 |

17 | 18 | 19 | ### Requirements 20 | python 3.6 21 | 22 | keras 2.1.2 23 | 24 | tensorflow 1.4.0 25 | 26 | numpy 1.14.3 27 | 28 | h5py 2.7.1 29 | 30 | matplotlib 2.1.2 31 | 32 | 33 | ### CNN architecture 34 |

35 | 36 |

37 | 38 | 39 | ### Download pre-trained weights 40 | You can download pre-trained weights from [here](https://zenodo.org/records/14939667?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjYxN2Q4ZDM0LTc3MWQtNDdiOS04MDY2LWQzYjM5MmFkZGE1YSIsImRhdGEiOnt9LCJyYW5kb20iOiI4N2I4MzY5YjNkZjRmMjMzYTAyMTFiMDI5NjQwYzk0NiJ9._LjMaI7t13wCyQA6MF4cBMacQ9SI8GrmuwaTBiIOKWfRrldPYZRJxjHKr4kvciulcubskLhg8xF_U55eEqGCnQ) 41 | 42 | ### Download dataset 43 | You can download dataset from [here](https://zenodo.org/records/15361263?token=eyJhbGciOiJIUzUxMiJ9.eyJpZCI6IjA4NWZjOWNkLTAwNzctNGIyNi04ODNkLTIzOTIxYzA2NTg1ZCIsImRhdGEiOnt9LCJyYW5kb20iOiIwMTVlNTA0YjE2N2RjNTQ3NjlmOTQ4ZWM1MDE3MmY4NyJ9.7FuO7_HZdT-pXfJY5NHey6tZ_H4YwC1QEYfROznirjCO_OZNawN-CpaB6Brb6Qrona-rabd3NeOcQWlNAcOPwg) 44 | #### Data for training: 45 | Letter speckle: each diffuser we use 300 images for training [image and corresponding ground truth index range:0-299)] 46 | Digit speckle: each diffuser we use 300 images for training [image and corresponding ground truth index range:0-299)] 47 | 48 | #### Data for testing: 49 | Seen digits and letters: each unseen diffuser 400 images. 50 | Including: 51 | 200 letters [image index 0-199)] corresponding letter ground truth index 0-199 52 | 200 digits [image index 200-399] corresponding digits ground truth index 0-199 53 | Unseen digits and letters 54 | Unseen digits: 100 images are used for testing [image and corresponding ground truth index range 300-399] 55 | Unseen digits: 100 images are used for testing [image and corresponding ground truth index range 300-399] 56 | Quickdraw : different object types 57 | 58 | 59 | ### How to use 60 | After download the pre-trained weights file, put it under the root directory and run [demo.py](demo.py). 61 | 62 | 63 | ### Results 64 |

65 | 66 |

67 | 68 | 69 | ## License 70 | This project is licensed under the terms of the MIT license. see the [LICENSE](LICENSE) file for details 71 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | python implementation of paper: Deep speckle correlation: a deep learning approach towards scalable imaging through 3 | scattering media 4 | 5 | paper link: https://arxiv.org/abs/1806.04139 6 | 7 | Author: Yunzhe li, Yujia Xue, Lei Tian 8 | 9 | Computational Imaging System Lab, @ ECE, Boston University 10 | 11 | Date: 2018.08.21 12 | """ 13 | 14 | from __future__ import print_function 15 | 16 | from keras.models import Model 17 | from keras.layers import Input, MaxPooling2D, UpSampling2D, Dropout, Conv2D, Concatenate, Activation 18 | from keras.layers.normalization import BatchNormalization 19 | from keras.regularizers import l2 20 | 21 | 22 | # define conv_factory: batch normalization + ReLU + Conv2D + Dropout (optional) 23 | def conv_factory(x, concat_axis, nb_filter, 24 | dropout_rate=None, weight_decay=1E-4): 25 | x = BatchNormalization(axis=concat_axis, 26 | gamma_regularizer=l2(weight_decay), 27 | beta_regularizer=l2(weight_decay))(x) 28 | x = Activation('relu')(x) 29 | x = Conv2D(nb_filter, (5, 5), dilation_rate=(2, 2), 30 | kernel_initializer="he_uniform", 31 | padding="same", 32 | kernel_regularizer=l2(weight_decay))(x) 33 | if dropout_rate: 34 | x = Dropout(dropout_rate)(x) 35 | 36 | return x 37 | 38 | 39 | # define dense block 40 | def denseblock(x, concat_axis, nb_layers, growth_rate, 41 | dropout_rate=None, weight_decay=1E-4): 42 | list_feat = [x] 43 | for i in range(nb_layers): 44 | x = conv_factory(x, concat_axis, growth_rate, 45 | dropout_rate, weight_decay) 46 | list_feat.append(x) 47 | x = Concatenate(axis=concat_axis)(list_feat) 48 | 49 | return x 50 | 51 | 52 | # define model U-net modified with dense block 53 | def get_model_deep_speckle(): 54 | inputs = Input((256, 256, 1)) 55 | print("inputs shape:", inputs.shape) 56 | 57 | conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs) 58 | print("conv1 shape:", conv1.shape) 59 | db1 = denseblock(x=conv1, concat_axis=3, nb_layers=4, growth_rate=16, dropout_rate=0.5) 60 | print("db1 shape:", db1.shape) 61 | pool1 = MaxPooling2D(pool_size=(2, 2))(db1) 62 | print("pool1 shape:", pool1.shape) 63 | 64 | conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1) 65 | print("conv2 shape:", conv2.shape) 66 | db2 = denseblock(x=conv2, concat_axis=3, nb_layers=4, growth_rate=16, dropout_rate=0.5) 67 | print("db2 shape:", db2.shape) 68 | pool2 = MaxPooling2D(pool_size=(2, 2))(db2) 69 | print("pool2 shape:", pool2.shape) 70 | 71 | conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2) 72 | print("conv3 shape:", conv3.shape) 73 | db3 = denseblock(x=conv3, concat_axis=3, nb_layers=4, growth_rate=16, dropout_rate=0.5) 74 | print("db3 shape:", db3.shape) 75 | pool3 = MaxPooling2D(pool_size=(2, 2))(db3) 76 | print("pool3 shape:", pool3.shape) 77 | 78 | conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3) 79 | print("conv4 shape:", conv4.shape) 80 | db4 = denseblock(x=conv4, concat_axis=3, nb_layers=4, growth_rate=16, dropout_rate=0.5) 81 | print("db4 shape:", db4.shape) 82 | pool4 = MaxPooling2D(pool_size=(2, 2))(db4) 83 | print("pool4 shape:", pool4.shape) 84 | 85 | conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4) 86 | print("conv5 shape:", conv5.shape) 87 | db5 = denseblock(x=conv5, concat_axis=3, nb_layers=4, growth_rate=16, dropout_rate=0.5) 88 | print("db5 shape:", db5.shape) 89 | up5 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')( 90 | UpSampling2D(size=(2, 2))(db5)) 91 | print("up5 shape:", up5.shape) 92 | merge5 = Concatenate(axis=3)([db4, up5]) 93 | print("merge5 shape:", merge5.shape) 94 | 95 | conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge5) 96 | print("conv6 shape:", conv6.shape) 97 | db6 = denseblock(x=conv6, concat_axis=3, nb_layers=3, growth_rate=16, dropout_rate=0.5) 98 | print("db5 shape:", db6.shape) 99 | up6 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')( 100 | UpSampling2D(size=(2, 2))(db6)) 101 | print("up6 shape:", up6.shape) 102 | merge6 = Concatenate(axis=3)([db3, up6]) 103 | print("merge6 shape:", merge6.shape) 104 | 105 | conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6) 106 | print("conv7 shape:", conv7.shape) 107 | db7 = denseblock(x=conv7, concat_axis=3, nb_layers=3, growth_rate=16, dropout_rate=0.5) 108 | print("db7 shape:", db7.shape) 109 | up7 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')( 110 | UpSampling2D(size=(2, 2))(db7)) 111 | print("up7 shape:", up7.shape) 112 | merge7 = Concatenate(axis=3)([db2, up7]) 113 | print("merge7 shape:", merge7.shape) 114 | 115 | conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7) 116 | print("conv8 shape:", conv8.shape) 117 | db8 = denseblock(x=conv8, concat_axis=3, nb_layers=3, growth_rate=16, dropout_rate=0.5) 118 | print("db8 shape:", db8.shape) 119 | up8 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')( 120 | UpSampling2D(size=(2, 2))(db8)) 121 | print("up8 shape:", up8.shape) 122 | merge8 = Concatenate(axis=3)([db1, up8]) 123 | print("merge8 shape:", merge8.shape) 124 | 125 | conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8) 126 | print("conv9 shape:", conv9.shape) 127 | db9 = denseblock(x=conv9, concat_axis=3, nb_layers=3, growth_rate=16, dropout_rate=0.5) 128 | print("db9 shape:", db9.shape) 129 | conv10 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer='he_normal')(db9) 130 | print("conv10 shape:", conv10.shape) 131 | conv11 = Conv2D(2, 1, activation='softmax')(conv10) 132 | print("conv11 shape:", conv11.shape) 133 | 134 | model = Model(inputs=inputs, outputs=conv11) 135 | 136 | return model 137 | --------------------------------------------------------------------------------