├── requirements.txt ├── Val_data └── imageId_3 │ ├── 1.bmp │ ├── 2.bmp │ ├── 3.bmp │ ├── 4.bmp │ ├── 5.bmp │ ├── 6.bmp │ ├── 7.bmp │ ├── 8.bmp │ ├── groud.bmp │ └── AllInFocus.bmp ├── Train_data └── imageId_8 │ ├── 1.bmp │ ├── 2.bmp │ ├── 3.bmp │ ├── 4.bmp │ ├── 5.bmp │ ├── 6.bmp │ ├── 7.bmp │ ├── 8.bmp │ ├── groud.bmp │ └── AllInFocus.bmp ├── main.py ├── LICENSE ├── .gitignore ├── data.py ├── README.md ├── model.py └── test.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.5.0 2 | tensorflow-tensorboard==1.5.0 3 | numpy==1.14.0 4 | keras==2.1.5 5 | scikit-image==0.13.1 6 | -------------------------------------------------------------------------------- /Val_data/imageId_3/1.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/1.bmp -------------------------------------------------------------------------------- /Val_data/imageId_3/2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/2.bmp -------------------------------------------------------------------------------- /Val_data/imageId_3/3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/3.bmp -------------------------------------------------------------------------------- /Val_data/imageId_3/4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/4.bmp -------------------------------------------------------------------------------- /Val_data/imageId_3/5.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/5.bmp -------------------------------------------------------------------------------- /Val_data/imageId_3/6.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/6.bmp -------------------------------------------------------------------------------- /Val_data/imageId_3/7.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/7.bmp -------------------------------------------------------------------------------- /Val_data/imageId_3/8.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/8.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/1.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/1.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/2.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/2.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/3.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/4.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/4.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/5.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/5.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/6.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/6.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/7.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/7.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/8.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/8.bmp -------------------------------------------------------------------------------- /Val_data/imageId_3/groud.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/groud.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/groud.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/groud.bmp -------------------------------------------------------------------------------- /Val_data/imageId_3/AllInFocus.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Val_data/imageId_3/AllInFocus.bmp -------------------------------------------------------------------------------- /Train_data/imageId_8/AllInFocus.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NerdToMars/Fringe-pattern-denoising-based-on-deep-learning/HEAD/Train_data/imageId_8/AllInFocus.bmp -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from model import * 2 | from data import * 3 | 4 | # the absolute folder path of your data 5 | # suggestion-> train:val = 8:2 6 | train_dir = "./Val_data/" 7 | val_dir = "./Train_data/" 8 | 9 | 10 | all_list = os.listdir(train_dir) 11 | # absolute path of all the data 12 | train_dataset = [train_dir+s_dir+'/' for s_dir in all_list] 13 | 14 | all_list = os.listdir(val_dir) 15 | val_dataset = [val_dir+s_dir+'/' for s_dir in all_list] 16 | 17 | 18 | train_data_generator = data_generator(train_dataset,batch_size=1) 19 | val_data_generator = data_generator(val_dataset,batch_size=1) 20 | 21 | model = LSTMUnet() 22 | model_checkpoint = ModelCheckpoint('cnn.hdf5', monitor='loss',verbose=1, save_best_only=True) 23 | model.fit_generator(train_data_generator,validation_data=val_data_generator,validation_steps =10,steps_per_epoch=8000,epochs=1,callbacks=[model_checkpoint]) 24 | 25 | # train_generator, 26 | # initial_epoch=self.epoch, 27 | # epochs=epochs, 28 | # steps_per_epoch=self.config.STEPS_PER_EPOCH, 29 | # callbacks=callbacks, 30 | # validation_data=val_generator, 31 | # validation_steps=self.config.VALIDATION_STEPS, 32 | # max_queue_size=1, 33 | # workers=1, 34 | # use_multiprocessing=False, 35 | 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2019, NerdToMars 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from keras.preprocessing.image import ImageDataGenerator 2 | import numpy as np 3 | import os 4 | import glob 5 | import skimage.io as io 6 | import skimage.transform as trans 7 | from keras.preprocessing import image as KI 8 | 9 | def data_generator(dataset,shuffle=True,target_size = (8,256,256),batch_size=1): 10 | ''' 11 | dataset: Array of all the data, each item is the string of the absolute folder path 12 | target_size: the shape of sample images(number, width, height), 8 means there are 8 images captured in 1 experiment. 13 | ''' 14 | image_stack_ids = np.copy(dataset) 15 | stack_index = -1 16 | b = 0 # batch item index 17 | error_count = 0 18 | while True: 19 | try: 20 | # Increment index to pick next image. Shuffle if at the start of an epoch. 21 | stack_index = (stack_index + 1) % len(image_stack_ids) 22 | if shuffle and stack_index == 0: 23 | np.random.shuffle(image_stack_ids) 24 | 25 | stack_id = image_stack_ids[stack_index] 26 | all_in_focus_image = KI.load_img(stack_id+'/AllInFocus.bmp',target_size=target_size[1:3],grayscale=True) 27 | all_in_focus_image = KI.img_to_array(all_in_focus_image) 28 | if b == 0: 29 | batch_image_stack = np.zeros( 30 | (batch_size,target_size[0],)+all_in_focus_image.shape, dtype=all_in_focus_image.dtype) 31 | batch_depth = np.zeros( 32 | (batch_size,)+all_in_focus_image.shape,dtype=all_in_focus_image.dtype) 33 | 34 | #load image stack and depth add to batch 35 | for i in range(target_size[0]): 36 | np_img = KI.img_to_array( 37 | KI.load_img(stack_id+str(i+1)+'.bmp',target_size=target_size[1:3],grayscale=True)) 38 | # print(batch_image_stack.shape) 39 | # print(b,i) 40 | batch_image_stack[b][i] = np_img 41 | 42 | np_depth = KI.img_to_array( 43 | KI.load_img(stack_id+'/groud.bmp',target_size=target_size[1:3],grayscale=True)) 44 | batch_depth[b] = np_depth 45 | b += 1 46 | if b >= batch_size: 47 | b = 0 48 | yield batch_image_stack, batch_depth 49 | except (GeneratorExit, KeyboardInterrupt): 50 | raise 51 | except: 52 | # Log it and skip the image 53 | print(dataset[stack_index]) 54 | error_count += 1 55 | if error_count > 5: 56 | raise 57 | 58 | def data_feed(data_dir,order=True,target_size=(8,256,256)): 59 | ''' 60 | load a single data 61 | ''' 62 | all_list = os.listdir(data_dir) 63 | 64 | image_idx = [] 65 | for file_name in all_list: 66 | start_idx = file_name.find('-') 67 | end_idx = file_name.find('.') 68 | file_idx = int(file_name[start_idx:end_idx]) 69 | image_idx.append(file_idx) 70 | 71 | if order: 72 | image_idx.sort() 73 | 74 | img_ki = KI.load_img(data_dir+'1'+str(image_idx[0])+'.bmp',target_size=target_size[1:3],grayscale=True) 75 | img_ki_array = KI.img_to_array(img_ki) 76 | image_stack = np.zeros((target_size[0],)+img_ki_array.shape,dtype=np.uint8) 77 | for idx,item in enumerate(image_idx): 78 | image_stack[idx] = KI.img_to_array( 79 | KI.load_img(data_dir+'1'+str(item)+'.bmp',target_size=target_size[1:3],grayscale=True)) 80 | 81 | return image_stack 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fringe-pattern-denoising-based-on-deep-learning 2 | Keras model, please change to your data or your perfered model 3 | 4 | ## Training 5 | make sure your data folder structure is correct: 6 | 7 | Tran_data 8 | - data_[number] 9 | - 1.bmp 10 | - 2.bmp 11 | - ... 12 | - ground.bmp #ground truth 13 | 14 | Val_data 15 | - data_[number] 16 | - 1.bmp 17 | - 2.bmp 18 | - ... 19 | - ground.bmp #ground truth 20 | 21 | 1. modify the steps and epochs in main.py 22 | 2. run the main.py 23 | 24 | or you can use jupyter 25 | 26 | ================================================================================================ 27 | 28 | # LSTM UNET TEST 29 | 30 | 31 | ```python 32 | from model import * 33 | from data import * 34 | import os 35 | import sys 36 | 37 | ROOTPATH = os.path.abspath("./") 38 | sys.path.append(ROOTPATH) 39 | train_dir = "./TRAIN/" 40 | val_dir = "./VAL/" 41 | 42 | test_dir = "./test/" 43 | 44 | all_list = os.listdir(train_dir) 45 | train_dataset = [train_dir+s_dir+'/' for s_dir in all_list] 46 | 47 | all_list = os.listdir(val_dir) 48 | val_dataset = [val_dir+s_dir+'/' for s_dir in all_list] 49 | 50 | 51 | 52 | 53 | ``` 54 | 55 | /usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. 56 | from ._conv import register_converters as _register_converters 57 | Using TensorFlow backend. 58 | 59 | 60 | # Create data generator 61 | 62 | 63 | ```python 64 | train_data_generator = data_generator(train_dataset,batch_size=1) 65 | val_data_generator = data_generator(val_dataset,batch_size=1) 66 | ``` 67 | 68 | # Tensorboard records (optional) 69 | need install tensorboard 70 | 71 | 72 | ```python 73 | from keras.callbacks import TensorBoard 74 | class TB(TensorBoard): 75 | def __init__(self, log_every=1, **kwargs): 76 | super().__init__(**kwargs) 77 | self.log_every = log_every 78 | self.counter = 0 79 | 80 | def on_batch_end(self, batch, logs=None): 81 | self.counter+=1 82 | if self.counter%self.log_every==0: 83 | for name, value in logs.items(): 84 | if name in ['batch', 'size']: 85 | continue 86 | summary = tf.Summary() 87 | summary_value = summary.value.add() 88 | summary_value.simple_value = value.item() 89 | summary_value.tag = name 90 | self.writer.add_summary(summary, self.counter) 91 | self.writer.flush() 92 | 93 | super().on_batch_end(batch, logs) 94 | 95 | tensorboard_log = TB(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True) 96 | print(model.summary()) 97 | ``` 98 | 99 | # Compile model 100 | 101 | 102 | ```python 103 | model = LSTMUnet() 104 | model_checkpoint = ModelCheckpoint('lstm_unet.hdf5', monitor='loss',verbose=1, save_best_only=True) 105 | print(model.summary()) 106 | 107 | ``` 108 | 109 | # Train 110 | 111 | 112 | ```python 113 | model.fit_generator( 114 | train_data_generator, 115 | validation_data=val_data_generator, 116 | validation_steps =400, 117 | steps_per_epoch=2325, 118 | epochs=15, 119 | callbacks=[model_checkpoint, tensorboard_log]) 120 | 121 | ``` 122 | 123 | # Testing 124 | 125 | 126 | ```python 127 | import matplotlib.pyplot as plt 128 | from mpl_toolkits.mplot3d import Axes3D 129 | ``` 130 | 131 | 132 | ```python 133 | model_val = LSTMUnet() 134 | model_val.load_weights('lstm_3.hdf5') # load your pretrained model 135 | ``` 136 | 137 | 138 | ```python 139 | test_dir = "/cole_driver/dff/PLS3Bx10/" 140 | test_dir2 = "/cole_driver/dff/PLS3Sx50/" #change to your test folder 141 | 142 | input_ = data_feed(test_dir) 143 | input_s = np.zeros((1,)+input_.shape,dtype=np.uint8) 144 | input_s[0] = input_ 145 | 146 | out_d = model_val.predict(input_s) 147 | imd = plt.imshow(out_d[0,::,::,0]) 148 | plt.colorbar() 149 | ``` 150 | 151 | 152 | ================================================================================================ 153 | 154 | 155 | For article, please cite:Fringe pattern denoising based on deep learning DOI: 10.1016/j.optcom.2018.12.058 156 | 157 | Yan, Ketao, Yingjie Yu, Chongtian Huang, Liansheng Sui, Kemao Qian, and Anand Asundi. "Fringe pattern denoising based on deep learning." Optics Communications (2018). 158 | 159 | @article{yan2018fringe, 160 | title={Fringe pattern denoising based on deep learning}, 161 | author={Yan, Ketao and Yu, Yingjie and Huang, Chongtian and Sui, Liansheng and Qian, Kemao and Asundi, Anand}, 162 | journal={Optics Communications}, 163 | year={2018}, 164 | publisher={Elsevier} 165 | } 166 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ This script demonstrates the use of a convolutional LSTM network. 2 | This network is used to predict the next frame of an artificially 3 | generated movie which contains moving squares. 4 | """ 5 | from keras.models import Sequential 6 | from keras.layers.convolutional import Conv3D 7 | from keras.layers.convolutional_recurrent import ConvLSTM2D 8 | from keras.layers.normalization import BatchNormalization 9 | import numpy as np 10 | import pylab as plt 11 | import skimage.transform 12 | import tensorflow as tf 13 | import keras 14 | import keras.backend as K 15 | import keras.layers as KL 16 | 17 | import keras.engine as KE 18 | import keras.models as KM 19 | from keras.models import * 20 | from keras.layers import * 21 | from keras.optimizers import * 22 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler 23 | from keras import backend as keras 24 | 25 | def getBiConvLSTM2d(input_image,filters,kernel_size,name): 26 | L1 = KL.Bidirectional(KL.ConvLSTM2D(filters=filters, kernel_size=kernel_size,activation='relu', padding='same', return_sequences=True))(input_image) 27 | L1 = KL.BatchNormalization(name="batchNormL_"+name)(L1) 28 | return L1 29 | 30 | def unet(inputs): 31 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs) 32 | conv1 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1) 33 | pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) 34 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1) 35 | conv2 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2) 36 | pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) 37 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2) 38 | conv3 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3) 39 | pool3 = MaxPooling2D(pool_size=(2, 2))(conv3) 40 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3) 41 | conv4 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4) 42 | drop4 = Dropout(0.5)(conv4) 43 | print(conv4) 44 | pool4 = MaxPooling2D(pool_size=(2, 2))(drop4) 45 | 46 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4) 47 | conv5 = Conv2D(1024, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5) 48 | drop5 = Dropout(0.5)(conv5) 49 | 50 | up6 = Conv2D(512, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5)) 51 | print(drop4) 52 | print(up6) 53 | merge6 = concatenate([drop4,up6], axis = 3) 54 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6) 55 | conv6 = Conv2D(512, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6) 56 | 57 | up7 = Conv2D(256, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6)) 58 | merge7 = concatenate([conv3,up7], axis = 3) 59 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7) 60 | conv7 = Conv2D(256, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7) 61 | 62 | up8 = Conv2D(128, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7)) 63 | merge8 = concatenate([conv2,up8], axis = 3) 64 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8) 65 | conv8 = Conv2D(128, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8) 66 | 67 | up9 = Conv2D(64, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8)) 68 | merge9 = concatenate([conv1,up9], axis = 3) 69 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9) 70 | conv9 = Conv2D(64, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 71 | conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9) 72 | conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9) 73 | 74 | return conv10 75 | 76 | 77 | def LSTMUnet(): 78 | ''' 79 | Input frame, with size 8*256*256*1 numpy array 80 | ''' 81 | row = 256 82 | col = 256 83 | input_image = KL.Input(shape=[8, row,col,1], name="input_image") 84 | L4 = input_image 85 | for i in range(3): 86 | L4 = getBiConvLSTM2d(L4,filters=20, kernel_size=(3, 3),name='top'+str(i)) 87 | 88 | L5 = KL.Conv3D(filters=8, kernel_size=(3, 3, 8), 89 | activation='relu', 90 | padding='same', data_format='channels_last')(L4) 91 | # L5 = KL.BatchNormalization(name="batchNormL_sel5")(L5) 92 | 93 | L6 = KL.Conv3D(filters=1, kernel_size=(3, 3, 4), 94 | activation='relu', 95 | padding='same', data_format='channels_last')(L5) 96 | L6 = KL.BatchNormalization(name="batchNormL_sel5")(L6) 97 | 98 | L7 = KL.Conv3D(filters=1, kernel_size=(3, 3, 8), 99 | activation='relu', 100 | padding='same', data_format='channels_first')(L6) 101 | L7 = Reshape((row,col,1))(L7) 102 | seg = unet(L7) 103 | model = KM.Model(input_image,seg) 104 | model.compile(loss='mean_squared_error', optimizer='adadelta') 105 | return model 106 | -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# LSTM UNET TEST" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [ 15 | { 16 | "name": "stderr", 17 | "output_type": "stream", 18 | "text": [ 19 | "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", 20 | " from ._conv import register_converters as _register_converters\n", 21 | "Using TensorFlow backend.\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "from model import *\n", 27 | "from data import *\n", 28 | "import os\n", 29 | "import sys\n", 30 | "\n", 31 | "ROOTPATH = os.path.abspath(\"./\")\n", 32 | "sys.path.append(ROOTPATH)\n", 33 | "train_dir = \"./TRAIN/\"\n", 34 | "val_dir = \"./VAL/\" \n", 35 | "\n", 36 | "test_dir = \"./test/\"\n", 37 | "\n", 38 | "all_list = os.listdir(train_dir)\n", 39 | "train_dataset = [train_dir+s_dir+'/' for s_dir in all_list]\n", 40 | "\n", 41 | "all_list = os.listdir(val_dir)\n", 42 | "val_dataset = [val_dir+s_dir+'/' for s_dir in all_list]\n", 43 | "\n", 44 | "\n", 45 | "\n" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "metadata": {}, 51 | "source": [ 52 | "# Create data generator" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 3, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "train_data_generator = data_generator(train_dataset,batch_size=1)\n", 62 | "val_data_generator = data_generator(val_dataset,batch_size=1)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "metadata": {}, 68 | "source": [ 69 | "# Tensorboard records (optional) \n", 70 | "need install tensorboard" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "from keras.callbacks import TensorBoard\n", 80 | "class TB(TensorBoard):\n", 81 | " def __init__(self, log_every=1, **kwargs):\n", 82 | " super().__init__(**kwargs)\n", 83 | " self.log_every = log_every\n", 84 | " self.counter = 0\n", 85 | " \n", 86 | " def on_batch_end(self, batch, logs=None):\n", 87 | " self.counter+=1\n", 88 | " if self.counter%self.log_every==0:\n", 89 | " for name, value in logs.items():\n", 90 | " if name in ['batch', 'size']:\n", 91 | " continue\n", 92 | " summary = tf.Summary()\n", 93 | " summary_value = summary.value.add()\n", 94 | " summary_value.simple_value = value.item()\n", 95 | " summary_value.tag = name\n", 96 | " self.writer.add_summary(summary, self.counter)\n", 97 | " self.writer.flush()\n", 98 | " \n", 99 | " super().on_batch_end(batch, logs)\n", 100 | " \n", 101 | "tensorboard_log = TB(log_dir='./Graph', histogram_freq=0, write_graph=True, write_images=True)\n", 102 | "print(model.summary())" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "# Compile model" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 5, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "model = LSTMUnet()\n", 119 | "model_checkpoint = ModelCheckpoint('lstm_unet.hdf5', monitor='loss',verbose=1, save_best_only=True)\n", 120 | "print(model.summary())\n" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "metadata": {}, 126 | "source": [ 127 | "# Train" 128 | ] 129 | }, 130 | { 131 | "cell_type": "code", 132 | "execution_count": 52, 133 | "metadata": {}, 134 | "outputs": [], 135 | "source": [ 136 | "model.fit_generator(\n", 137 | " train_data_generator,\n", 138 | " validation_data=val_data_generator,\n", 139 | " validation_steps =400,\n", 140 | " steps_per_epoch=2325,\n", 141 | " epochs=15,\n", 142 | " callbacks=[model_checkpoint, tensorboard_log])\n" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "# Testing" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 2, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "import matplotlib.pyplot as plt\n", 159 | "from mpl_toolkits.mplot3d import Axes3D" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": 12, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "model_val = LSTMUnet()\n", 169 | "model_val.load_weights('lstm_3.hdf5') # load your pretrained model" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 15, 175 | "metadata": {}, 176 | "outputs": [], 177 | "source": [ 178 | "test_dir = \"/cole_driver/dff/PLS3Bx10/\"\n", 179 | "test_dir2 = \"/cole_driver/dff/PLS3Sx50/\" #change to your test folder\n", 180 | "\n", 181 | "input_ = data_feed(test_dir)\n", 182 | "input_s = np.zeros((1,)+input_.shape,dtype=np.uint8)\n", 183 | "input_s[0] = input_\n", 184 | "\n", 185 | "out_d = model_val.predict(input_s)\n", 186 | "imd = plt.imshow(out_d[0,::,::,0])\n", 187 | "plt.colorbar()" 188 | ] 189 | } 190 | ], 191 | "metadata": { 192 | "kernelspec": { 193 | "display_name": "Python 3", 194 | "language": "python", 195 | "name": "python3" 196 | }, 197 | "language_info": { 198 | "codemirror_mode": { 199 | "name": "ipython", 200 | "version": 3 201 | }, 202 | "file_extension": ".py", 203 | "mimetype": "text/x-python", 204 | "name": "python", 205 | "nbconvert_exporter": "python", 206 | "pygments_lexer": "ipython3", 207 | "version": "3.5.2" 208 | } 209 | }, 210 | "nbformat": 4, 211 | "nbformat_minor": 2 212 | } 213 | --------------------------------------------------------------------------------