├── .gitignore ├── README.md ├── U-Net_Predict.py ├── U-Net_Training.py ├── base_functions.py └── configuration.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # U-Net for segmentation 2 | This is the reference implementation of the models and code for the U-Net which was proposed by O Ronneberger,P Fischer and T Brox. 3 | You can set parameters like file directory, image size, number of classes etc. in configuration.txt, then run U-Net_Training.py to train the model, run U-Net_Predict.py to segment images. 4 | It's mainly for gray images. In my own seg task, the gt mask is saved as png(4 classes and 1 for bg), and the corresponding gray value is [0, 85, 170, 255] 5 | -------------------------------------------------------------------------------- /U-Net_Predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author:CaoZhihui 4 | 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | import keras 9 | from keras.models import Model 10 | from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dense, Dropout, Reshape, Activation, core, \ 11 | Permute 12 | 13 | from keras import backend as K 14 | import os 15 | import configparser as ConfigParser 16 | import warnings # 不显示乱七八糟的warning 17 | from base_functions import get_test_data, pred_to_imgs 18 | import h5py 19 | from keras.models import model_from_json 20 | from PIL import Image 21 | from IPython.terminal.tests.test_help import test_profile_list_help 22 | 23 | warnings.filterwarnings("ignore") 24 | K.set_image_dim_ordering('th') 25 | 26 | # --read configuration file and get parameters-- # 27 | config = ConfigParser.RawConfigParser() 28 | config.read('./configuration.txt') 29 | path_local = config.get('unet_parameters', 'path_local') 30 | model_path = path_local + config.get('unet_parameters', 'unet_model_dir') 31 | test_images_dir = path_local + config.get('unet_parameters', 'test_images_dir') 32 | test_labels_dir = path_local + config.get('unet_parameters', 'test_labels_dir') 33 | img_h = int(config.get('unet_parameters', 'img_h')) 34 | img_w = int(config.get('unet_parameters', 'img_w')) 35 | C = int(config.get('unet_parameters', 'C')) 36 | 37 | print('-'*30) 38 | print('Loading model and weights...') 39 | print('-'*30) 40 | model = model_from_json(open(model_path + 'unet.json').read()) 41 | model.load_weights(model_path+'unet_weights.h5') 42 | print('Loading test data...') 43 | print('-'*30) 44 | test_x, test_y = get_test_data(test_images_dir, test_labels_dir, img_h, img_w) 45 | print('Predicting...') 46 | predictions = model.predict(test_x) 47 | 48 | pred_images = pred_to_imgs(predictions, img_h, img_w, C=C) 49 | pred_images /= np.max(pred_images) 50 | 51 | print('-' * 30 + '\n' + 'Saving predicted images...' + '\n' + '-' * 30) 52 | for i in range(pred_images.shape[0]): 53 | img = Image.fromarray(pred_images[i]*255) 54 | img = img.convert('L') 55 | img.save('./U-Net/pred_images/pred_' + str(i) + '.png') 56 | 57 | 58 | -------------------------------------------------------------------------------- /U-Net_Training.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # Author:CaoZhihui 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import keras 8 | from keras.models import Model 9 | from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dense, Dropout, Reshape, Activation, core, \ 10 | Permute 11 | from keras.optimizers import Adam 12 | from keras.callbacks import ModelCheckpoint, LearningRateScheduler 13 | from keras import backend as K 14 | import os 15 | import configparser as ConfigParser 16 | import warnings 17 | from base_functions import get_train_data 18 | import h5py 19 | from keras.models import model_from_json 20 | from keras.optimizers import SGD 21 | from IPython.terminal.tests.test_help import test_profile_list_help 22 | 23 | warnings.filterwarnings("ignore") 24 | K.set_image_dim_ordering('th') 25 | 26 | # --read configuration file-- # 27 | config = ConfigParser.RawConfigParser() 28 | config.read('./configuration.txt') 29 | 30 | # --get parameters-- # 31 | path_local = config.get('unet_parameters', 'path_local') 32 | train_images_dir = path_local + config.get('unet_parameters', 'train_images_dir') 33 | train_labels_dir = path_local + config.get('unet_parameters', 'train_labels_dir') 34 | img_h = int(config.get('unet_parameters', 'img_h')) 35 | img_w = int(config.get('unet_parameters', 'img_w')) 36 | N_channels = int(config.get('unet_parameters', 'N_channels')) 37 | C = int(config.get('unet_parameters', 'C')) 38 | 39 | if C > 2: 40 | gt_list = eval(config.get('unet_parameters', 'gt_gray_value_list')) 41 | else: 42 | gt_list = None 43 | 44 | 45 | # --Build a net work-- # 46 | def get_net(): 47 | inputs = Input(shape=(N_channels, img_h, img_w)) 48 | # Block 1 49 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(inputs) 50 | conv1 = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv1) 51 | pool1 = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(conv1) 52 | 53 | # Block 2 54 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool1) 55 | conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv2) 56 | pool2 = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(conv2) 57 | 58 | # Block 3 59 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool2) 60 | conv3 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv3) 61 | pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(conv3) 62 | 63 | # Block 4 64 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool3) 65 | conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv4) 66 | pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(conv4) 67 | 68 | # Block 5 69 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(pool4) 70 | conv5 = Conv2D(512, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv5) 71 | 72 | up6 = merge([UpSampling2D(size=(2, 2))(conv5), conv4], mode='concat', concat_axis=1) 73 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(up6) 74 | conv6 = Conv2D(256, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv6) 75 | 76 | up7 = merge([UpSampling2D(size=(2, 2))(conv6), conv3], mode='concat', concat_axis=1) 77 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(up7) 78 | conv7 = Conv2D(128, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv7) 79 | 80 | up8 = merge([UpSampling2D(size=(2, 2))(conv7), conv2], mode='concat', concat_axis=1) 81 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(up8) 82 | conv8 = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(conv8) 83 | 84 | up9 = merge([UpSampling2D(size=(2, 2))(conv8), conv1], mode='concat', concat_axis=1) 85 | conv9 = Conv2D(32, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal')(up9) 86 | 87 | conv10 = Conv2D(C, (1, 1), activation='relu', kernel_initializer='he_normal')(conv9) 88 | 89 | reshape = Reshape((C, img_h * img_w), input_shape=(C, img_h, img_w))(conv10) 90 | reshape = Permute((2, 1))(reshape) 91 | 92 | activation = Activation('softmax')(reshape) 93 | 94 | model = Model(input=inputs, output=activation) 95 | 96 | model.compile(optimizer=Adam(lr=1.0e-4), loss='categorical_crossentropy', metrics=['accuracy']) 97 | 98 | return model 99 | 100 | 101 | print('-' * 30) 102 | print('Loading and pre-processing train data...') 103 | print('-' * 30) 104 | 105 | train_x, train_y = get_train_data(train_images_dir, train_labels_dir, img_h, img_w, C=C, gt_list=gt_list) 106 | 107 | print('train_x size: ', np.shape(train_x)) 108 | print('train_y size: ', np.shape(train_y)) 109 | 110 | assert(train_y.shape[1] == img_h and train_y.shape[2] == img_w) 111 | train_y = np.reshape(train_y, (train_y.shape[0], img_h*img_w, C)) 112 | model_path = path_local+config.get('unet_parameters', 'unet_model_dir') 113 | 114 | # --Check whether the output path of the model exists or not-- # 115 | if os.path.isdir(model_path): 116 | pass 117 | else: 118 | os.mkdir(model_path) 119 | # ------------------------------------ # 120 | print('-' * 30) 121 | print('Creating and compiling model...') 122 | print('-' * 30) 123 | model = get_net() 124 | model_checkpoint = ModelCheckpoint(model_path + '/unet.hdf5', monitor='loss', save_best_only=True) 125 | 126 | print('-' * 30) 127 | print('Fitting model...') 128 | print('-' * 30) 129 | batch_size = int(config.get('unet_parameters', 'batch_size')) 130 | epochs = int(config.get('unet_parameters', 'N_epochs')) 131 | val_rate = config.get('unet_parameters', 'validation_rate') 132 | hist = model.fit(x=train_x, y=train_y, batch_size=batch_size, epochs=epochs, verbose=2, shuffle=True, 133 | validation_split=float(val_rate), callbacks=[model_checkpoint], initial_epoch=0) 134 | with open(model_path + '/unet.txt', 'w') as f: 135 | f.write(str(hist.history)) 136 | print('-' * 30) 137 | print('Loading saved weights...') 138 | print('-' * 30) 139 | model.load_weights(model_path + '/unet.hdf5') 140 | json_string = model.to_json() # equal to: json_string = model.get_config() 141 | open(model_path + '/unet.json', 'w').write(json_string) 142 | model.save_weights(model_path + '/unet_weights.h5') 143 | -------------------------------------------------------------------------------- /base_functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.image as mpimg 4 | 5 | 6 | def get_train_data(train_images_dir, train_labels_dir, img_h, img_w, N_channels=1, C=2, gt_list = None): 7 | print('-'*30) 8 | print('Loading train images...') 9 | print('-'*30) 10 | assert(C == 2 or C > 2) 11 | files = os.listdir(train_images_dir) # get file names 12 | total_images = np.zeros([len(files), img_h, img_w, N_channels]) # for storing training imgs 13 | for idx in range(len(files)): # 14 | img = mpimg.imread(os.path.join(train_images_dir, files[idx])) 15 | if len(img.shape) == 2: 16 | img = img[:, :, np.newaxis] 17 | total_images[idx, :, :, :img.shape[-1]] = img 18 | total_images = total_images/np.max(total_images) 19 | mean = np.mean(total_images, axis=0) 20 | np.save('./mean_img.npy', mean) 21 | total_images -= mean 22 | total_images = np.transpose(total_images, [0, 3, 1, 2]) 23 | 24 | print('-'*30) 25 | print('Loading train labels...') 26 | print('-'*30) 27 | files2 = os.listdir(train_labels_dir) 28 | total_labels = np.zeros([len(files), img_h, img_w, C]) 29 | if C == 2: 30 | for idx in range(len(files)): 31 | ground_truth = mpimg.imread(os.path.join(train_labels_dir, files2[idx])) 32 | total_labels[idx, :, :, 0] = ((ground_truth == 0)*1) 33 | total_labels[idx, :, :, 1] = ((ground_truth != 0)*1) 34 | else: 35 | if gt_list is None: 36 | print('-' * 30 + '\n' + 'There is a lack of a list of GT values!') 37 | raise Exception 38 | else: 39 | for idx in range(len(files)): 40 | ground_truth = mpimg.imread(os.path.join(train_labels_dir, files2[idx])) 41 | for ch in range(C): 42 | total_labels[idx, :, :, ch] = ((ground_truth == gt_list[ch]) * 1) 43 | 44 | return total_images, total_labels 45 | 46 | 47 | def get_test_data(test_images_dir, test_labels_dir, img_h, img_w, N_channels=1, C=2): 48 | print('-'*30) 49 | print('Loading test images...') 50 | 51 | files = os.listdir(test_images_dir) # get file names 52 | total_images = np.zeros([len(files), img_h, img_w, N_channels]) 53 | for idx in range(len(files)): 54 | img = mpimg.imread(os.path.join(test_images_dir,files[idx])) 55 | if len(img.shape) == 2: 56 | img = img[:, :, np.newaxis] 57 | total_images[idx, :, :, :img.shape[-1]] = img 58 | total_images = total_images/np.max(total_images) 59 | mean = np.load('./mean_img.npy') 60 | total_images -= mean 61 | total_images = np.transpose(total_images, [0, 3, 1, 2]) # transpose to shape[N, channels, h, w] 62 | 63 | print('-'*30) 64 | print('Loading test labels...') 65 | print('-'*30) 66 | files2 = os.listdir(test_labels_dir) 67 | total_labels = np.zeros([len(files2), img_h, img_w]) 68 | for idx in range(len(files2)): 69 | total_labels[idx] = mpimg.imread(os.path.join(test_labels_dir, files2[idx])) 70 | 71 | return total_images, total_labels 72 | 73 | 74 | def pred_to_imgs(predictions, img_h, img_w, C=2): 75 | assert(len(predictions.shape) == 3) 76 | assert(predictions.shape[1] == img_h*img_w) 77 | N_images = predictions.shape[0] 78 | predictions = np.reshape(predictions, [N_images, img_h, img_w, C]) 79 | pred_images = np.argmax(predictions, axis=3) 80 | 81 | return pred_images 82 | 83 | 84 | def pred_to_imgs_bak(predictions, img_h, img_w, C=2): 85 | assert(len(predictions.shape) == 3) 86 | assert(predictions.shape[1] == img_h*img_w) 87 | N_images = predictions.shape[0] 88 | predictions = np.reshape(predictions,[N_images, img_h, img_w, C]) 89 | pred_images = np.zeros([N_images, img_h, img_w]) 90 | for img in range(N_images): 91 | for h in range(img_h): 92 | for w in range(img_w): 93 | l = list(predictions[img, h, w, :]) 94 | pred_images[img, h, w] = l.index(max(l)) 95 | pred_images /= np.max(pred_images) 96 | assert(np.min(pred_images) == 0 and np.max(pred_images) == 1) 97 | 98 | return pred_images 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /configuration.txt: -------------------------------------------------------------------------------- 1 | 2 | [unet_parameters] 3 | #--------data paths-------# 4 | path_local = ./U-Net/ 5 | train_images_dir = /train_images/ 6 | train_labels_dir = /train_labels/ 7 | test_images_dir = /test_images/ 8 | test_labels_dir = /test_labels/ 9 | unet_model_dir = /model/ 10 | 11 | 12 | #-----data parameters-----# 13 | img_h = 256 14 | img_w = 256 15 | C = 2 16 | N_channels = 1 17 | gt_gray_value_list = [0, 85, 170, 255] 18 | #image height, image width, num of classes respectively 19 | 20 | 21 | #----training settings----# 22 | N_epochs = 40 23 | batch_size = 5 24 | validation_rate = 0.1 25 | 26 | 27 | --------------------------------------------------------------------------------