├── img ├── 1.jpg └── 2.jpg ├── myscript.sh ├── README.md ├── read_Data_list.py ├── BatchDatasetReader.py └── train_main.py /img/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiangwenliu/IDRiD-Lesion-Segmentation/HEAD/img/1.jpg -------------------------------------------------------------------------------- /img/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiangwenliu/IDRiD-Lesion-Segmentation/HEAD/img/2.jpg -------------------------------------------------------------------------------- /myscript.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #source ~/.bashrc 3 | #hostname 4 | #write my command 5 | python train_main.py 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diabetic Retinopathy Lesion Segmentation 2 | The automatic segmentation of retinal lesions. 3 | 4 | # Dependence 5 | * python 2.7 6 | * tensorflow 1.6 7 | * tensorlayer 1.8 8 | * tensorboard 9 | 10 | # Data 11 | * It consists of 81 images with pixel level annotation, training data 54 samples, test data 27 samples. 12 | * [Down Load Indian Diabetic Retinopathy Image Dataset (IDRiD)](https://ieee-dataport.org/open-access/indian-diabetic-retinopathy-image-dataset-idrid) 13 | 14 | # Performance Evaluation 15 | Evaluates the performance of the algorithms for lesion segmentation using the available binary masks. The area under precision-recall (AUPR) is used to obtain a single score. 16 | 17 | # Experiment 18 | 19 | 20 | # Result 21 | 22 | 23 | -------------------------------------------------------------------------------- /read_Data_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from six.moves import cPickle as pickle 4 | from tensorflow.python.platform import gfile 5 | import glob 6 | 7 | 8 | 9 | 10 | def read_dataset(data_dir): 11 | pickle_filename = "iDrid.pickle" 12 | pickle_filepath = os.path.join(data_dir, pickle_filename) 13 | if not os.path.exists(pickle_filepath): 14 | result = create_image_lists(data_dir) 15 | print ("Pickling ...") 16 | with open(pickle_filepath, 'wb') as f: 17 | pickle.dump(result, f, protocol=2) 18 | else: 19 | print ("Found pickle file!") 20 | 21 | with open(pickle_filepath, 'rb') as f: 22 | result = pickle.load(f) 23 | training_records = result['training'] 24 | validation_records = result['validation'] 25 | del result 26 | 27 | return training_records, validation_records 28 | 29 | 30 | def create_image_lists(image_dir): 31 | if not gfile.Exists(image_dir): 32 | print("Image directory '" + image_dir + "' not found.") 33 | return None 34 | directories = ['training', 'validation'] 35 | image_list = {} 36 | for directory in directories: 37 | file_list = [] 38 | image_list[directory] = [] 39 | file_glob = os.path.join(image_dir, "images", directory, '*.' + 'jpg') 40 | file_list.extend(glob.glob(file_glob)) 41 | if not file_list: 42 | print('No files found') 43 | else: 44 | for f in file_list: 45 | filename = os.path.splitext(f.split("/")[-1])[0]+'_EX' #windows->\\,linux->/ 46 | annotation_file = os.path.join(image_dir, "annotations", directory, filename + '.tif') 47 | if os.path.exists(annotation_file): 48 | record = {'image': f, 'annotation': annotation_file, 'filename': filename} 49 | image_list[directory].append(record) 50 | else: 51 | print("Annotation file not found for %s - Skipping" % filename) 52 | 53 | random.shuffle(image_list[directory]) 54 | no_of_images = len(image_list[directory]) 55 | print ('No. of %s files: %d' % (directory, no_of_images)) 56 | 57 | #print(image_list) 58 | return image_list 59 | 60 | 61 | #create_image_lists('Data/') -------------------------------------------------------------------------------- /BatchDatasetReader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.misc as misc 3 | from PIL import Image, ImageOps, ImageEnhance 4 | import random 5 | 6 | 7 | class BatchDatset: 8 | files = [] 9 | images = [] 10 | annotations = [] 11 | image_options = {} 12 | batch_offset = 0 13 | epochs_completed = 0 # don't use this time 14 | ratio = 0.25 15 | 16 | def __init__(self, records_list, image_options={}, augmentation=False): 17 | """ 18 | Intialize a generic file reader with batching for list of files 19 | :param records_list: list of file name records to read - 20 | sample record: {'image': f, 'annotation': annotation_file, 'filename': filename} 21 | :param image_options: A dictionary of options for modifying the output image 22 | Available options: 23 | resize = True/ False 24 | resize_size = #size of output image - does bilinear resize 25 | color=True/False 26 | """ 27 | print("Initializing Batch Dataset Reader...") 28 | print(image_options) 29 | self.files = records_list 30 | self.image_options = image_options 31 | self.prosess_image = True 32 | self.data_augmentation = augmentation 33 | self.__channels = True 34 | self._read_images() 35 | 36 | def _read_images(self): 37 | # self.__channels = True # require gray immages 38 | # self.prosessimg = True 39 | # self.images = np.array([self._transform(filename['image']) for filename in self.files]) 40 | # # self.images = np.expand_dims(self.images, 3) 41 | # self.__channels = True 42 | # self.prosessimg = False 43 | # self.annotations = np.array([self._transform(filename['annotation']) for filename in self.files]) 44 | # self.annotations = np.expand_dims(self.annotations, 3) 45 | # print('self.images.shape:', self.images.shape) 46 | # print('self.annotations.shape:', self.annotations.shape) 47 | 48 | self.imageslist = [self._get_image(filename['image']) for filename in self.files] 49 | self.annotationslist = [self._get_image(filename['annotation']) for filename in self.files] 50 | print('----self.images.length:', len(self.imageslist)) 51 | print('----self.annotations.length:', len(self.annotationslist)) 52 | 53 | def _get_image(self, filename): 54 | image = Image.open(filename) 55 | 56 | return image 57 | 58 | 59 | def _transform(self, filename): 60 | image_object = Image.open(filename) 61 | crop_image = image_object.crop((270, 0, 3710, 2848)) 62 | padding = (0, 296, 0, 296) 63 | pad_image = ImageOps.expand(crop_image, padding) # width and height same size 64 | 65 | if self.__channels == False: 66 | image = np.array(pad_image.convert('L')) 67 | else: 68 | image = np.array(pad_image) 69 | 70 | if self.image_options.get("resize", False) and self.image_options["resize"]: 71 | resize_size = int(self.image_options["resize_size"]) 72 | resize_image = misc.imresize(image, 73 | [resize_size, resize_size], 74 | interp='bicubic') # bicubic interpolation resize image 75 | else: 76 | resize_image = image 77 | 78 | if self.prosess_image is True: 79 | resize_image = resize_image * (1.0 / 255) 80 | resize_image = per_image_standardization(resize_image) 81 | 82 | return np.array(resize_image) 83 | 84 | def _crop_resize_image(self, image, annotation): 85 | rate = [0.248, 0.25, 0.252] 86 | index = random.randint(0, 2) 87 | old_size = image.size 88 | new_size = tuple([int(x * rate[index]) for x in old_size]) 89 | resize_image = image.resize(size=new_size, resample=3) 90 | resize_annotation = annotation.resize(size=new_size, resample=3) 91 | crop_image = resize_image.crop((66, 0, 930, 712)) 92 | crop_annotation = resize_annotation.crop((66, 0, 930, 712)) 93 | padding = (0, 76, 0, 76) 94 | image = ImageOps.expand(crop_image, padding) 95 | annotation = ImageOps.expand(crop_annotation, padding) 96 | if image.size != annotation.size: 97 | raise ValueError("Image and annotation size not equal !!!") 98 | 99 | # random crop image to size (640, 640) 100 | width, height = image.size 101 | resize = int(self.image_options["resize_size"]) 102 | x = random.randint(0, width - resize - 1) 103 | y = random.randint(0, height - resize - 1) 104 | image = image.crop((x, y, x + resize, y + resize)) 105 | annotation = annotation.crop((x, y, x + resize, y + resize)) 106 | 107 | 108 | if self.data_augmentation is True: 109 | # light 110 | enh_bri = ImageEnhance.Brightness(image) 111 | brightness = round(random.uniform(0.8, 1.2), 2) 112 | image = enh_bri.enhance(brightness) 113 | 114 | # color 115 | enh_col = ImageEnhance.Color(image) 116 | color = round(random.uniform(0.8, 1.2), 2) 117 | image = enh_col.enhance(color) 118 | 119 | 120 | # contrast 121 | enh_con = ImageEnhance.Contrast(image) 122 | contrast = round(random.uniform(0.8, 1.2), 2) 123 | image = enh_con.enhance(contrast) 124 | # 125 | # enh_sha = ImageEnhance.Sharpness(image) 126 | # sharpness = round(random.uniform(0.8, 1.2), 2) 127 | # image = enh_sha.enhance(sharpness) 128 | 129 | method = random.randint(0, 7) 130 | # print(method) 131 | if method < 7: 132 | image = image.transpose(method) 133 | annotation = annotation.transpose(method) 134 | degree = random.randint(-5, 5) 135 | image = image.rotate(degree) 136 | annotation = annotation.rotate(degree) 137 | 138 | image_array = np.array(image) 139 | #standardization image 140 | if self.prosess_image is True: 141 | image_array = image_array * (1.0 / 255) 142 | # image_array = per_image_standardization(image_array) 143 | 144 | annotation_array = np.array(annotation) 145 | return np.array(image_array), annotation_array 146 | 147 | def get_records(self): 148 | return self.imageslist, self.annotationslist 149 | 150 | def reset_batch_offset(self, offset=0): 151 | self.batch_offset = offset 152 | 153 | def next_batch(self, batch_size): 154 | start = self.batch_offset 155 | self.batch_offset += batch_size 156 | if self.batch_offset > len(self.imageslist): 157 | # Finished epoch 158 | # self.epochs_completed += 1 159 | # print("****************** Epochs completed: " + str(self.epochs_completed) + "******************") 160 | # Shuffle the data 161 | # perm = np.arange(len(self.imageslist)) 162 | # np.random.shuffle(perm) 163 | c = list(zip(self.imageslist, self.annotationslist)) 164 | random.shuffle(c) 165 | self.imageslist, self.annotationslist = zip(*c) 166 | # self.imageslist = self.imageslist[perm] 167 | # self.annotationslist = self.annotationslist[perm] 168 | # Start next epoch 169 | start = 0 170 | self.batch_offset = batch_size 171 | 172 | end = self.batch_offset 173 | image_batch = [] 174 | annotation_batch = [] 175 | for (image, annotation) in zip(self.imageslist[start:end], self.annotationslist[start:end]): 176 | img, annot = self._crop_resize_image(image, annotation) 177 | image_batch.append(img) 178 | annotation_batch.append(annot) 179 | return np.array(image_batch), np.expand_dims(np.array(annotation_batch), 3) 180 | 181 | def get_random_batch(self, batch_size): 182 | indexes = np.random.randint(0, len(self.imageslist), size=[batch_size]).tolist() 183 | image = [] 184 | annotation = [] 185 | for index in indexes: 186 | img, annot = self._crop_resize_image(self.imageslist[index], self.annotationslist[index]) 187 | image.append(img) 188 | annotation.append(annot) 189 | return np.array(image), np.expand_dims(np.array(annotation), 3) 190 | 191 | 192 | def per_image_standardization(image): 193 | image = image.astype(np.float32, copy=False) 194 | mean = np.mean(image) 195 | stddev = np.std(image) 196 | adjusted_stddev = max(stddev, 1.0 / np.sqrt(np.array(image.size, dtype=np.float32))) 197 | im = (image - mean) / adjusted_stddev 198 | return im -------------------------------------------------------------------------------- /train_main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Apr 23 10:01:27 2018 5 | 6 | @author: shawn 7 | """ 8 | 9 | import tensorflow as tf 10 | import tensorlayer as tl 11 | import numpy as np 12 | import BatchDatasetReader as BDR 13 | import read_Data_list as RDL 14 | import sys 15 | import time 16 | from sklearn.metrics import roc_auc_score 17 | from sklearn.metrics import auc 18 | from tensorflow.python.framework import ops 19 | from tensorflow.python.ops import math_ops 20 | 21 | # path variable 22 | logs_dir = 'logs/' 23 | data_dir = '/home/lxw/tensorflowproject/data/idrid' 24 | 25 | # basic constant variable 26 | IMG_SIZE = 640 27 | num_of_classes = 2 28 | print_freq = 10 29 | WIDTH = 4288 30 | HEIGHT = 2848 31 | # training constant variable 32 | MAX_EPOCH = 500 33 | batch_size = 1 34 | test_batchsize = 1 35 | train_nbr = 54 36 | test_nbr = 27 37 | gama = 64 38 | step_every_epoch = int(train_nbr / batch_size) 39 | test_every_epoch = int(test_nbr / test_batchsize) 40 | learningrate = 1.0e-4 #tf.Variable(1e-4, dtype=tf.float32) 41 | learningrateend = 1.0e-6 42 | # the parameters of aupr 43 | range_threshold = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] 44 | # range_threshold = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0] 45 | 46 | # flags parameters 47 | FLAGS = tf.flags.FLAGS 48 | tf.flags.DEFINE_string('mode', "train", "Mode train/ test/ visualize") 49 | 50 | 51 | # data_dir = "Data/" 52 | class Unet: 53 | def __init__(self, img_rows=IMG_SIZE, img_cols=IMG_SIZE): 54 | self.img_rows = img_rows 55 | self.img_cols = img_cols 56 | 57 | def load_data_util(self): 58 | image_options = {'resize': True, 'resize_size': IMG_SIZE} # resize all your images 59 | train_records, valid_records = RDL.read_dataset(data_dir) # get read lists 60 | train_dataset_reader = BDR.BatchDatset(train_records, image_options,augmentation=True) 61 | validation_dataset_reader = BDR.BatchDatset(valid_records, image_options,augmentation=False) 62 | return train_dataset_reader, validation_dataset_reader 63 | 64 | def model(self, image, is_train=True, reuse=False): 65 | with tf.variable_scope("model", reuse=reuse): 66 | tl.layers.set_name_reuse(reuse) 67 | #W_init = tf.contrib.layers.xavier_initializer() 68 | W_init = tf.contrib.layers.variance_scaling_initializer() 69 | # W_init = tf.contrib.layers.variance_scaling_initializer(factor=2.0, mode='FAN_IN') 70 | net = tl.layers.InputLayer(image, name='input_layer') # input image 71 | 72 | # filter = [96 for _ in range(10)] 73 | filter = [16, 64, 64, 128, 128, 128, 256] 74 | concat={} 75 | i = 0 76 | block = 1 77 | depth = 6 78 | for j in range(1, depth): 79 | i = i + 1 80 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1), 81 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i)) 82 | i = i+1 83 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1), 84 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i)) 85 | concat[j] = net 86 | net = tl.layers.AtrousConv2dLayer(net, filter[j], (3, 3), 2, 87 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='atro_conv'+str(block)) 88 | block = block + 1 89 | 90 | i = i + 1 91 | net = tl.layers.Conv2d(net, filter[block], (3, 3), (1, 1), 92 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i)) 93 | i = i + 1 94 | net = tl.layers.Conv2d(net, filter[block], (3, 3), (1, 1), 95 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i)) 96 | 97 | for j in range(depth-1, 0, -1): 98 | i = i + 1 99 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1), 100 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i)) 101 | net = tl.layers.ConcatLayer([net, concat[j]], 3, name='concat_'+str(j)) 102 | i = i+1 103 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1), 104 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_'+str(i)) 105 | i = i+1 106 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1), 107 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_'+str(i)) 108 | i = i+1 109 | net = tl.layers.Conv2d(net, num_of_classes, (1, 1), (1, 1), 110 | padding='SAME', W_init=W_init, name='conv_'+str(i)) 111 | y = net.outputs # transfer tl object to logits tensor 112 | pred = tf.argmax(y, 3, name="prediction") 113 | 114 | return pred, y, net 115 | 116 | def loss(self, logits, annotation): 117 | positive_count = tf.count_nonzero(annotation,dtype=tf.float32) 118 | total_num = tf.constant(IMG_SIZE * IMG_SIZE,dtype=tf.float32) 119 | negativa_count = tf.subtract(total_num, positive_count) 120 | alpha = tf.div(tf.multiply(negativa_count,1.0),(positive_count*gama)) 121 | 122 | classes_weights = [1.0, alpha] 123 | labels = tf.squeeze(annotation, squeeze_dims=[3]) 124 | # labels = tf.reshape(annotation,(-1,)) 125 | # labels = tf.one_hot(labels, depth=num_of_classes,dtype=tf.float32) 126 | 127 | # loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, 128 | # labels=tf.squeeze(annotation,squeeze_dims=[3]), 129 | # name="entropy"))) 130 | 131 | 132 | # print('++++++++++++++++++++++++++++=logits.shape=', logits.get_shape()) 133 | # # print('++++++++++++++++++++++++++++=labels.dtype=',labels.dtype) 134 | # print('++++++++++++++++++++++++++++=labels.shape()=', labels.get_shape()) 135 | # print('------------------------positive count = ',positive_count) 136 | # print('------------------------negativa count = ', negativa_count) 137 | # print('------------------------classes_weights = ', classes_weights) 138 | # logits = tf.reshape(logits, (-1, num_of_classes)) 139 | # logits = tf.nn.softmax(logits) 140 | 141 | # loss = tf.reduce_mean( 142 | # tf.nn.weighted_cross_entropy_with_logits(logits=logits, 143 | # targets=labels, 144 | # pos_weight=classes_weights)) 145 | # loss = tf.reduce_mean(self.weighted_cross_entropy(targets=labels, 146 | # logits=logits, 147 | # pos_weight=classes_weights)) 148 | 149 | loss = self.weighted_softmax_cross_entropy_loss(logits=logits,labels=labels,weights=classes_weights) 150 | 151 | # L2 = 0 152 | # for p in tl.layers.get_variables_with_name('/W', True, True): 153 | # L2 += tf.contrib.layers.l2_regularizer(0.00001)(p) 154 | # loss = loss + L2 155 | return loss 156 | 157 | def weighted_softmax_cross_entropy_loss(self, logits, labels, weights): 158 | """ 159 | Computes the SoftMax Cross Entropy loss with class weights based on the class of each pixel. 160 | 161 | Parameters 162 | ---------- 163 | logits: TF tensor 164 | The network output before SoftMax. 165 | labels: TF tensor 166 | The desired output from the ground truth. 167 | weights : list of floats 168 | A list of the weights associated with the different labels in the ground truth. 169 | 170 | Returns 171 | ------- 172 | loss : TF float 173 | The loss. 174 | weight_map: TF Tensor 175 | The loss weights assigned to each pixel. Same dimensions as the labels. 176 | 177 | """ 178 | 179 | with tf.name_scope('loss'): 180 | # logits = tf.reshape(logits, [-1, tf.shape(logits)[3]], name='flatten_logits') 181 | # labels = tf.reshape(labels, [-1], name='flatten_labels') 182 | 183 | weight_map = tf.to_float(tf.equal(labels, 0, name='label_map_0')) * weights[0] 184 | for i, weight in enumerate(weights[1:], start=1): 185 | weight_map = weight_map + tf.to_float(tf.equal(labels, i, name='label_map_' + str(i))) * weight 186 | 187 | weight_map = tf.stop_gradient(weight_map, name='stop_gradient') 188 | 189 | # compute cross entropy loss 190 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, 191 | name='cross_entropy_softmax') 192 | 193 | # apply weights to cross entropy loss 194 | weighted_cross_entropy = tf.multiply(weight_map, cross_entropy, name='apply_weights') 195 | 196 | # get loss scalar 197 | loss = tf.reduce_mean(weighted_cross_entropy, name='loss') 198 | 199 | return loss 200 | 201 | def weighted_softmax_cross_entropy_loss_with_false_positive_weights(self, logits, labels, weights, 202 | false_positive_factor=0.5): 203 | """ 204 | Computes the SoftMax Cross Entropy loss with class weights based on the class of each pixel and an additional weight 205 | for false positive classifications (instances of class 0 classified as class 1). 206 | 207 | Parameters 208 | ---------- 209 | logits: TF tensor 210 | The network output before SoftMax. 211 | labels: TF tensor 212 | The desired output from the ground truth. 213 | weights : list of floats 214 | A list of the weights associated with the different labels in the ground truth. 215 | false_positive_factor: float 216 | False positives receive a loss weight of false_positive_factor * label_weights[1], the weight of the class of interest. 217 | 218 | Returns 219 | ------- 220 | loss : TF float 221 | The loss. 222 | weight_map: TF Tensor 223 | The loss weights assigned to each pixel. Same dimensions as the labels. 224 | 225 | """ 226 | 227 | with tf.name_scope('loss'): 228 | logits = tf.reshape(logits, [-1, tf.shape(logits)[3]], name='flatten_logits') 229 | labels = tf.reshape(labels, [-1], name='flatten_labels') 230 | 231 | # get predictions from likelihoods 232 | prediction = tf.argmax(logits, 1, name='predictions') 233 | 234 | # get maps of class_of_interest pixels 235 | prediction_map = tf.equal(prediction, 1, name='prediction_map') 236 | label_map = tf.equal(labels, 1, name='label_map') 237 | 238 | false_positive_map = tf.logical_and(prediction_map, tf.logical_not(label_map), name='false_positive_map') 239 | 240 | label_map = tf.to_float(label_map) 241 | false_positive_map = tf.to_float(false_positive_map) 242 | 243 | weight_map = label_map * (weights[1] - weights[0]) + weights[0] 244 | weight_map = tf.add(weight_map, false_positive_map * ((false_positive_factor * weights[1]) - weights[0]), 245 | name="combined_weight_map") 246 | 247 | weight_map = tf.stop_gradient(weight_map, name='stop_gradient') 248 | 249 | # compute cross entropy loss 250 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, 251 | name='cross_entropy_softmax') 252 | 253 | # apply weights to cross entropy loss 254 | weighted_cross_entropy = tf.multiply(weight_map, cross_entropy, name='apply_weights') 255 | 256 | # get loss scalar 257 | loss = tf.reduce_mean(weighted_cross_entropy, name='loss') 258 | 259 | return loss 260 | 261 | def weighted_cross_entropy(self, targets, logits, pos_weight, name=None): 262 | """computer weight cross entropy""" 263 | with ops.name_scope(name, "logistic_loss", [logits, targets]) as name: 264 | logits = ops.convert_to_tensor(logits, name="logits") 265 | targets = ops.convert_to_tensor(targets, name="targets") 266 | try: 267 | targets.get_shape().merge_with(logits.get_shape()) 268 | except ValueError: 269 | raise ValueError( 270 | "logits and targets must have the same shape (%s vs %s)" % 271 | (logits.get_shape(), targets.get_shape())) 272 | 273 | # log_weight = 1 + (pos_weight - 1) * targets 274 | # return math_ops.add( 275 | # (1 - targets) * logits, 276 | # log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) + 277 | # nn_ops.relu(-logits)), 278 | # name=name) 279 | cross_entropy = -tf.add(targets[:, :, :, 0] * tf.log(logits[:, :, :, 0]), 280 | pos_weight*targets[:, :, :, 1] * tf.log(logits[:, :, :, 1])) 281 | return cross_entropy 282 | 283 | def train(self, loss, learning_rate, globalstep): 284 | # If use tf.nn.sparse_softmax_cross_entropy_with_logits , 285 | # maybe loss will be NAN,because without clip 286 | # annotation = tf.cast(annotation,dtype = tf.float32) 287 | # prob = tf.nn.softmax(logits) 288 | # loss = -tf.reduce_mean(annotation*tf.log(tf.clip_by_value(prob,1e-11,1.0))) 289 | optimizer = tf.train.AdamOptimizer(learning_rate) 290 | var_list = tf.trainable_variables() 291 | grads = optimizer.compute_gradients(loss, var_list=var_list) 292 | train_op = optimizer.apply_gradients(grads, global_step=globalstep) 293 | return train_op 294 | 295 | 296 | # AUPR score 297 | def computeConfMatElements(thresholded_proba_map, ground_truth): 298 | P = np.count_nonzero(ground_truth) 299 | TP = np.count_nonzero(thresholded_proba_map * ground_truth) 300 | FP = np.count_nonzero(thresholded_proba_map - (thresholded_proba_map * ground_truth)) 301 | 302 | return P, TP, FP 303 | 304 | 305 | def computeAUPR(proba_map, ground_truth, threshold_list): 306 | proba_map = proba_map.astype(np.float32) 307 | proba_map = proba_map.reshape(-1) 308 | ground_truth = ground_truth.reshape(-1) 309 | precision_list_treshold = [] 310 | recall_list_treshold = [] 311 | # loop over thresholds 312 | for threshold in threshold_list: 313 | # threshold the proba map 314 | thresholded_proba_map = np.zeros(np.shape(proba_map)) 315 | thresholded_proba_map[proba_map >= threshold] = 1 316 | # print(np.shape(thresholded_proba_map)) #(400,640) 317 | 318 | # compute P, TP, and FP for this threshold and this proba map 319 | P, TP, FP = computeConfMatElements(thresholded_proba_map, ground_truth) 320 | 321 | # check that ground truth contains at least one positive 322 | if (P > 0 and (TP + FP) > 0): 323 | precision = TP * 1. / (TP + FP) 324 | recall = TP * 1. / P 325 | else: 326 | precision = 1 327 | recall = 0 328 | 329 | # average sensitivity and FP over the proba map, for a given threshold 330 | precision_list_treshold.append(precision) 331 | recall_list_treshold.append(recall) 332 | 333 | # aupr = 0.0 334 | # for i in range(1, len(precision_list_treshold)): 335 | # aupr = aupr + precision_list_treshold[i] * (recall_list_treshold[i] - recall_list_treshold[i - 1]) 336 | precision_list_treshold.append(1) 337 | recall_list_treshold.append(0) 338 | return auc(recall_list_treshold, precision_list_treshold) 339 | 340 | 341 | def main(argv=None): 342 | myUnet = Unet() 343 | image = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 3], name='image') # input gray images 344 | annotation = tf.placeholder(tf.int32, shape=[None, IMG_SIZE, IMG_SIZE, 1], name="annotation") 345 | # image = tf.cast(image, tf.float32) 346 | # annotation = tf.cast(annotation, tf.int32) 347 | 348 | # define inferences 349 | train_pred, train_logits, train_tlnetwork = myUnet.model(image, is_train=True, reuse=False) 350 | train_positive_prob = tf.nn.softmax(train_logits)[:, :, :, 1] 351 | train_loss_op = myUnet.loss(train_logits, annotation) 352 | 353 | n_epoch = MAX_EPOCH 354 | n_step_epoch = int(train_nbr / batch_size) 355 | LR_start = learningrate 356 | LR_fin = learningrateend 357 | LR_decay = (LR_fin / LR_start) ** (1.0 / n_epoch) 358 | step_decay = n_step_epoch 359 | global_steps = tf.Variable(0, trainable=False) 360 | learning_rate = tf.train.exponential_decay(learningrate, global_steps, step_decay, LR_decay, staircase=True) 361 | # learning_rate = tf.Variable(learningrate, dtype=tf.float32) 362 | # train_op = myUnet.train(train_loss_op,learning_rate,global_steps) 363 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(train_loss_op, global_step=global_steps) 364 | summaries=[] 365 | summaries.append(tf.summary.scalar('learning_rate', learning_rate, collections=['learningrate'])) 366 | # Merge all summaries together. 367 | summary_lr = tf.summary.merge(summaries, name='summary_lr') 368 | 369 | test_pred, test_logits, test_tlnetwork = myUnet.model(image, is_train=False, reuse=True) 370 | test_positive_prob = tf.nn.softmax(test_logits)[:, :, :, 1] 371 | test_loss_op = myUnet.loss(test_logits, annotation) 372 | 373 | # lr_assign_op = tf.assign(learning_rate, learning_rate / 5) # learning_rate decay 374 | 375 | # only visualize the test images 376 | # first lighten the annotation images 377 | visual_annotation = tf.where(tf.equal(annotation, 1), annotation + 254, annotation) 378 | visual_pred = tf.expand_dims(tf.where(tf.equal(test_pred, 1), test_pred + 254, test_pred), dim=3) 379 | tf.summary.image("input_image", image, max_outputs=2) 380 | tf.summary.image("ground_truth", tf.cast(visual_annotation, tf.uint8), max_outputs=2) 381 | tf.summary.image("pred_annotation", tf.cast(visual_pred, tf.uint8), max_outputs=2) 382 | 383 | print("Setting up summary op...") 384 | test_summary_op = tf.summary.merge_all() 385 | 386 | if FLAGS.mode == 'train': 387 | train_dataset_reader, validation_dataset_reader = myUnet.load_data_util() 388 | config = tf.ConfigProto(allow_soft_placement=True) 389 | config.gpu_options.allow_growth = True 390 | sess = tf.Session(config=config) 391 | 392 | print("Setting up Saver...") 393 | saver = tf.train.Saver(max_to_keep=2) 394 | summary_writer = tf.summary.FileWriter(logs_dir, sess.graph) 395 | 396 | sess.run(tf.global_variables_initializer()) 397 | 398 | ckpt = tf.train.get_checkpoint_state(logs_dir) # if model has been trained,restore it 399 | if ckpt and ckpt.model_checkpoint_path: 400 | saver.restore(sess, ckpt.model_checkpoint_path) 401 | print("Model restored...") 402 | start = time.time() 403 | for epo in range(MAX_EPOCH): 404 | start_time = time.time() 405 | train_loss, test_loss, train_aupr, test_aupr, train_auc, test_auc = 0, 0, 0, 0, 0, 0 406 | 407 | for s in range(step_every_epoch): 408 | train_images, train_annotations = train_dataset_reader.next_batch(batch_size) 409 | feed_dict = {image: train_images, annotation: train_annotations} 410 | tra_positive_prob, train_err, _ = sess.run([train_positive_prob, train_loss_op, train_op], 411 | feed_dict=feed_dict) 412 | 413 | # compute auc score 414 | temp_train_annotations = np.reshape(train_annotations, -1) 415 | temp_tra_positive_prob = np.reshape(tra_positive_prob, -1) 416 | train_sauc = roc_auc_score(temp_train_annotations, np.nan_to_num(temp_tra_positive_prob)) 417 | # compute aupr 418 | train_saupr = computeAUPR(np.nan_to_num(tra_positive_prob).reshape(-1), train_annotations.reshape(-1), range_threshold) 419 | 420 | train_loss += train_err 421 | train_auc += train_sauc 422 | train_aupr += train_saupr 423 | 424 | if epo + 1 == 1 or (epo + 1) % print_freq == 0: 425 | train_loss = train_loss / step_every_epoch 426 | train_auc = train_auc / step_every_epoch 427 | train_aupr = train_aupr / step_every_epoch 428 | # visualize the training loss 429 | print("%d epoches %d took %fs" % (print_freq, epo, time.time() - start_time)) 430 | print(" train loss: %f" % train_loss) 431 | print(" train auc: %f" % train_auc) 432 | print(" train aupr: %f" % train_aupr) 433 | 434 | train_summary = tf.Summary(value=[ 435 | tf.Summary.Value(tag="train_loss", simple_value=train_loss), 436 | tf.Summary.Value(tag="train_auc", simple_value=train_auc), 437 | tf.Summary.Value(tag="train_aupr", simple_value=train_aupr) 438 | ]) 439 | summary_writer.add_summary(train_summary, epo) 440 | summary_str = sess.run(summary_lr) 441 | summary_writer.add_summary(summary_str, epo) 442 | summary_writer.flush() 443 | 444 | for test_s in range(test_every_epoch): 445 | # get validation data 446 | valid_images, valid_annotations = validation_dataset_reader.next_batch(test_batchsize) 447 | # visualize the validation loss 448 | feed_dict = {image: valid_images, annotation: valid_annotations} 449 | valid_positive_prob, validation_err = sess.run([test_positive_prob, test_loss_op], feed_dict=feed_dict) 450 | # compute auc score 451 | temp_valid_annotations = np.reshape(valid_annotations, -1) 452 | temp_valid_positive_prob = np.reshape(valid_positive_prob, -1) 453 | test_sauc = roc_auc_score(temp_valid_annotations, np.nan_to_num(temp_valid_positive_prob)) 454 | # compute test aupr 455 | test_saupr = computeAUPR(np.nan_to_num(valid_positive_prob).reshape(-1), valid_annotations.reshape(-1), 456 | range_threshold) 457 | 458 | test_loss += validation_err 459 | test_auc += test_sauc 460 | test_aupr += test_saupr 461 | test_loss = test_loss / test_every_epoch 462 | test_auc = test_auc / test_every_epoch 463 | test_aupr = test_aupr / test_every_epoch 464 | print(" test aupr: %f" % test_aupr) 465 | test_summary = tf.Summary(value=[ 466 | tf.Summary.Value(tag="test_loss", simple_value=test_loss), 467 | tf.Summary.Value(tag="test_auc", simple_value=test_auc), 468 | tf.Summary.Value(tag="test_aupr", simple_value=test_aupr) 469 | ]) 470 | summary_writer.add_summary(test_summary, epo) 471 | 472 | # visualize the test result(only visualize the last batchsize of this epoch) 473 | feed_dict = {image: valid_images, annotation: valid_annotations} 474 | summary_str = sess.run(test_summary_op, feed_dict=feed_dict) 475 | summary_writer.add_summary(summary_str, epo) 476 | 477 | # tensorboard flush 478 | summary_writer.flush() 479 | sys.stdout.flush() 480 | # if (epo+1) % 100 == 0: 481 | # sess.run(lr_assign_op) 482 | if (epo + 1) % 100 == 0: 483 | saver.save(sess, logs_dir + "model.ckpt", epo) 484 | print('the %d epoch , the model has been saved successfully' % epo) 485 | sys.stdout.flush() 486 | print('-------------------total cost time: %fs' % (time.time() - start)) 487 | summary_writer.close() 488 | sess.close() 489 | 490 | 491 | if __name__ == '__main__': 492 | tf.app.run() 493 | --------------------------------------------------------------------------------