├── img
├── 1.jpg
└── 2.jpg
├── myscript.sh
├── README.md
├── read_Data_list.py
├── BatchDatasetReader.py
└── train_main.py
/img/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiangwenliu/IDRiD-Lesion-Segmentation/HEAD/img/1.jpg
--------------------------------------------------------------------------------
/img/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiangwenliu/IDRiD-Lesion-Segmentation/HEAD/img/2.jpg
--------------------------------------------------------------------------------
/myscript.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #source ~/.bashrc
3 | #hostname
4 | #write my command
5 | python train_main.py
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Diabetic Retinopathy Lesion Segmentation
2 | The automatic segmentation of retinal lesions.
3 |
4 | # Dependence
5 | * python 2.7
6 | * tensorflow 1.6
7 | * tensorlayer 1.8
8 | * tensorboard
9 |
10 | # Data
11 | * It consists of 81 images with pixel level annotation, training data 54 samples, test data 27 samples.
12 | * [Down Load Indian Diabetic Retinopathy Image Dataset (IDRiD)](https://ieee-dataport.org/open-access/indian-diabetic-retinopathy-image-dataset-idrid)
13 |
14 | # Performance Evaluation
15 | Evaluates the performance of the algorithms for lesion segmentation using the available binary masks. The area under precision-recall (AUPR) is used to obtain a single score.
16 |
17 | # Experiment
18 |
19 |
20 | # Result
21 |
22 |
23 |
--------------------------------------------------------------------------------
/read_Data_list.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from six.moves import cPickle as pickle
4 | from tensorflow.python.platform import gfile
5 | import glob
6 |
7 |
8 |
9 |
10 | def read_dataset(data_dir):
11 | pickle_filename = "iDrid.pickle"
12 | pickle_filepath = os.path.join(data_dir, pickle_filename)
13 | if not os.path.exists(pickle_filepath):
14 | result = create_image_lists(data_dir)
15 | print ("Pickling ...")
16 | with open(pickle_filepath, 'wb') as f:
17 | pickle.dump(result, f, protocol=2)
18 | else:
19 | print ("Found pickle file!")
20 |
21 | with open(pickle_filepath, 'rb') as f:
22 | result = pickle.load(f)
23 | training_records = result['training']
24 | validation_records = result['validation']
25 | del result
26 |
27 | return training_records, validation_records
28 |
29 |
30 | def create_image_lists(image_dir):
31 | if not gfile.Exists(image_dir):
32 | print("Image directory '" + image_dir + "' not found.")
33 | return None
34 | directories = ['training', 'validation']
35 | image_list = {}
36 | for directory in directories:
37 | file_list = []
38 | image_list[directory] = []
39 | file_glob = os.path.join(image_dir, "images", directory, '*.' + 'jpg')
40 | file_list.extend(glob.glob(file_glob))
41 | if not file_list:
42 | print('No files found')
43 | else:
44 | for f in file_list:
45 | filename = os.path.splitext(f.split("/")[-1])[0]+'_EX' #windows->\\,linux->/
46 | annotation_file = os.path.join(image_dir, "annotations", directory, filename + '.tif')
47 | if os.path.exists(annotation_file):
48 | record = {'image': f, 'annotation': annotation_file, 'filename': filename}
49 | image_list[directory].append(record)
50 | else:
51 | print("Annotation file not found for %s - Skipping" % filename)
52 |
53 | random.shuffle(image_list[directory])
54 | no_of_images = len(image_list[directory])
55 | print ('No. of %s files: %d' % (directory, no_of_images))
56 |
57 | #print(image_list)
58 | return image_list
59 |
60 |
61 | #create_image_lists('Data/')
--------------------------------------------------------------------------------
/BatchDatasetReader.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.misc as misc
3 | from PIL import Image, ImageOps, ImageEnhance
4 | import random
5 |
6 |
7 | class BatchDatset:
8 | files = []
9 | images = []
10 | annotations = []
11 | image_options = {}
12 | batch_offset = 0
13 | epochs_completed = 0 # don't use this time
14 | ratio = 0.25
15 |
16 | def __init__(self, records_list, image_options={}, augmentation=False):
17 | """
18 | Intialize a generic file reader with batching for list of files
19 | :param records_list: list of file name records to read -
20 | sample record: {'image': f, 'annotation': annotation_file, 'filename': filename}
21 | :param image_options: A dictionary of options for modifying the output image
22 | Available options:
23 | resize = True/ False
24 | resize_size = #size of output image - does bilinear resize
25 | color=True/False
26 | """
27 | print("Initializing Batch Dataset Reader...")
28 | print(image_options)
29 | self.files = records_list
30 | self.image_options = image_options
31 | self.prosess_image = True
32 | self.data_augmentation = augmentation
33 | self.__channels = True
34 | self._read_images()
35 |
36 | def _read_images(self):
37 | # self.__channels = True # require gray immages
38 | # self.prosessimg = True
39 | # self.images = np.array([self._transform(filename['image']) for filename in self.files])
40 | # # self.images = np.expand_dims(self.images, 3)
41 | # self.__channels = True
42 | # self.prosessimg = False
43 | # self.annotations = np.array([self._transform(filename['annotation']) for filename in self.files])
44 | # self.annotations = np.expand_dims(self.annotations, 3)
45 | # print('self.images.shape:', self.images.shape)
46 | # print('self.annotations.shape:', self.annotations.shape)
47 |
48 | self.imageslist = [self._get_image(filename['image']) for filename in self.files]
49 | self.annotationslist = [self._get_image(filename['annotation']) for filename in self.files]
50 | print('----self.images.length:', len(self.imageslist))
51 | print('----self.annotations.length:', len(self.annotationslist))
52 |
53 | def _get_image(self, filename):
54 | image = Image.open(filename)
55 |
56 | return image
57 |
58 |
59 | def _transform(self, filename):
60 | image_object = Image.open(filename)
61 | crop_image = image_object.crop((270, 0, 3710, 2848))
62 | padding = (0, 296, 0, 296)
63 | pad_image = ImageOps.expand(crop_image, padding) # width and height same size
64 |
65 | if self.__channels == False:
66 | image = np.array(pad_image.convert('L'))
67 | else:
68 | image = np.array(pad_image)
69 |
70 | if self.image_options.get("resize", False) and self.image_options["resize"]:
71 | resize_size = int(self.image_options["resize_size"])
72 | resize_image = misc.imresize(image,
73 | [resize_size, resize_size],
74 | interp='bicubic') # bicubic interpolation resize image
75 | else:
76 | resize_image = image
77 |
78 | if self.prosess_image is True:
79 | resize_image = resize_image * (1.0 / 255)
80 | resize_image = per_image_standardization(resize_image)
81 |
82 | return np.array(resize_image)
83 |
84 | def _crop_resize_image(self, image, annotation):
85 | rate = [0.248, 0.25, 0.252]
86 | index = random.randint(0, 2)
87 | old_size = image.size
88 | new_size = tuple([int(x * rate[index]) for x in old_size])
89 | resize_image = image.resize(size=new_size, resample=3)
90 | resize_annotation = annotation.resize(size=new_size, resample=3)
91 | crop_image = resize_image.crop((66, 0, 930, 712))
92 | crop_annotation = resize_annotation.crop((66, 0, 930, 712))
93 | padding = (0, 76, 0, 76)
94 | image = ImageOps.expand(crop_image, padding)
95 | annotation = ImageOps.expand(crop_annotation, padding)
96 | if image.size != annotation.size:
97 | raise ValueError("Image and annotation size not equal !!!")
98 |
99 | # random crop image to size (640, 640)
100 | width, height = image.size
101 | resize = int(self.image_options["resize_size"])
102 | x = random.randint(0, width - resize - 1)
103 | y = random.randint(0, height - resize - 1)
104 | image = image.crop((x, y, x + resize, y + resize))
105 | annotation = annotation.crop((x, y, x + resize, y + resize))
106 |
107 |
108 | if self.data_augmentation is True:
109 | # light
110 | enh_bri = ImageEnhance.Brightness(image)
111 | brightness = round(random.uniform(0.8, 1.2), 2)
112 | image = enh_bri.enhance(brightness)
113 |
114 | # color
115 | enh_col = ImageEnhance.Color(image)
116 | color = round(random.uniform(0.8, 1.2), 2)
117 | image = enh_col.enhance(color)
118 |
119 |
120 | # contrast
121 | enh_con = ImageEnhance.Contrast(image)
122 | contrast = round(random.uniform(0.8, 1.2), 2)
123 | image = enh_con.enhance(contrast)
124 | #
125 | # enh_sha = ImageEnhance.Sharpness(image)
126 | # sharpness = round(random.uniform(0.8, 1.2), 2)
127 | # image = enh_sha.enhance(sharpness)
128 |
129 | method = random.randint(0, 7)
130 | # print(method)
131 | if method < 7:
132 | image = image.transpose(method)
133 | annotation = annotation.transpose(method)
134 | degree = random.randint(-5, 5)
135 | image = image.rotate(degree)
136 | annotation = annotation.rotate(degree)
137 |
138 | image_array = np.array(image)
139 | #standardization image
140 | if self.prosess_image is True:
141 | image_array = image_array * (1.0 / 255)
142 | # image_array = per_image_standardization(image_array)
143 |
144 | annotation_array = np.array(annotation)
145 | return np.array(image_array), annotation_array
146 |
147 | def get_records(self):
148 | return self.imageslist, self.annotationslist
149 |
150 | def reset_batch_offset(self, offset=0):
151 | self.batch_offset = offset
152 |
153 | def next_batch(self, batch_size):
154 | start = self.batch_offset
155 | self.batch_offset += batch_size
156 | if self.batch_offset > len(self.imageslist):
157 | # Finished epoch
158 | # self.epochs_completed += 1
159 | # print("****************** Epochs completed: " + str(self.epochs_completed) + "******************")
160 | # Shuffle the data
161 | # perm = np.arange(len(self.imageslist))
162 | # np.random.shuffle(perm)
163 | c = list(zip(self.imageslist, self.annotationslist))
164 | random.shuffle(c)
165 | self.imageslist, self.annotationslist = zip(*c)
166 | # self.imageslist = self.imageslist[perm]
167 | # self.annotationslist = self.annotationslist[perm]
168 | # Start next epoch
169 | start = 0
170 | self.batch_offset = batch_size
171 |
172 | end = self.batch_offset
173 | image_batch = []
174 | annotation_batch = []
175 | for (image, annotation) in zip(self.imageslist[start:end], self.annotationslist[start:end]):
176 | img, annot = self._crop_resize_image(image, annotation)
177 | image_batch.append(img)
178 | annotation_batch.append(annot)
179 | return np.array(image_batch), np.expand_dims(np.array(annotation_batch), 3)
180 |
181 | def get_random_batch(self, batch_size):
182 | indexes = np.random.randint(0, len(self.imageslist), size=[batch_size]).tolist()
183 | image = []
184 | annotation = []
185 | for index in indexes:
186 | img, annot = self._crop_resize_image(self.imageslist[index], self.annotationslist[index])
187 | image.append(img)
188 | annotation.append(annot)
189 | return np.array(image), np.expand_dims(np.array(annotation), 3)
190 |
191 |
192 | def per_image_standardization(image):
193 | image = image.astype(np.float32, copy=False)
194 | mean = np.mean(image)
195 | stddev = np.std(image)
196 | adjusted_stddev = max(stddev, 1.0 / np.sqrt(np.array(image.size, dtype=np.float32)))
197 | im = (image - mean) / adjusted_stddev
198 | return im
--------------------------------------------------------------------------------
/train_main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Mon Apr 23 10:01:27 2018
5 |
6 | @author: shawn
7 | """
8 |
9 | import tensorflow as tf
10 | import tensorlayer as tl
11 | import numpy as np
12 | import BatchDatasetReader as BDR
13 | import read_Data_list as RDL
14 | import sys
15 | import time
16 | from sklearn.metrics import roc_auc_score
17 | from sklearn.metrics import auc
18 | from tensorflow.python.framework import ops
19 | from tensorflow.python.ops import math_ops
20 |
21 | # path variable
22 | logs_dir = 'logs/'
23 | data_dir = '/home/lxw/tensorflowproject/data/idrid'
24 |
25 | # basic constant variable
26 | IMG_SIZE = 640
27 | num_of_classes = 2
28 | print_freq = 10
29 | WIDTH = 4288
30 | HEIGHT = 2848
31 | # training constant variable
32 | MAX_EPOCH = 500
33 | batch_size = 1
34 | test_batchsize = 1
35 | train_nbr = 54
36 | test_nbr = 27
37 | gama = 64
38 | step_every_epoch = int(train_nbr / batch_size)
39 | test_every_epoch = int(test_nbr / test_batchsize)
40 | learningrate = 1.0e-4 #tf.Variable(1e-4, dtype=tf.float32)
41 | learningrateend = 1.0e-6
42 | # the parameters of aupr
43 | range_threshold = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
44 | # range_threshold = [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0]
45 |
46 | # flags parameters
47 | FLAGS = tf.flags.FLAGS
48 | tf.flags.DEFINE_string('mode', "train", "Mode train/ test/ visualize")
49 |
50 |
51 | # data_dir = "Data/"
52 | class Unet:
53 | def __init__(self, img_rows=IMG_SIZE, img_cols=IMG_SIZE):
54 | self.img_rows = img_rows
55 | self.img_cols = img_cols
56 |
57 | def load_data_util(self):
58 | image_options = {'resize': True, 'resize_size': IMG_SIZE} # resize all your images
59 | train_records, valid_records = RDL.read_dataset(data_dir) # get read lists
60 | train_dataset_reader = BDR.BatchDatset(train_records, image_options,augmentation=True)
61 | validation_dataset_reader = BDR.BatchDatset(valid_records, image_options,augmentation=False)
62 | return train_dataset_reader, validation_dataset_reader
63 |
64 | def model(self, image, is_train=True, reuse=False):
65 | with tf.variable_scope("model", reuse=reuse):
66 | tl.layers.set_name_reuse(reuse)
67 | #W_init = tf.contrib.layers.xavier_initializer()
68 | W_init = tf.contrib.layers.variance_scaling_initializer()
69 | # W_init = tf.contrib.layers.variance_scaling_initializer(factor=2.0, mode='FAN_IN')
70 | net = tl.layers.InputLayer(image, name='input_layer') # input image
71 |
72 | # filter = [96 for _ in range(10)]
73 | filter = [16, 64, 64, 128, 128, 128, 256]
74 | concat={}
75 | i = 0
76 | block = 1
77 | depth = 6
78 | for j in range(1, depth):
79 | i = i + 1
80 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1),
81 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i))
82 | i = i+1
83 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1),
84 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i))
85 | concat[j] = net
86 | net = tl.layers.AtrousConv2dLayer(net, filter[j], (3, 3), 2,
87 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='atro_conv'+str(block))
88 | block = block + 1
89 |
90 | i = i + 1
91 | net = tl.layers.Conv2d(net, filter[block], (3, 3), (1, 1),
92 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i))
93 | i = i + 1
94 | net = tl.layers.Conv2d(net, filter[block], (3, 3), (1, 1),
95 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i))
96 |
97 | for j in range(depth-1, 0, -1):
98 | i = i + 1
99 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1),
100 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_' + str(i))
101 | net = tl.layers.ConcatLayer([net, concat[j]], 3, name='concat_'+str(j))
102 | i = i+1
103 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1),
104 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_'+str(i))
105 | i = i+1
106 | net = tl.layers.Conv2d(net, filter[j], (3, 3), (1, 1),
107 | act=tf.nn.relu, padding='SAME', W_init=W_init, name='conv_'+str(i))
108 | i = i+1
109 | net = tl.layers.Conv2d(net, num_of_classes, (1, 1), (1, 1),
110 | padding='SAME', W_init=W_init, name='conv_'+str(i))
111 | y = net.outputs # transfer tl object to logits tensor
112 | pred = tf.argmax(y, 3, name="prediction")
113 |
114 | return pred, y, net
115 |
116 | def loss(self, logits, annotation):
117 | positive_count = tf.count_nonzero(annotation,dtype=tf.float32)
118 | total_num = tf.constant(IMG_SIZE * IMG_SIZE,dtype=tf.float32)
119 | negativa_count = tf.subtract(total_num, positive_count)
120 | alpha = tf.div(tf.multiply(negativa_count,1.0),(positive_count*gama))
121 |
122 | classes_weights = [1.0, alpha]
123 | labels = tf.squeeze(annotation, squeeze_dims=[3])
124 | # labels = tf.reshape(annotation,(-1,))
125 | # labels = tf.one_hot(labels, depth=num_of_classes,dtype=tf.float32)
126 |
127 | # loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
128 | # labels=tf.squeeze(annotation,squeeze_dims=[3]),
129 | # name="entropy")))
130 |
131 |
132 | # print('++++++++++++++++++++++++++++=logits.shape=', logits.get_shape())
133 | # # print('++++++++++++++++++++++++++++=labels.dtype=',labels.dtype)
134 | # print('++++++++++++++++++++++++++++=labels.shape()=', labels.get_shape())
135 | # print('------------------------positive count = ',positive_count)
136 | # print('------------------------negativa count = ', negativa_count)
137 | # print('------------------------classes_weights = ', classes_weights)
138 | # logits = tf.reshape(logits, (-1, num_of_classes))
139 | # logits = tf.nn.softmax(logits)
140 |
141 | # loss = tf.reduce_mean(
142 | # tf.nn.weighted_cross_entropy_with_logits(logits=logits,
143 | # targets=labels,
144 | # pos_weight=classes_weights))
145 | # loss = tf.reduce_mean(self.weighted_cross_entropy(targets=labels,
146 | # logits=logits,
147 | # pos_weight=classes_weights))
148 |
149 | loss = self.weighted_softmax_cross_entropy_loss(logits=logits,labels=labels,weights=classes_weights)
150 |
151 | # L2 = 0
152 | # for p in tl.layers.get_variables_with_name('/W', True, True):
153 | # L2 += tf.contrib.layers.l2_regularizer(0.00001)(p)
154 | # loss = loss + L2
155 | return loss
156 |
157 | def weighted_softmax_cross_entropy_loss(self, logits, labels, weights):
158 | """
159 | Computes the SoftMax Cross Entropy loss with class weights based on the class of each pixel.
160 |
161 | Parameters
162 | ----------
163 | logits: TF tensor
164 | The network output before SoftMax.
165 | labels: TF tensor
166 | The desired output from the ground truth.
167 | weights : list of floats
168 | A list of the weights associated with the different labels in the ground truth.
169 |
170 | Returns
171 | -------
172 | loss : TF float
173 | The loss.
174 | weight_map: TF Tensor
175 | The loss weights assigned to each pixel. Same dimensions as the labels.
176 |
177 | """
178 |
179 | with tf.name_scope('loss'):
180 | # logits = tf.reshape(logits, [-1, tf.shape(logits)[3]], name='flatten_logits')
181 | # labels = tf.reshape(labels, [-1], name='flatten_labels')
182 |
183 | weight_map = tf.to_float(tf.equal(labels, 0, name='label_map_0')) * weights[0]
184 | for i, weight in enumerate(weights[1:], start=1):
185 | weight_map = weight_map + tf.to_float(tf.equal(labels, i, name='label_map_' + str(i))) * weight
186 |
187 | weight_map = tf.stop_gradient(weight_map, name='stop_gradient')
188 |
189 | # compute cross entropy loss
190 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits,
191 | name='cross_entropy_softmax')
192 |
193 | # apply weights to cross entropy loss
194 | weighted_cross_entropy = tf.multiply(weight_map, cross_entropy, name='apply_weights')
195 |
196 | # get loss scalar
197 | loss = tf.reduce_mean(weighted_cross_entropy, name='loss')
198 |
199 | return loss
200 |
201 | def weighted_softmax_cross_entropy_loss_with_false_positive_weights(self, logits, labels, weights,
202 | false_positive_factor=0.5):
203 | """
204 | Computes the SoftMax Cross Entropy loss with class weights based on the class of each pixel and an additional weight
205 | for false positive classifications (instances of class 0 classified as class 1).
206 |
207 | Parameters
208 | ----------
209 | logits: TF tensor
210 | The network output before SoftMax.
211 | labels: TF tensor
212 | The desired output from the ground truth.
213 | weights : list of floats
214 | A list of the weights associated with the different labels in the ground truth.
215 | false_positive_factor: float
216 | False positives receive a loss weight of false_positive_factor * label_weights[1], the weight of the class of interest.
217 |
218 | Returns
219 | -------
220 | loss : TF float
221 | The loss.
222 | weight_map: TF Tensor
223 | The loss weights assigned to each pixel. Same dimensions as the labels.
224 |
225 | """
226 |
227 | with tf.name_scope('loss'):
228 | logits = tf.reshape(logits, [-1, tf.shape(logits)[3]], name='flatten_logits')
229 | labels = tf.reshape(labels, [-1], name='flatten_labels')
230 |
231 | # get predictions from likelihoods
232 | prediction = tf.argmax(logits, 1, name='predictions')
233 |
234 | # get maps of class_of_interest pixels
235 | prediction_map = tf.equal(prediction, 1, name='prediction_map')
236 | label_map = tf.equal(labels, 1, name='label_map')
237 |
238 | false_positive_map = tf.logical_and(prediction_map, tf.logical_not(label_map), name='false_positive_map')
239 |
240 | label_map = tf.to_float(label_map)
241 | false_positive_map = tf.to_float(false_positive_map)
242 |
243 | weight_map = label_map * (weights[1] - weights[0]) + weights[0]
244 | weight_map = tf.add(weight_map, false_positive_map * ((false_positive_factor * weights[1]) - weights[0]),
245 | name="combined_weight_map")
246 |
247 | weight_map = tf.stop_gradient(weight_map, name='stop_gradient')
248 |
249 | # compute cross entropy loss
250 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits,
251 | name='cross_entropy_softmax')
252 |
253 | # apply weights to cross entropy loss
254 | weighted_cross_entropy = tf.multiply(weight_map, cross_entropy, name='apply_weights')
255 |
256 | # get loss scalar
257 | loss = tf.reduce_mean(weighted_cross_entropy, name='loss')
258 |
259 | return loss
260 |
261 | def weighted_cross_entropy(self, targets, logits, pos_weight, name=None):
262 | """computer weight cross entropy"""
263 | with ops.name_scope(name, "logistic_loss", [logits, targets]) as name:
264 | logits = ops.convert_to_tensor(logits, name="logits")
265 | targets = ops.convert_to_tensor(targets, name="targets")
266 | try:
267 | targets.get_shape().merge_with(logits.get_shape())
268 | except ValueError:
269 | raise ValueError(
270 | "logits and targets must have the same shape (%s vs %s)" %
271 | (logits.get_shape(), targets.get_shape()))
272 |
273 | # log_weight = 1 + (pos_weight - 1) * targets
274 | # return math_ops.add(
275 | # (1 - targets) * logits,
276 | # log_weight * (math_ops.log1p(math_ops.exp(-math_ops.abs(logits))) +
277 | # nn_ops.relu(-logits)),
278 | # name=name)
279 | cross_entropy = -tf.add(targets[:, :, :, 0] * tf.log(logits[:, :, :, 0]),
280 | pos_weight*targets[:, :, :, 1] * tf.log(logits[:, :, :, 1]))
281 | return cross_entropy
282 |
283 | def train(self, loss, learning_rate, globalstep):
284 | # If use tf.nn.sparse_softmax_cross_entropy_with_logits ,
285 | # maybe loss will be NAN,because without clip
286 | # annotation = tf.cast(annotation,dtype = tf.float32)
287 | # prob = tf.nn.softmax(logits)
288 | # loss = -tf.reduce_mean(annotation*tf.log(tf.clip_by_value(prob,1e-11,1.0)))
289 | optimizer = tf.train.AdamOptimizer(learning_rate)
290 | var_list = tf.trainable_variables()
291 | grads = optimizer.compute_gradients(loss, var_list=var_list)
292 | train_op = optimizer.apply_gradients(grads, global_step=globalstep)
293 | return train_op
294 |
295 |
296 | # AUPR score
297 | def computeConfMatElements(thresholded_proba_map, ground_truth):
298 | P = np.count_nonzero(ground_truth)
299 | TP = np.count_nonzero(thresholded_proba_map * ground_truth)
300 | FP = np.count_nonzero(thresholded_proba_map - (thresholded_proba_map * ground_truth))
301 |
302 | return P, TP, FP
303 |
304 |
305 | def computeAUPR(proba_map, ground_truth, threshold_list):
306 | proba_map = proba_map.astype(np.float32)
307 | proba_map = proba_map.reshape(-1)
308 | ground_truth = ground_truth.reshape(-1)
309 | precision_list_treshold = []
310 | recall_list_treshold = []
311 | # loop over thresholds
312 | for threshold in threshold_list:
313 | # threshold the proba map
314 | thresholded_proba_map = np.zeros(np.shape(proba_map))
315 | thresholded_proba_map[proba_map >= threshold] = 1
316 | # print(np.shape(thresholded_proba_map)) #(400,640)
317 |
318 | # compute P, TP, and FP for this threshold and this proba map
319 | P, TP, FP = computeConfMatElements(thresholded_proba_map, ground_truth)
320 |
321 | # check that ground truth contains at least one positive
322 | if (P > 0 and (TP + FP) > 0):
323 | precision = TP * 1. / (TP + FP)
324 | recall = TP * 1. / P
325 | else:
326 | precision = 1
327 | recall = 0
328 |
329 | # average sensitivity and FP over the proba map, for a given threshold
330 | precision_list_treshold.append(precision)
331 | recall_list_treshold.append(recall)
332 |
333 | # aupr = 0.0
334 | # for i in range(1, len(precision_list_treshold)):
335 | # aupr = aupr + precision_list_treshold[i] * (recall_list_treshold[i] - recall_list_treshold[i - 1])
336 | precision_list_treshold.append(1)
337 | recall_list_treshold.append(0)
338 | return auc(recall_list_treshold, precision_list_treshold)
339 |
340 |
341 | def main(argv=None):
342 | myUnet = Unet()
343 | image = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 3], name='image') # input gray images
344 | annotation = tf.placeholder(tf.int32, shape=[None, IMG_SIZE, IMG_SIZE, 1], name="annotation")
345 | # image = tf.cast(image, tf.float32)
346 | # annotation = tf.cast(annotation, tf.int32)
347 |
348 | # define inferences
349 | train_pred, train_logits, train_tlnetwork = myUnet.model(image, is_train=True, reuse=False)
350 | train_positive_prob = tf.nn.softmax(train_logits)[:, :, :, 1]
351 | train_loss_op = myUnet.loss(train_logits, annotation)
352 |
353 | n_epoch = MAX_EPOCH
354 | n_step_epoch = int(train_nbr / batch_size)
355 | LR_start = learningrate
356 | LR_fin = learningrateend
357 | LR_decay = (LR_fin / LR_start) ** (1.0 / n_epoch)
358 | step_decay = n_step_epoch
359 | global_steps = tf.Variable(0, trainable=False)
360 | learning_rate = tf.train.exponential_decay(learningrate, global_steps, step_decay, LR_decay, staircase=True)
361 | # learning_rate = tf.Variable(learningrate, dtype=tf.float32)
362 | # train_op = myUnet.train(train_loss_op,learning_rate,global_steps)
363 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(train_loss_op, global_step=global_steps)
364 | summaries=[]
365 | summaries.append(tf.summary.scalar('learning_rate', learning_rate, collections=['learningrate']))
366 | # Merge all summaries together.
367 | summary_lr = tf.summary.merge(summaries, name='summary_lr')
368 |
369 | test_pred, test_logits, test_tlnetwork = myUnet.model(image, is_train=False, reuse=True)
370 | test_positive_prob = tf.nn.softmax(test_logits)[:, :, :, 1]
371 | test_loss_op = myUnet.loss(test_logits, annotation)
372 |
373 | # lr_assign_op = tf.assign(learning_rate, learning_rate / 5) # learning_rate decay
374 |
375 | # only visualize the test images
376 | # first lighten the annotation images
377 | visual_annotation = tf.where(tf.equal(annotation, 1), annotation + 254, annotation)
378 | visual_pred = tf.expand_dims(tf.where(tf.equal(test_pred, 1), test_pred + 254, test_pred), dim=3)
379 | tf.summary.image("input_image", image, max_outputs=2)
380 | tf.summary.image("ground_truth", tf.cast(visual_annotation, tf.uint8), max_outputs=2)
381 | tf.summary.image("pred_annotation", tf.cast(visual_pred, tf.uint8), max_outputs=2)
382 |
383 | print("Setting up summary op...")
384 | test_summary_op = tf.summary.merge_all()
385 |
386 | if FLAGS.mode == 'train':
387 | train_dataset_reader, validation_dataset_reader = myUnet.load_data_util()
388 | config = tf.ConfigProto(allow_soft_placement=True)
389 | config.gpu_options.allow_growth = True
390 | sess = tf.Session(config=config)
391 |
392 | print("Setting up Saver...")
393 | saver = tf.train.Saver(max_to_keep=2)
394 | summary_writer = tf.summary.FileWriter(logs_dir, sess.graph)
395 |
396 | sess.run(tf.global_variables_initializer())
397 |
398 | ckpt = tf.train.get_checkpoint_state(logs_dir) # if model has been trained,restore it
399 | if ckpt and ckpt.model_checkpoint_path:
400 | saver.restore(sess, ckpt.model_checkpoint_path)
401 | print("Model restored...")
402 | start = time.time()
403 | for epo in range(MAX_EPOCH):
404 | start_time = time.time()
405 | train_loss, test_loss, train_aupr, test_aupr, train_auc, test_auc = 0, 0, 0, 0, 0, 0
406 |
407 | for s in range(step_every_epoch):
408 | train_images, train_annotations = train_dataset_reader.next_batch(batch_size)
409 | feed_dict = {image: train_images, annotation: train_annotations}
410 | tra_positive_prob, train_err, _ = sess.run([train_positive_prob, train_loss_op, train_op],
411 | feed_dict=feed_dict)
412 |
413 | # compute auc score
414 | temp_train_annotations = np.reshape(train_annotations, -1)
415 | temp_tra_positive_prob = np.reshape(tra_positive_prob, -1)
416 | train_sauc = roc_auc_score(temp_train_annotations, np.nan_to_num(temp_tra_positive_prob))
417 | # compute aupr
418 | train_saupr = computeAUPR(np.nan_to_num(tra_positive_prob).reshape(-1), train_annotations.reshape(-1), range_threshold)
419 |
420 | train_loss += train_err
421 | train_auc += train_sauc
422 | train_aupr += train_saupr
423 |
424 | if epo + 1 == 1 or (epo + 1) % print_freq == 0:
425 | train_loss = train_loss / step_every_epoch
426 | train_auc = train_auc / step_every_epoch
427 | train_aupr = train_aupr / step_every_epoch
428 | # visualize the training loss
429 | print("%d epoches %d took %fs" % (print_freq, epo, time.time() - start_time))
430 | print(" train loss: %f" % train_loss)
431 | print(" train auc: %f" % train_auc)
432 | print(" train aupr: %f" % train_aupr)
433 |
434 | train_summary = tf.Summary(value=[
435 | tf.Summary.Value(tag="train_loss", simple_value=train_loss),
436 | tf.Summary.Value(tag="train_auc", simple_value=train_auc),
437 | tf.Summary.Value(tag="train_aupr", simple_value=train_aupr)
438 | ])
439 | summary_writer.add_summary(train_summary, epo)
440 | summary_str = sess.run(summary_lr)
441 | summary_writer.add_summary(summary_str, epo)
442 | summary_writer.flush()
443 |
444 | for test_s in range(test_every_epoch):
445 | # get validation data
446 | valid_images, valid_annotations = validation_dataset_reader.next_batch(test_batchsize)
447 | # visualize the validation loss
448 | feed_dict = {image: valid_images, annotation: valid_annotations}
449 | valid_positive_prob, validation_err = sess.run([test_positive_prob, test_loss_op], feed_dict=feed_dict)
450 | # compute auc score
451 | temp_valid_annotations = np.reshape(valid_annotations, -1)
452 | temp_valid_positive_prob = np.reshape(valid_positive_prob, -1)
453 | test_sauc = roc_auc_score(temp_valid_annotations, np.nan_to_num(temp_valid_positive_prob))
454 | # compute test aupr
455 | test_saupr = computeAUPR(np.nan_to_num(valid_positive_prob).reshape(-1), valid_annotations.reshape(-1),
456 | range_threshold)
457 |
458 | test_loss += validation_err
459 | test_auc += test_sauc
460 | test_aupr += test_saupr
461 | test_loss = test_loss / test_every_epoch
462 | test_auc = test_auc / test_every_epoch
463 | test_aupr = test_aupr / test_every_epoch
464 | print(" test aupr: %f" % test_aupr)
465 | test_summary = tf.Summary(value=[
466 | tf.Summary.Value(tag="test_loss", simple_value=test_loss),
467 | tf.Summary.Value(tag="test_auc", simple_value=test_auc),
468 | tf.Summary.Value(tag="test_aupr", simple_value=test_aupr)
469 | ])
470 | summary_writer.add_summary(test_summary, epo)
471 |
472 | # visualize the test result(only visualize the last batchsize of this epoch)
473 | feed_dict = {image: valid_images, annotation: valid_annotations}
474 | summary_str = sess.run(test_summary_op, feed_dict=feed_dict)
475 | summary_writer.add_summary(summary_str, epo)
476 |
477 | # tensorboard flush
478 | summary_writer.flush()
479 | sys.stdout.flush()
480 | # if (epo+1) % 100 == 0:
481 | # sess.run(lr_assign_op)
482 | if (epo + 1) % 100 == 0:
483 | saver.save(sess, logs_dir + "model.ckpt", epo)
484 | print('the %d epoch , the model has been saved successfully' % epo)
485 | sys.stdout.flush()
486 | print('-------------------total cost time: %fs' % (time.time() - start))
487 | summary_writer.close()
488 | sess.close()
489 |
490 |
491 | if __name__ == '__main__':
492 | tf.app.run()
493 |
--------------------------------------------------------------------------------