├── .gitignore ├── README.md └── src ├── dataset.py ├── main.py ├── model.py ├── preprocessing.py ├── solver.py ├── tensorflow_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | github_imgs/ 2 | src/.idea 3 | src/__pycache__ 4 | src/logs 5 | src/model 6 | src/wmap_imgs 7 | src/sample 8 | src/test 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # U-Net-TensorFlow 2 | This repository is a TensorFlow implementation of the ["U-Net: Convolutional Networks for Biomedical Image Segmentation," MICCAI2015](https://arxiv.org/pdf/1505.04597.pdf). It completely follows the original U-Net paper. 3 | 4 |
5 |
7 |
8 | ## EM Segmentation Challenge Dataset
9 |
10 |
12 |
13 | ## Requirements
14 | - tensorflow 1.13.1
15 | - python 3.5.3
16 | - numpy 1.15.2
17 | - scipy 1.1.0
18 | - tifffile 2019.3.18
19 | - opencv 3.3.1
20 | - matplotlib 2.2.2
21 | - elasticdeform 0.4.4
22 | - scikit-learn 0.20.0
23 |
24 | ## Implementation Details
25 | This implementation completely follows the original U-Net paper from the following aspects:
26 | - Input image size 572 x 572 x 1 vs output labeled image 388 x 388 x 2
27 | - Upsampling used fractional strided convolusion (deconv)
28 | - Reflection mirror padding is used for the input image
29 | - Data augmentation: random translation, random horizontal and vertical flip, random rotation, and random elastic deformation
30 | - Loss function includes weighted cross-entropy loss and regularization term
31 | - Weight map is calculated using equation 2 of the original paper
32 | - In test stage, this implementation achieves average of the 7 rotated versions of the input data
33 |
34 | ## Examples of the Data Augmentation
35 | - Random Translation
36 | - Random Horizontal and Vertical Flip
37 | - Random Rotation
38 | - Random Elastic Deformation
39 |
40 |
41 |
43 |
44 |
45 |
47 |
48 | ## Fundamental of the Different Sized Input and Output Images in Training Process
49 | - Reflected mirror padding is utilized first (white lines indicate boundaries of the image)
50 | - Randomly cropping the input image, label image, and weighted image
51 | - Blue rectangle region of the input image and red rectangle of the weight map are the inputs of the U-Net in the training, and the red rectangle of the labeled image is the ground-truth of the network.
52 |
53 |
54 |
56 |
57 | ## Test Paradigm
58 | - In test stage, each test image is the average of the 7 rotated version of the input data. The final prediction is the averaging the 7 predicted restuls.
59 |
60 |
61 |
63 |
64 |
65 |
67 |
68 | - For each rotated image, the four regions are extracted, top left, top right, bottom left, and bottom right of the each image to go through the U-Net, and the prediction is calculated averaging the overlapping scores of the four results
69 |
70 |
71 |
73 |
74 | **Note**: White lines indicate boundaries of the image.
75 |
76 |
77 |
79 |
80 | **Note:** The prediciton results of the EM Segmentation Challenge Test Dataset
81 |
82 |
84 |
85 | ## Download Dataset
86 | Download the EM Segmetnation Challenge dataset from [ISBI challenge homepage](http://brainiac2.mit.edu/isbi_challenge/).
87 |
88 | ## Documentation
89 | ### Directory Hierarchy
90 | ```
91 | .
92 | │ U-Net
93 | │ ├── src
94 | │ │ ├── dataset.py
95 | │ │ ├── main.py
96 | │ │ ├── model.py
97 | │ │ ├── preprocessing.py
98 | │ │ ├── solver.py
99 | │ │ ├── tensorflow_utils.py
100 | │ │ └── utils.py
101 | │ Data
102 | │ └── EMSegmentation
103 | │ │ ├── test-volume.tif
104 | │ │ ├── train-labels.tif
105 | │ │ ├── train-wmaps.npy (generated in preprocessing)
106 | │ │ └── train-volume.tif
107 | ```
108 | ### Preprocessing
109 | Weight map need to be calculated using segmentaion labels in training data first. Calculaing wegith map using on-line method in training will slow down processing time. Therefore, calculating and saving weighted map first, the weight maps are augmented according to the input and label images. Use `preprocessing.py` to calculate weight maps. Example usage:
110 | ```
111 | python preprocessing.py
112 | ```
113 |
114 | ### Training U-Net
115 | Use `main.py` to train the U
116 | ```
117 | python main.py
118 | ```
119 | - `gpu_index`: gpu index if you have multiple gpus, default: `0`
120 | - `dataset`: dataset name, default: `EMSegmentation`
121 | - `batch_size`: batch size for one iteration, default: `4`
122 | - `is_train`: training or inference (test) mode, default: `True (training mode)`
123 | - `learning_rate`: initial learning rate for optimizer, default: `1e-3`
124 | - `weight_decay`: weight decay for model to handle overfitting, default: `1e-4`
125 | - `iters`: number of iterations, default: `20,000`
126 | - `print_freq`: print frequency for loss information, default: `10`
127 | - `sample_freq`: sample frequence for checking qualitative evaluation, default: `100`
128 | - `eval_freq`: evaluation frequency for evluation of the batch accuracy, default: `200`
129 | - `load_model`: folder of saved model that you wish to continue training, (e.g. 20190524-1606), default: `None`
130 |
131 | ### Test U-Net
132 | Use `main.py` to test the models. Example usage:
133 | ```
134 | python main.py --is_train=False --load_model=folder/you/wish/to/test/e.g./20190524-1606
135 | ```
136 | Please refer to the above arguments.
137 |
138 | ### Tensorboard Visualization
139 | **Note:** The following figure shows data loss, weighted data loss, regularization term, and total loss during training process. The batch accuracy also is given in tensorboard.
140 |
141 |
142 |
144 |
145 | ### Citation
146 | ```
147 | @misc{chengbinjin2019u-nettensorflow,
148 | author = {Cheng-Bin Jin},
149 | title = {U-Net Tensorflow},
150 | year = {2019},
151 | howpublished = {\url{https://github.com/ChengBinJin/U-Net-TensorFlow}},
152 | note = {commit xxxxxxx}
153 | }
154 | ```
155 |
156 | ### Attributions/Thanks
157 | - This project borrowed some code from [Zhixuhao](https://github.com/zhixuhao/unet)
158 | - Some readme formatting was borrowed from [Logan Engstrom](https://github.com/lengstrom/fast-style-transfer)
159 |
160 | ## License
161 | Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: sbkim0407@gmail.com). Free for research use, as long as proper attribution is given and this copyright notice is retained.
162 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow U-Net Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import logging
9 | import numpy as np
10 | import tifffile as tiff
11 | from scipy.ndimage import rotate
12 |
13 | import utils as utils
14 |
15 | logger = logging.getLogger(__name__) # logger
16 | logger.setLevel(logging.INFO)
17 |
18 |
19 | class Dataset(object):
20 | def __init__(self, name='EMSegmentation', log_dir=None):
21 | # It is depended on dataset
22 | self.input_size = 572
23 | self.output_size = 388
24 | self.input_channel = 1
25 | self.input_shape = (self.input_size, self.input_size, self.input_channel)
26 | self.output_shape = (self.output_size, self.output_size)
27 |
28 | self.name = name
29 | self.dataset_path = '../../Data/EMSegmentation'
30 |
31 | self.train_imgs = tiff.imread(os.path.join(self.dataset_path, 'train-volume.tif'))
32 | self.train_labels = tiff.imread(os.path.join(self.dataset_path, 'train-labels.tif'))
33 | self.train_wmaps = np.load(os.path.join(self.dataset_path, 'train-wmaps.npy'))
34 | self.test_imgs = tiff.imread(os.path.join(self.dataset_path, 'test-volume.tif'))
35 | self.mean_value = np.mean(self.train_imgs)
36 |
37 | self.num_train = self.train_imgs.shape[0]
38 | self.num_test = self.test_imgs.shape[0]
39 | self.img_shape = self.train_imgs[0].shape
40 |
41 | self.init_logger(log_dir)
42 |
43 | @staticmethod
44 | def init_logger(log_dir):
45 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s')
46 |
47 | # file handler
48 | file_handler = logging.FileHandler(os.path.join(log_dir, 'dataset.log'))
49 | file_handler.setFormatter(formatter)
50 |
51 | # stream handler
52 | stream_handler = logging.StreamHandler()
53 | stream_handler.setFormatter(formatter)
54 |
55 | # add handlers
56 | logger.addHandler(file_handler)
57 | logger.addHandler(stream_handler)
58 |
59 | def info(self, use_logging=True, log_dir=None):
60 | if use_logging:
61 | logger.info('- Training-img set:\t{}'.format(self.train_imgs.shape))
62 | logger.info('- Training-label set:\t{}'.format(self.train_labels.shape))
63 | logger.info('- Training-wmap set:\t{}'.format(self.train_wmaps.shape))
64 | logger.info('- Test-img set:\t\t{}'.format(self.test_imgs.shape))
65 |
66 | logger.info('- image shape:\t\t{}'.format(self.img_shape))
67 | else:
68 | print('- Training-img set:\t{}'.format(self.train_imgs.shape))
69 | print('- Training-label set:\t{}'.format(self.train_labels.shape))
70 | print('- Training-wmap set:\t{}'.format(self.train_wmaps.shape))
71 | print('- Test-img set:\t\t{}'.format(self.test_imgs.shape))
72 | print('- image shape:\t\t{}'.format(self.img_shape))
73 |
74 | print(' [*] Saving data augmented images to check U-Net fundamentals...')
75 | for idx in range(self.num_train):
76 | img_, label_, wmap_ = self.train_imgs[idx], self.train_labels[idx], self.train_wmaps[idx]
77 | utils.imshow(img_, label_, wmap_, idx, log_dir=log_dir)
78 | utils.test_augmentation(img_, label_, wmap_, idx, log_dir=log_dir)
79 | utils.test_cropping(img_, label_, wmap_, idx, self.input_size, self.output_size, log_dir=log_dir)
80 | print(' [!] Saving data augmented images to check U-Net fundamentals!')
81 |
82 | def info_test(self, test_dir):
83 | print(' [*] Saving test data cropping to check U-Net fundamentals...')
84 | for idx in range(self.num_test):
85 | img = self.test_imgs[idx]
86 | utils.test_imshow(img, idx, test_dir=test_dir)
87 | utils.test_rotate(img, idx, test_dir=test_dir)
88 | print(' [!] Saving test data cropping to check U-net fundamentals!')
89 |
90 |
91 | def random_batch(self, idx, batch_size=2):
92 | idx = idx % self.num_train
93 | x_img, y_label, w_map = self.train_imgs[idx], self.train_labels[idx], self.train_wmaps[idx]
94 |
95 | x_batchs = np.zeros((batch_size, self.input_size, self.input_size), dtype=np.float32)
96 | # y_batchs will be represented in one-hot in solver.train()
97 | y_batchs = np.zeros((batch_size, self.output_size, self.output_size), dtype=np.float32)
98 | w_batchs = np.zeros((batch_size, self.output_size, self.output_size), dtype=np.float32)
99 | for idx in range(batch_size):
100 | # Random translation
101 | x_batch, y_batch, w_batch = utils.aug_translate(x_img, y_label, w_map)
102 |
103 | # Random horizontal and vertical flip
104 | x_batch, y_batch, w_batch = utils.aug_flip(x_batch, y_batch, w_batch)
105 |
106 | # Random rotation
107 | x_batch, y_batch, w_batch = utils.aug_rotate(x_batch, y_batch, w_batch)
108 |
109 | # Random elastic deformation
110 | x_batch, y_batch, w_batch = utils.aug_elastic_deform(x_batch, y_batch, w_batch)
111 |
112 | # Following the originl U-Net paper
113 | # Resize image to 696(696=572+92*2) x 696(696=572+92*2) then crop 572 x 572 input image
114 | # and 388 x 388 lable map
115 | # 92 = (572 - 388) / 2
116 | x_batch, y_batch, w_batch = utils.cropping(x_batch, y_batch, w_batch, self.input_size, self.output_size)
117 |
118 | x_batchs[idx, :, :] = x_batch
119 | y_batchs[idx, :, :] = y_batch
120 | w_batchs[idx, :, :] = w_batch
121 |
122 | return self.zero_centering(x_batchs), (y_batchs / 255).astype(np.uint8), w_batchs
123 |
124 | def test_batch(self, idx, angle):
125 | x_img_ori = self.test_imgs[idx]
126 |
127 | # Rotate inage
128 | x_img = rotate(input=x_img_ori, angle=angle, axes=(0, 1), reshape=False, order=3, mode='reflect')
129 | x_batchs = utils.test_data_cropping(img=x_img,
130 | input_size=self.input_size,
131 | output_size=self.output_size,
132 | num_blocks=4)
133 |
134 | return self.zero_centering(x_batchs), x_img_ori
135 |
136 | def zero_centering(self, imgs):
137 | return imgs - self.mean_value
138 |
139 |
140 | # if __name__ == '__main__':
141 | # data = Dataset()
142 | #
143 | # for i in range(data.num_train):
144 | # img, label = data.train_imgs[i], data.train_labels[i]
145 | # utils.imshow(img, label, idx=i)
146 |
147 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow U-Net Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import logging
9 | import numpy as np
10 | import tensorflow as tf
11 | from datetime import datetime
12 |
13 | from dataset import Dataset
14 | from model import Model
15 | from solver import Solver
16 | import utils as utils
17 |
18 | FLAGS = tf.flags.FLAGS
19 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index if you have multiple gpus, default: 0')
20 | tf.flags.DEFINE_string('dataset', 'EMSegmentation', 'dataset name, default: EMSegmentation')
21 | tf.flags.DEFINE_integer('batch_size', 4, 'batch size for one iteration, default: 4')
22 | tf.flags.DEFINE_bool('is_train', True, 'training or inference mode, default: True')
23 | tf.flags.DEFINE_float('learning_rate', 1e-3, 'initial learning rate for optimizer, default: 0.001')
24 | tf.flags.DEFINE_float('weight_decay', 1e-4, 'weight decay for model to handle overfitting, default: 0.0001')
25 | tf.flags.DEFINE_integer('iters', 20000, 'number of iterations, default: 20,000')
26 | tf.flags.DEFINE_integer('print_freq', 10, 'print frequency for loss information, default: 10')
27 | tf.flags.DEFINE_integer('sample_freq', 100, 'sample frequence for checking qualitative evaluation, default: 100')
28 | tf.flags.DEFINE_integer('eval_freq', 200, 'evaluation frequency for evaluation of the batch accuracy, default: 200')
29 | tf.flags.DEFINE_string('load_model', None, 'folder of saved model that you wish to continue training '
30 | '(e.g. 20190524-1606), default: None')
31 |
32 | logger = logging.getLogger(__name__) # logger
33 | logger.setLevel(logging.INFO)
34 |
35 | def init_logger(log_dir, is_train=True):
36 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s')
37 |
38 | # file handler
39 | file_handler = logging.FileHandler(os.path.join(log_dir, 'main.log'))
40 | file_handler.setFormatter(formatter)
41 | file_handler.setLevel(logging.INFO)
42 |
43 | # stream handler
44 | stream_handler = logging.StreamHandler()
45 | stream_handler.setFormatter(formatter)
46 |
47 | # add handlers
48 | logger.addHandler(file_handler)
49 | logger.addHandler(stream_handler)
50 |
51 | if is_train:
52 | logger.info('gpu_index: \t\t{}'.format(FLAGS.gpu_index))
53 | logger.info('dataset: \t\t{}'.format(FLAGS.dataset))
54 | logger.info('batch_size: \t\t{}'.format(FLAGS.batch_size))
55 | logger.info('is_train: \t\t{}'.format(FLAGS.is_train))
56 | logger.info('learning_rate: \t{}'.format(FLAGS.learning_rate))
57 | logger.info('weight_decay: \t\t{}'.format(FLAGS.weight_decay))
58 | logger.info('iters: \t\t{}'.format(FLAGS.iters))
59 | logger.info('print_freq: \t\t{}'.format(FLAGS.print_freq))
60 | logger.info('sample_freq: \t\t{}'.format(FLAGS.sample_freq))
61 | logger.info('eval_freq: \t\t{}'.format(FLAGS.eval_freq))
62 | logger.info('load_model: \t\t{}'.format(FLAGS.load_model))
63 | else:
64 | print('-- gpu_index: \t\t{}'.format(FLAGS.gpu_index))
65 | print('-- dataset: \t\t{}'.format(FLAGS.dataset))
66 | print('-- batch_size: \t\t{}'.format(FLAGS.batch_size))
67 | print('-- is_train: \t\t{}'.format(FLAGS.is_train))
68 | print('-- learning_rate: \t{}'.format(FLAGS.learning_rate))
69 | print('-- weight_decay: \t\t{}'.format(FLAGS.weight_decay))
70 | print('-- iters: \t\t{}'.format(FLAGS.iters))
71 | print('-- print_freq: \t\t{}'.format(FLAGS.print_freq))
72 | print('-- sample_freq: \t\t{}'.format(FLAGS.sample_freq))
73 | print('-- eval_freq: \t\t{}'.format(FLAGS.eval_freq))
74 | print('-- load_model: \t\t{}'.format(FLAGS.load_model))
75 |
76 |
77 | def main(_):
78 | os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_index
79 |
80 | # Initialize model and log folders
81 | if FLAGS.load_model is None:
82 | cur_time = datetime.now().strftime("%Y%m%d-%H%M")
83 | else:
84 | cur_time = FLAGS.load_model
85 |
86 | model_dir, log_dir, sample_dir, test_dir = utils.make_folders(is_train=FLAGS.is_train, cur_time=cur_time)
87 | init_logger(log_dir=log_dir, is_train=FLAGS.is_train)
88 |
89 | # Initilize dataset
90 | data = Dataset(name=FLAGS.dataset, log_dir=log_dir)
91 | data.info(use_logging=True, log_dir=log_dir)
92 |
93 | # Initialize session
94 | sess = tf.Session()
95 |
96 | # Initilize model
97 | model = Model(input_shape=data.input_shape,
98 | output_shape=data.output_shape,
99 | lr=FLAGS.learning_rate,
100 | weight_decay=FLAGS.weight_decay,
101 | total_iters=FLAGS.iters,
102 | is_train=FLAGS.is_train,
103 | log_dir=log_dir,
104 | name='U-Net')
105 |
106 | # Initilize solver
107 | solver = Solver(sess, model, data.mean_value)
108 | saver = tf.train.Saver(max_to_keep=1)
109 |
110 | if FLAGS.is_train:
111 | train(data, solver, saver, model_dir, log_dir, sample_dir)
112 | else:
113 | test(data, solver, saver, model_dir, test_dir)
114 |
115 |
116 | def train(data, solver, saver, model_dir, log_dir, sample_dir):
117 | best_acc = 0.
118 | num_evals = 0
119 | tb_writer = tf.summary.FileWriter(log_dir, graph_def=solver.sess.graph_def)
120 | solver.init()
121 |
122 | iter_time = 0
123 | if FLAGS.load_model is not None:
124 | flag, iter_time, best_acc = load_model(saver, solver, model_dir, is_train=True)
125 | logger.info(' [!] Load Success! Iter: {}, Best acc: {:.3f}'.format(iter_time, best_acc))
126 |
127 | for iter_time in range(iter_time, FLAGS.iters):
128 | x_batch, y_batch, w_batch = data.random_batch(batch_size=FLAGS.batch_size, idx=iter_time)
129 | _, total_loss, avg_data_loss, weighted_data_loss, reg_term, summary, pred_cls = \
130 | solver.train(x_batch, y_batch, w_batch)
131 |
132 | # Write to tensorbard
133 | tb_writer.add_summary(summary, iter_time)
134 | tb_writer.flush()
135 |
136 | if np.mod(iter_time, FLAGS.print_freq) == 0:
137 | msg = '{}/{}: \tTotal loss: {:.3f}, \tAvg. data loss: {:.3f}, \tWeighted data loss: {:.3f} \tReg. term: {:.3f}'
138 | print(msg.format(iter_time, FLAGS.iters, total_loss, avg_data_loss, weighted_data_loss, reg_term))
139 |
140 | if np.mod(iter_time, FLAGS.sample_freq) == 0:
141 | solver.save_imgs(x_batch, pred_cls, y_batch, iter_time, sample_dir)
142 |
143 | if np.mod(iter_time, FLAGS.eval_freq) == 0:
144 | x_batch, y_batch, w_batch = data.random_batch(batch_size=FLAGS.batch_size * 20,
145 | idx=np.random.randint(low=0, high=FLAGS.iters))
146 | acc, summary = solver.evalate(x_batch, y_batch, batch_size=FLAGS.batch_size)
147 | print('Evaluation! \tAcc: {:.3f} \tBest Acc: {:.3f}'.format(acc, best_acc))
148 |
149 | # Write to tensorboard
150 | tb_writer.add_summary(summary, num_evals)
151 | tb_writer.flush()
152 | num_evals += 1
153 |
154 | if acc > best_acc:
155 | logger.info('Acc: {:.3f}, Best Acc: {:.3f}'.format(acc, best_acc))
156 | best_acc = acc
157 | save_model(saver, solver, model_dir, iter_time, best_acc)
158 |
159 |
160 | def test(data, solver, saver, model_dir, test_dir, start=0, stop=360, num=7):
161 | # Load checkpoint
162 | flag, iter_time, best_acc = load_model(saver, solver, model_dir, is_train=False)
163 | if flag is True:
164 | print(' [!] Load Success! Iter: {}, Best acc: {:.3f}'.format(iter_time, best_acc))
165 | else:
166 | print(' [!] Load Failed!')
167 |
168 | # Test
169 | data.info_test(test_dir)
170 |
171 | for iter_time in range(data.num_test):
172 | print('iter: {}'.format(iter_time))
173 |
174 | y_preds = np.zeros((num, *data.img_shape, 2), dtype=np.float32) # [N, H, W, 2]
175 | x_ori_img = None
176 | for i, angle in enumerate(np.linspace(start=start, stop=stop, num=num, endpoint=False)):
177 | x_batchs, x_ori_img = data.test_batch(iter_time, angle) # four corpped image for one test image
178 | y_preds[i] = solver.test(x_batchs, iter_time, angle, test_dir, is_save=True)
179 |
180 | # Merge rotated label images
181 | utils.merge_rotated_preds(y_preds, x_ori_img, iter_time, start, stop, num, test_dir, is_save=True)
182 |
183 |
184 | def save_model(saver, solver, model_dir, iter_time, best_acc):
185 | solver.save_acc_record(best_acc)
186 | saver.save(solver.sess, os.path.join(model_dir, 'model'), global_step=iter_time)
187 | logger.info(' [*] Model saved! Iter: {}, Best Acc.: {:.3f}'.format(iter_time, best_acc))
188 |
189 |
190 | def load_model(saver, solver, model_dir, is_train=False):
191 | if is_train:
192 | logger.info(' [*] Reading checkpoint...')
193 | else:
194 | print(' [*] Reading checkpoint...')
195 |
196 | ckpt = tf.train.get_checkpoint_state(model_dir)
197 | if ckpt and ckpt.model_checkpoint_path:
198 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
199 | saver.restore(solver.sess, os.path.join(model_dir, ckpt_name))
200 |
201 | meta_graph_path = ckpt.model_checkpoint_path + '.meta'
202 | iter_time = int(meta_graph_path.split('-')[-1].split('.')[0])
203 | best_acc = solver.load_acc_record()
204 |
205 | if is_train:
206 | logger.info(' [!] Load Iter: {}, Best Acc.: {:.3f}'.format(iter_time, best_acc))
207 |
208 | return True, iter_time, best_acc
209 | else:
210 | return False, 0, 0.
211 |
212 |
213 |
214 | if __name__ == '__main__':
215 | tf.app.run()
216 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow U-Net Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import tensorflow as tf
8 | import utils as utils
9 | import tensorflow_utils as tf_utils
10 |
11 | class Model(object):
12 | def __init__(self, input_shape, output_shape, lr=0.001, weight_decay=1e-4, total_iters=2e4, is_train=True,
13 | log_dir=None, name='U-Net'):
14 | self.input_shape = input_shape
15 | self.output_shape = output_shape
16 | self.conv_dims = [64, 64, 128, 128, 256, 256, 512, 512, 1024, 1024,
17 | 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 2]
18 | self.lr = lr
19 | self.weight_decay = weight_decay
20 | self.total_steps = total_iters
21 | self.start_decay_step = int(self.total_steps * 0.5)
22 | self.decay_steps = self.total_steps - self.start_decay_step
23 | self.is_train = is_train
24 | self.log_dir = log_dir
25 | self.name = name
26 | self.tb_lr, self.pred = None, None
27 | self.logger, self.file_handler, self.stream_handler = utils.init_logger(log_dir=self.log_dir,
28 | name=self.name,
29 | is_train=self.is_train)
30 |
31 | with tf.variable_scope(self.name):
32 | self._build_net()
33 |
34 | self._tensorboard()
35 | tf_utils.show_all_variables(logger=self.logger if self.is_train else None)
36 |
37 |
38 | def _build_net(self):
39 | # Input placeholders
40 | self.inp_img = tf.placeholder(dtype=tf.float32, shape=[None, *self.input_shape], name='input_img')
41 | self.out_img = tf.placeholder(dtype=tf.uint8, shape=[None, *self.output_shape], name='output_img')
42 | self.weight_map = tf.placeholder(dtype=tf.float32, shape=[None, *self.output_shape], name='weight_map')
43 | self.keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob')
44 | self.val_acc = tf.placeholder(dtype=tf.float32, name='val_acc')
45 |
46 | # Best acc record
47 | self.best_acc_record = tf.get_variable(name='best_acc_save', dtype=tf.float32, initializer=tf.constant(0.),
48 | trainable=False)
49 |
50 | # One-hot representation
51 | self.out_img_one_hot = tf.one_hot(indices=self.out_img, depth=2, axis=-1, dtype=tf.float32, name='one_hot')
52 |
53 | # U-Net building
54 | self.u_net()
55 | self.pred_cls = tf.math.argmax(input=self.pred, axis=-1)
56 |
57 | # Data loss
58 | self.data_loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.out_img_one_hot, logits=self.pred)
59 | self.avg_data_loss = tf.reduce_mean(self.data_loss)
60 | self.weighted_data_loss = tf.reduce_mean(self.weight_map * self.data_loss)
61 |
62 | # Regularization term
63 | self.reg_term = self.weight_decay * tf.reduce_sum(
64 | [tf.nn.l2_loss(weight) for weight in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)])
65 |
66 | # Total loss = Data loss + Regularization term
67 | self.total_loss = self.weighted_data_loss + self.reg_term
68 |
69 | # Optimizer
70 | self.train_op = self.optimizer_fn(loss=self.total_loss, name='Adam')
71 |
72 | # Accuracy
73 | self.accuracy = tf.reduce_mean(
74 | tf.cast(tf.math.equal(x=tf.cast(tf.math.argmax(self.out_img_one_hot, axis=-1), dtype=tf.uint8),
75 | y=tf.cast(self.pred_cls, dtype=tf.uint8)), dtype=tf.float32)) * 100.
76 | self.save_best_acc_op = tf.assign(self.best_acc_record, value=self.val_acc)
77 |
78 | def u_net(self):
79 | # Stage 1
80 | tf_utils.print_activations(self.inp_img, logger=self.logger)
81 | s1_conv1 = tf_utils.conv2d(x=self.inp_img, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1,
82 | padding='VALID', initializer='He', name='s1_conv1', logger=self.logger)
83 | s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1', logger=self.logger)
84 | s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1,
85 | padding='VALID', initializer='He', name='s1_conv2', logger=self.logger)
86 | s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2', logger=self.logger)
87 |
88 | # Stage 2
89 | s2_maxpool = tf_utils.max_pool(x=s1_conv2, name='s2_maxpool', logger=self.logger)
90 | s2_conv1 = tf_utils.conv2d(x=s2_maxpool, output_dim=self.conv_dims[2], k_h=3, k_w=3, d_h=1, d_w=1,
91 | padding='VALID', initializer='He', name='s2_conv1', logger=self.logger)
92 | s2_conv1 = tf_utils.relu(s2_conv1, name='relu_s2_conv1', logger=self.logger)
93 | s2_conv2 = tf_utils.conv2d(x=s2_conv1, output_dim=self.conv_dims[3], k_h=3, k_w=3, d_h=1, d_w=1,
94 | padding='VALID', initializer='He', name='s2_conv2', logger=self.logger)
95 | s2_conv2 = tf_utils.relu(s2_conv2, name='relu_s2_conv2', logger=self.logger)
96 |
97 | # Stage 3
98 | s3_maxpool = tf_utils.max_pool(x=s2_conv2, name='s3_maxpool', logger=self.logger)
99 | s3_conv1 = tf_utils.conv2d(x=s3_maxpool, output_dim=self.conv_dims[4], k_h=3, k_w=3, d_h=1, d_w=1,
100 | padding='VALID', initializer='He', name='s3_conv1', logger=self.logger)
101 | s3_conv1 = tf_utils.relu(s3_conv1, name='relu_s3_conv1', logger=self.logger)
102 | s3_conv2 = tf_utils.conv2d(x=s3_conv1, output_dim=self.conv_dims[5], k_h=3, k_w=3, d_h=1, d_w=1,
103 | padding='VALID', initializer='He', name='s3_conv2', logger=self.logger)
104 | s3_conv2 = tf_utils.relu(s3_conv2, name='relu_s3_conv2', logger=self.logger)
105 |
106 | # Stage 4
107 | s4_maxpool = tf_utils.max_pool(x=s3_conv2, name='s4_maxpool', logger=self.logger)
108 | s4_conv1 = tf_utils.conv2d(x=s4_maxpool, output_dim=self.conv_dims[6], k_h=3, k_w=3, d_h=1, d_w=1,
109 | padding='VALID', initializer='He', name='s4_conv1', logger=self.logger)
110 | s4_conv1 = tf_utils.relu(s4_conv1, name='relu_s4_conv1', logger=self.logger)
111 | s4_conv2 = tf_utils.conv2d(x=s4_conv1, output_dim=self.conv_dims[7], k_h=3, k_w=3, d_h=1, d_w=1,
112 | padding='VALID', initializer='He', name='s4_conv2', logger=self.logger)
113 | s4_conv2 = tf_utils.relu(s4_conv2, name='relu_s4_conv2', logger=self.logger)
114 | s4_conv2_drop = tf_utils.dropout(x=s4_conv2, keep_prob=self.keep_prob, name='s4_conv2_dropout',
115 | logger=self.logger)
116 |
117 | # Stage 5
118 | s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop, name='s5_maxpool', logger=self.logger)
119 | s5_conv1 = tf_utils.conv2d(x=s5_maxpool, output_dim=self.conv_dims[8], k_h=3, k_w=3, d_h=1, d_w=1,
120 | padding='VALID', initializer='He', name='s5_conv1', logger=self.logger)
121 | s5_conv1 = tf_utils.relu(s5_conv1, name='relu_s5_conv1', logger=self.logger)
122 | s5_conv2 = tf_utils.conv2d(x=s5_conv1, output_dim=self.conv_dims[9], k_h=3, k_w=3, d_h=1, d_w=1,
123 | padding='VALID', initializer='He', name='s5_conv2', logger=self.logger)
124 | s5_conv2 = tf_utils.relu(s5_conv2, name='relu_s5_conv2', logger=self.logger)
125 | s5_conv2_drop = tf_utils.dropout(x=s5_conv2, keep_prob=self.keep_prob, name='s5_conv2_dropout',
126 | logger=self.logger)
127 |
128 | # Stage 6
129 | s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop, output_dim=self.conv_dims[10], k_h=2, k_w=2, initializer='He',
130 | name='s6_deconv1', logger=self.logger)
131 | s6_deconv1 = tf_utils.relu(s6_deconv1, name='relu_s6_deconv1', logger=self.logger)
132 |
133 | # Cropping
134 | h1, w1 = s4_conv2_drop.get_shape().as_list()[1:3]
135 | h2, w2 = s6_deconv1.get_shape().as_list()[1:3]
136 | s4_conv2_crop = tf.image.crop_to_bounding_box(image=s4_conv2_drop,
137 | offset_height=int(0.5 * (h1 - h2)),
138 | offset_width=int(0.5 * (w1 - w2)),
139 | target_height=h2,
140 | target_width=w2)
141 | tf_utils.print_activations(s4_conv2_crop, logger=self.logger)
142 |
143 | s6_concat = tf_utils.concat(values=[s4_conv2_crop, s6_deconv1], axis=3, name='s6_concat', logger=self.logger)
144 | s6_conv2 = tf_utils.conv2d(x=s6_concat, output_dim=self.conv_dims[11], k_h=3, k_w=3, d_h=1, d_w=1,
145 | padding='VALID', initializer='He', name='s6_conv2', logger=self.logger)
146 | s6_conv2 = tf_utils.relu(s6_conv2, name='relu_s6_conv2', logger=self.logger)
147 | s6_conv3 = tf_utils.conv2d(x=s6_conv2, output_dim=self.conv_dims[12], k_h=3, k_w=3, d_h=1, d_w=1,
148 | padding='VALID', initializer='He', name='s6_conv3', logger=self.logger)
149 | s6_conv3 = tf_utils.relu(s6_conv3, name='relu_s6_conv3', logger=self.logger)
150 |
151 | # Stage 7
152 | s7_deconv1 = tf_utils.deconv2d(x=s6_conv3, output_dim=self.conv_dims[13], k_h=2, k_w=2, initializer='He',
153 | name='s7_deconv1', logger=self.logger)
154 | s7_deconv1 = tf_utils.relu(s7_deconv1, name='relu_s7_deconv1', logger=self.logger)
155 | # Cropping
156 | h1, w1 = s3_conv2.get_shape().as_list()[1:3]
157 | h2, w2 = s7_deconv1.get_shape().as_list()[1:3]
158 | s3_conv2_crop = tf.image.crop_to_bounding_box(image=s3_conv2,
159 | offset_height=int(0.5 * (h1 - h2)),
160 | offset_width=int(0.5 * (w1 - w2)),
161 | target_height=h2,
162 | target_width=w2)
163 | tf_utils.print_activations(s3_conv2_crop, logger=self.logger)
164 |
165 | s7_concat = tf_utils.concat(values=[s3_conv2_crop, s7_deconv1], axis=3, name='s7_concat', logger=self.logger)
166 | s7_conv2 = tf_utils.conv2d(x=s7_concat, output_dim=self.conv_dims[14], k_h=3, k_w=3, d_h=1, d_w=1,
167 | padding='VALID', initializer='He', name='s7_conv2', logger=self.logger)
168 | s7_conv2 = tf_utils.relu(s7_conv2, name='relu_s7_conv2', logger=self.logger)
169 | s7_conv3 = tf_utils.conv2d(x=s7_conv2, output_dim=self.conv_dims[15], k_h=3, k_w=3, d_h=1, d_w=1,
170 | padding='VALID', initializer='He', name='s7_conv3', logger=self.logger)
171 | s7_conv3 = tf_utils.relu(s7_conv3, name='relu_s7_conv3', logger=self.logger)
172 |
173 | # Stage 8
174 | s8_deconv1 = tf_utils.deconv2d(x=s7_conv3, output_dim=self.conv_dims[16], k_h=2, k_w=2, initializer='He',
175 | name='s8_deconv1', logger=self.logger)
176 | s8_deconv1 = tf_utils.relu(s8_deconv1, name='relu_s8_deconv1', logger=self.logger)
177 | # Cropping
178 | h1, w1 = s2_conv2.get_shape().as_list()[1:3]
179 | h2, w2 = s8_deconv1.get_shape().as_list()[1:3]
180 | s2_conv2_crop = tf.image.crop_to_bounding_box(image=s2_conv2,
181 | offset_height=int(0.5 * (h1 - h2)),
182 | offset_width=int(0.5 * (w1 - w2)),
183 | target_height=h2,
184 | target_width=w2)
185 | tf_utils.print_activations(s2_conv2_crop, logger=self.logger)
186 |
187 | s8_concat = tf_utils.concat(values=[s2_conv2_crop, s8_deconv1], axis=3, name='s8_concat', logger=self.logger)
188 | s8_conv2 = tf_utils.conv2d(x=s8_concat, output_dim=self.conv_dims[17], k_h=3, k_w=3, d_h=1, d_w=1,
189 | padding='VALID', initializer='He', name='s8_conv2', logger=self.logger)
190 | s8_conv2 = tf_utils.relu(s8_conv2, name='relu_s8_conv2', logger=self.logger)
191 | s8_conv3 = tf_utils.conv2d(x=s8_conv2, output_dim=self.conv_dims[18], k_h=3, k_w=3, d_h=1, d_w=1,
192 | padding='VALID', initializer='He', name='s8_conv3', logger=self.logger)
193 | s8_conv3 = tf_utils.relu(s8_conv3, name='relu_conv3', logger=self.logger)
194 |
195 | # Stage 9
196 | s9_deconv1 = tf_utils.deconv2d(x=s8_conv3, output_dim=self.conv_dims[19], k_h=2, k_w=2, initializer='He',
197 | name='s9_deconv1', logger=self.logger)
198 | s9_deconv1 = tf_utils.relu(s9_deconv1, name='relu_s9_deconv1', logger=self.logger)
199 | # Cropping
200 | h1, w1 = s1_conv2.get_shape().as_list()[1:3]
201 | h2, w2 = s9_deconv1.get_shape().as_list()[1:3]
202 | s1_conv2_crop = tf.image.crop_to_bounding_box(image=s1_conv2,
203 | offset_height=int(0.5 * (h1 - h2)),
204 | offset_width=int(0.5 * (w1 - w2)),
205 | target_height=h2,
206 | target_width=w2)
207 | tf_utils.print_activations(s1_conv2_crop, logger=self.logger)
208 |
209 | s9_concat = tf_utils.concat(values=[s1_conv2_crop, s9_deconv1], axis=3, name='s9_concat', logger=self.logger)
210 | s9_conv2 = tf_utils.conv2d(x=s9_concat, output_dim=self.conv_dims[20], k_h=3, k_w=3, d_h=1, d_w=1,
211 | padding='VALID', initializer='He', name='s9_conv2', logger=self.logger)
212 | s9_conv2 = tf_utils.relu(s9_conv2, name='relu_s9_conv2', logger=self.logger)
213 | s9_conv3 = tf_utils.conv2d(x=s9_conv2, output_dim=self.conv_dims[21], k_h=3, k_w=3, d_h=1, d_w=1,
214 | padding='VALID', initializer='He', name='s9_conv3', logger=self.logger)
215 | s9_conv3 = tf_utils.relu(s9_conv3, name='relu_s9_conv3', logger=self.logger)
216 | self.pred = tf_utils.conv2d(x=s9_conv3, output_dim=self.conv_dims[22], k_h=1, k_w=1, d_h=1, d_w=1,
217 | padding='SAME', initializer='He', name='output', logger=self.logger)
218 |
219 | def optimizer_fn(self,loss, name=None):
220 | with tf.variable_scope(name):
221 | global_step = tf.Variable(0, dtype=tf.float32, trainable=False)
222 | start_learning_rate = self.lr
223 | end_learning_rate = 0.
224 | start_decay_step = self.start_decay_step
225 | decay_steps = self.decay_steps
226 |
227 | learning_rate = (tf.where(tf.greater_equal(global_step, start_decay_step),
228 | tf.train.polynomial_decay(start_learning_rate,
229 | global_step - start_decay_step,
230 | decay_steps, end_learning_rate, power=1.0),
231 | start_learning_rate))
232 | self.tb_lr = tf.summary.scalar('learning_rate', learning_rate)
233 |
234 | learn_step = tf.train.AdamOptimizer(learning_rate, beta1=0.99).minimize(loss, global_step=global_step)
235 |
236 | return learn_step
237 |
238 | def _tensorboard(self):
239 | self.tb_total = tf.summary.scalar('Loss/total', self.total_loss)
240 | self.tb_data = tf.summary.scalar('Loss/avg_data', self.avg_data_loss)
241 | self.tb_weighted_data = tf.summary.scalar('Loss/weighted_data', self.weighted_data_loss)
242 | self.tb_reg = tf.summary.scalar('Loss/reg_term', self.reg_term)
243 | self.summary_op = tf.summary.merge(
244 | inputs=[self.tb_lr, self.tb_total, self.tb_data, self.tb_weighted_data, self.tb_reg])
245 |
246 | self.val_acc_op = tf.summary.scalar('Acc', self.val_acc)
247 |
248 | def release_handles(self):
249 | utils.release_handles(self.logger, self.file_handler, self.stream_handler)
250 |
251 |
--------------------------------------------------------------------------------
/src/preprocessing.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow U-Net Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import cv2
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 | import tifffile as tiff
12 |
13 | def main(dataset_path=None, is_write=False):
14 | # train_imgs = tiff.imread(os.path.join(dataset_path, 'train-volume.tif'))
15 | train_labels = tiff.imread(os.path.join(dataset_path, 'train-labels.tif'))
16 |
17 | save_dir = 'wmap_imgs'
18 | if not os.path.isdir(save_dir):
19 | os.makedirs(save_dir)
20 |
21 | file_name = os.path.join(dataset_path, 'train-wmaps.npy')
22 | if is_write:
23 | wmaps = np.zeros_like(train_labels, dtype=np.float32)
24 | for idx in range(train_labels.shape[0]):
25 | print('Image index: {}'.format(idx))
26 | img = train_labels[idx]
27 | cal_weight_map(label=img, wmaps=wmaps, save_dir=save_dir, iter_time=idx)
28 |
29 | np.save(file_name, wmaps)
30 |
31 | wmaps = np.load(file_name)
32 | plot_wmaps(wmaps)
33 |
34 |
35 | def plot_wmaps(wmaps, nrows=5, ncols=6, hspace=0.2, wspace=0.1, vmin=0., vmax=12., interpolation='nearest'):
36 | # Create figure with sub-plots.
37 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(18, 10))
38 |
39 | # Adjust vertical and horizontal spacing
40 | fig.subplots_adjust(hspace=hspace, wspace=wspace)
41 |
42 | im = None
43 | for i, ax in enumerate(axes.flat):
44 | if i < (nrows * ncols):
45 | # Plot image
46 | im = ax.imshow(wmaps[i], interpolation=interpolation, cmap=plt.cm.jet, vmin=vmin, vmax=vmax)
47 |
48 | # Show the classes as the label on the x-axis.
49 | xlabel = "Weight Map: {0}".format(str(i).zfill(2))
50 | ax.set_xlabel(xlabel)
51 |
52 | # Remove ticks from the plot.
53 | ax.set_xticks([])
54 | ax.set_yticks([])
55 |
56 | # Add colorbar
57 | # fig.colorbar(im) # cmap=plt.cm.jet, vmin=vmin, vmax=vmax,
58 |
59 | # fig.subplots_adjust(right=0.8)
60 | fig.colorbar(im, ax=axes.ravel().tolist())
61 |
62 | plt.show()
63 |
64 |
65 | def cal_weight_map(label, wmaps, save_dir, iter_time, wc=1., w0=10., sigma=5, interval=500, vmin=0, vmax=12):
66 | _, contours, _ = cv2.findContours(label, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
67 | wmap = wc * np.ones_like(label, dtype=np.float32)
68 | img = label.copy()
69 |
70 | y_points, x_points = np.where(img == 0)
71 | for idx, (y_point, x_point) in enumerate(zip(y_points, x_points)):
72 | if np.mod(idx, interval) == 0:
73 | print('{} / {}'.format(idx, len(y_points)))
74 |
75 | point = np.array([x_point, y_point]).astype(np.float32)
76 | dis_arr = []
77 |
78 | for i in range(len(contours)):
79 | cnt = (np.squeeze(contours[i])).astype(np.float32)
80 | dis_arr.append(np.amin(np.sqrt(np.sum(np.power(point - cnt, 2), axis=1))))
81 |
82 | dis_arr.sort() # sorting
83 | wmap[y_point, x_point] += wc + w0 * np.exp(- np.power(np.sum(dis_arr[0:2]), 2) / (2 * sigma * sigma))
84 |
85 | plt.imshow(wmap, cmap=plt.cm.jet, vmin=vmin, vmax=vmax)
86 | plt.axis('off')
87 |
88 | # To solve the multiple color-bar problem
89 | if iter_time == 0:
90 | plt.colorbar()
91 |
92 | plt.savefig(os.path.join(save_dir, str(iter_time).zfill(2) + '.png'), bbox_inches='tight')
93 | wmaps[iter_time] = wmap
94 |
95 |
96 | if __name__ == '__main__':
97 | main(dataset_path='../../Data/EMSegmentation', is_write=False)
98 |
99 |
100 |
--------------------------------------------------------------------------------
/src/solver.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow U-Net Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import cv2
9 | import numpy as np
10 | import tensorflow as tf
11 |
12 | import utils as utils
13 |
14 |
15 | class Solver(object):
16 | def __init__(self, sess, model, mean_value):
17 | self.sess = sess
18 | self.model = model
19 | self.mean_value = mean_value
20 |
21 | def train(self, x, y, wmap):
22 | feed = {
23 | self.model.inp_img: np.expand_dims(x, axis=3),
24 | self.model.out_img: y,
25 | self.model.weight_map: wmap,
26 | self.model.keep_prob: 0.5
27 | }
28 |
29 | train_op = self.model.train_op
30 | total_loss = self.model.total_loss
31 | avg_data_loass = self.model.avg_data_loss
32 | weighted_data_loss = self.model.weighted_data_loss
33 | reg_term = self.model.reg_term
34 | pred_cls = self.model.pred_cls
35 | summary = self.model.summary_op
36 |
37 | return self.sess.run([train_op, total_loss, avg_data_loass, weighted_data_loss, reg_term, summary, pred_cls],
38 | feed_dict=feed)
39 |
40 | def evalate(self, x, y, batch_size=4):
41 | print(' [*] Evaluation...')
42 |
43 | num_test = x.shape[0]
44 | avg_acc = 0.
45 |
46 | for i_start in range(0, num_test, batch_size):
47 | if i_start + batch_size < num_test:
48 | i_end = i_start + batch_size
49 | else:
50 | i_end = num_test - 1
51 |
52 | x_batch = x[i_start:i_end]
53 | y_batch = y[i_start:i_end]
54 |
55 | feed = {
56 | self.model.inp_img: np.expand_dims(x_batch, axis=3),
57 | self.model.out_img: y_batch,
58 | self.model.keep_prob: 1.0
59 | }
60 |
61 | acc_op = self.model.accuracy
62 | avg_acc += self.sess.run(acc_op, feed_dict=feed)
63 |
64 | avg_acc = np.float32(avg_acc / np.ceil(num_test / batch_size))
65 | summary = self.sess.run(self.model.val_acc_op, feed_dict={self.model.val_acc: avg_acc})
66 |
67 | return avg_acc, summary
68 |
69 | def test(self, x, iter_time, angle, test_dir, is_save=False):
70 | feed = {
71 | self.model.inp_img: np.expand_dims(x, axis=3),
72 | self.model.keep_prob: 1.0
73 | }
74 |
75 | preds = self.sess.run(self.model.pred, feed_dict=feed)
76 | pred = utils.merge_preds(preds, idx=iter_time, angle=angle, test_dir=test_dir, is_save=is_save)
77 |
78 | return pred
79 |
80 | def save_imgs(self, x_imgs, pred_imgs, y_imgs, iter_time, sample_dir=None, border=5):
81 | num_cols = 3
82 | _, H1, W1 = x_imgs.shape
83 | N, H2, W2 = pred_imgs.shape
84 | margin = int(0.5 * (H1 - H2))
85 |
86 | canvas = np.zeros((N*H1+(N+1)*border, 1*W1+(num_cols-1)*W2+(num_cols+1)*border), dtype=np.uint8)
87 | for idx in range(N):
88 | canvas[(idx+1)*border+idx*H1:(idx+1)*border+(idx+1)*H1, border:border+W1] = \
89 | x_imgs[idx] + self.mean_value
90 | canvas[(idx+1)*border+idx*H1+margin:(idx+1)*border+idx*H1+margin+H2, 2*border+W1:2*border+W1+W2] = \
91 | pred_imgs[idx] * 255
92 | canvas[(idx+1)*border+idx*H1+margin:(idx+1)*border+idx*H1+margin+H2, 3*border+W1+W2:3*border+W1+2*W2] = \
93 | y_imgs[idx] * 255
94 |
95 | cv2.imwrite(os.path.join(sample_dir, str(iter_time).zfill(5) + '.png'), canvas)
96 |
97 | def save_acc_record(self, acc):
98 | self.sess.run(self.model.save_best_acc_op, feed_dict={self.model.val_acc: acc})
99 |
100 | def load_acc_record(self):
101 | best_acc = self.sess.run(self.model.best_acc_record)
102 | return best_acc
103 |
104 | def init(self):
105 | self.sess.run(tf.global_variables_initializer())
106 |
--------------------------------------------------------------------------------
/src/tensorflow_utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow Utils Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import numpy as np
8 | import tensorflow as tf
9 | from tensorflow.python.training import moving_averages
10 |
11 |
12 | def padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='pad2d'):
13 | if pad_type == 'REFLECT':
14 | return tf.pad(x, [[0, 0], [p_h, p_h], [p_w, p_w], [0, 0]], 'REFLECT', name=name)
15 |
16 |
17 | def conv2d(x, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, initializer=None, padding='SAME', name='conv2d',
18 | is_print=True, logger=None):
19 | with tf.variable_scope(name):
20 | if initializer is None:
21 | init_op = tf.truncated_normal_initializer(stddev=stddev)
22 | elif initializer == 'He':
23 | init_op = tf.initializers.he_normal()
24 | else:
25 | raise NotImplementedError
26 |
27 | w = tf.get_variable('w', [k_h, k_w, x.get_shape()[-1], output_dim], initializer=init_op)
28 | conv = tf.nn.conv2d(x, w, strides=[1, d_h, d_w, 1], padding=padding)
29 |
30 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
31 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
32 | conv = tf.nn.bias_add(conv, biases)
33 |
34 | if is_print:
35 | print_activations(conv, logger)
36 |
37 | return conv
38 |
39 |
40 | def deconv2d(x, output_dim, k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, initializer=None, padding_='SAME',
41 | output_size=None, name='deconv2d', with_w=False, is_print=True, logger=None):
42 | with tf.variable_scope(name):
43 | input_shape = x.get_shape().as_list()
44 |
45 | # calculate output size
46 | h_output, w_output = None, None
47 | if not output_size:
48 | h_output, w_output = input_shape[1] * 2, input_shape[2] * 2
49 | # output_shape = [input_shape[0], h_output, w_output, k] # error when not define batch_size
50 | output_shape = [tf.shape(x)[0], h_output, w_output, output_dim]
51 |
52 | # conv2d transpose
53 | if initializer is None:
54 | init_op = tf.random_normal_initializer(stddev=stddev)
55 | elif initializer == 'He':
56 | init_op = tf.initializers.he_normal()
57 | else:
58 | raise NotImplementedError
59 | w = tf.get_variable('w', [k_h, k_w, output_dim, input_shape[3]], initializer=init_op)
60 | deconv = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=[1, d_h, d_w, 1], padding=padding_)
61 |
62 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
63 | deconv = tf.nn.bias_add(deconv, biases)
64 |
65 | if is_print:
66 | print_activations(deconv, logger)
67 |
68 | if with_w:
69 | return deconv, w, biases
70 | else:
71 | return deconv
72 |
73 |
74 | def concat(values, axis, name='concat', is_print=True, logger=None):
75 | output = tf.concat(values=values, axis=axis, name=name)
76 |
77 | if is_print:
78 | print_activations(output, logger)
79 |
80 | return output
81 |
82 |
83 | def upsampling2d(x, size=(2, 2), name='upsampling2d'):
84 | with tf.name_scope(name):
85 | shape = x.get_shape().as_list()
86 | return tf.image.resize_nearest_neighbor(x, size=(size[0] * shape[1], size[1] * shape[2]))
87 |
88 | def flatten(x, name='flatten', data_format='channels_last', is_print=True, logger=None):
89 | try:
90 | output = tf.layers.flatten(inputs=x, name=name, data_format=data_format)
91 | except(RuntimeError, TypeError, NameError):
92 | print('[*] Catch the flatten function Error!')
93 | output = tf.contrib.layers.flatten(inputs=x, scope=name)
94 |
95 | if is_print:
96 | print_activations(output, logger)
97 | return output
98 |
99 |
100 | def linear(x, output_size, bias_start=0.0, with_w=False, name='fc', is_print=True, logger=None):
101 | shape = x.get_shape().as_list()
102 |
103 | with tf.variable_scope(name):
104 | matrix = tf.get_variable(name="matrix", shape=[shape[1], output_size],
105 | dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer())
106 | bias = tf.get_variable(name="bias", shape=[output_size],
107 | initializer=tf.constant_initializer(bias_start))
108 | output = tf.matmul(x, matrix) + bias
109 |
110 | if is_print:
111 | print_activations(output, logger)
112 |
113 | if with_w:
114 | return output, matrix, bias
115 | else:
116 | return output
117 |
118 |
119 | def norm(x, name, _type, _ops, is_train=True):
120 | if _type == 'batch':
121 | return batch_norm(x, name=name, _ops=_ops, is_train=is_train)
122 | elif _type == 'instance':
123 | return instance_norm(x, name=name)
124 | else:
125 | raise NotImplementedError
126 |
127 |
128 | def batch_norm(x, name, _ops, is_train=True):
129 | """Batch normalization."""
130 | with tf.variable_scope(name):
131 | params_shape = [x.get_shape()[-1]]
132 |
133 | beta = tf.get_variable('beta', params_shape, tf.float32,
134 | initializer=tf.constant_initializer(0.0, tf.float32))
135 | gamma = tf.get_variable('gamma', params_shape, tf.float32,
136 | initializer=tf.constant_initializer(1.0, tf.float32))
137 |
138 | if is_train is True:
139 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments')
140 |
141 | moving_mean = tf.get_variable('moving_mean', params_shape, tf.float32,
142 | initializer=tf.constant_initializer(0.0, tf.float32),
143 | trainable=False)
144 | moving_variance = tf.get_variable('moving_variance', params_shape, tf.float32,
145 | initializer=tf.constant_initializer(1.0, tf.float32),
146 | trainable=False)
147 |
148 | _ops.append(moving_averages.assign_moving_average(moving_mean, mean, 0.9))
149 | _ops.append(moving_averages.assign_moving_average(moving_variance, variance, 0.9))
150 | else:
151 | mean = tf.get_variable('moving_mean', params_shape, tf.float32,
152 | initializer=tf.constant_initializer(0.0, tf.float32), trainable=False)
153 | variance = tf.get_variable('moving_variance', params_shape, tf.float32,
154 | initializer=tf.constant_initializer(1.0, tf.float32), trainable=False)
155 |
156 | # epsilon used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net.
157 | y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-5)
158 | y.set_shape(x.get_shape())
159 |
160 | return y
161 |
162 |
163 | def instance_norm(x, name='instance_norm', mean=1.0, stddev=0.02, epsilon=1e-5):
164 | with tf.variable_scope(name):
165 | depth = x.get_shape()[3]
166 | scale = tf.get_variable(
167 | 'scale', [depth], tf.float32,
168 | initializer=tf.random_normal_initializer(mean=mean, stddev=stddev, dtype=tf.float32))
169 | offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0))
170 |
171 | # calcualte mean and variance as instance
172 | mean, variance = tf.nn.moments(x, axes=[1, 2], keep_dims=True)
173 |
174 | # normalization
175 | inv = tf.rsqrt(variance + epsilon)
176 | normalized = (x - mean) * inv
177 |
178 | return scale * normalized + offset
179 |
180 |
181 | def n_res_blocks(x, _ops=None, norm_='instance', is_train=True, num_blocks=6, is_print=True, logger=None):
182 | output = None
183 | for idx in range(1, num_blocks+1):
184 | output = res_block(x, x.get_shape()[3], _ops=_ops, norm_=norm_, is_train=is_train,
185 | name='res{}'.format(idx))
186 | x = output
187 |
188 | if is_print:
189 | print_activations(output, logger)
190 |
191 | return output
192 |
193 |
194 | # norm(x, name, _type, _ops, is_train=True)
195 | def res_block(x, k, _ops=None, norm_='instance', is_train=True, pad_type=None, name=None):
196 | with tf.variable_scope(name):
197 | conv1, conv2 = None, None
198 |
199 | # 3x3 Conv-Batch-Relu S1
200 | with tf.variable_scope('layer1'):
201 | if pad_type is None:
202 | conv1 = conv2d(x, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv')
203 | elif pad_type == 'REFLECT':
204 | padded1 = padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='padding')
205 | conv1 = conv2d(padded1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv')
206 | normalized1 = norm(conv1, name='norm', _type=norm_, _ops=_ops, is_train=is_train)
207 | relu1 = tf.nn.relu(normalized1)
208 |
209 | # 3x3 Conv-Batch S1
210 | with tf.variable_scope('layer2'):
211 | if pad_type is None:
212 | conv2 = conv2d(relu1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv')
213 | elif pad_type == 'REFLECT':
214 | padded2 = padding2d(relu1, p_h=1, p_w=1, pad_type='REFLECT', name='padding')
215 | conv2 = conv2d(padded2, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv')
216 | normalized2 = norm(conv2, name='norm', _type=norm_, _ops=_ops, is_train=is_train)
217 |
218 | # sum layer1 and layer2
219 | output = x + normalized2
220 | return output
221 |
222 |
223 | def identity(x, name='identity', is_print=True, logger=None):
224 | output = tf.identity(x, name=name)
225 | if is_print:
226 | print_activations(output, logger)
227 |
228 | return output
229 |
230 |
231 | def max_pool(x, name='max_pool', ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], is_print=True, logger=None):
232 | output = tf.nn.max_pool(value=x, ksize=ksize, strides=strides, padding='SAME', name=name)
233 | if is_print:
234 | print_activations(output, logger)
235 | return output
236 |
237 | def dropout(x, keep_prob=0.5, seed=None, name='dropout', is_print=True, logger=None):
238 | try:
239 | output = tf.nn.dropout(x=x,
240 | rate=keep_prob,
241 | seed=tf.set_random_seed(seed) if seed else None,
242 | name=name)
243 | except(RuntimeError, TypeError, NameError):
244 | print('[*] Catch the dropout function Error!')
245 | output = tf.nn.dropout(x=x,
246 | keep_prob=keep_prob,
247 | seed=tf.set_random_seed(seed) if seed else None,
248 | name=name)
249 |
250 | if is_print:
251 | print_activations(output, logger)
252 |
253 | return output
254 |
255 | def sigmoid(x, name='sigmoid', is_print=True, logger=None):
256 | output = tf.nn.sigmoid(x, name=name)
257 | if is_print:
258 | print_activations(output, logger)
259 |
260 | return output
261 |
262 |
263 | def tanh(x, name='tanh', is_print=True, logger=None):
264 | output = tf.nn.tanh(x, name=name)
265 | if is_print:
266 | print_activations(output, logger)
267 |
268 | return output
269 |
270 |
271 | def relu(x, name='relu', is_print=True, logger=None):
272 | output = tf.nn.relu(x, name=name)
273 | if is_print:
274 | print_activations(output, logger)
275 |
276 | return output
277 |
278 |
279 | def lrelu(x, leak=0.2, name='lrelu', is_print=True, logger=None):
280 | output = tf.maximum(x, leak*x, name=name)
281 | if is_print:
282 | print_activations(output, logger)
283 |
284 | return output
285 |
286 |
287 | def elu(x, name='elu', is_print=True, logger=None):
288 | output = tf.nn.elu(x, name=name)
289 | if is_print:
290 | print_activations(output, logger)
291 |
292 | return output
293 |
294 |
295 | def xavier_init(in_dim):
296 | xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
297 | return xavier_stddev
298 |
299 |
300 | def print_activations(t, logger=None):
301 | if logger is None:
302 | print(t.op.name, '{}', t.get_shape().as_list())
303 | else:
304 | logger.info(t.op.name + '{}'.format(t.get_shape().as_list()))
305 |
306 |
307 | def show_all_variables(logger=None):
308 | total_count = 0
309 |
310 | for idx, op in enumerate(tf.trainable_variables()):
311 | shape = op.get_shape()
312 | count = np.prod(shape)
313 |
314 | if logger is None:
315 | print("[%2d] %s %s = %s" % (idx, op.name, shape, count))
316 | else:
317 | logger.info("[%2d] %s %s = %s" % (idx, op.name, shape, count))
318 |
319 | total_count += int(count)
320 |
321 | if logger is None:
322 | print("[Total] variable size: %s" % "{:,}".format(total_count))
323 | else:
324 | logger.info("[Total] variable size: %s" % "{:,}".format(total_count))
325 |
326 |
327 | def batch_convert2int(images):
328 | # images: 4D float tensor (batch_size, image_size, image_size, depth)
329 | return tf.map_fn(convert2int, images, dtype=tf.uint8)
330 |
331 |
332 | def convert2int(image):
333 | # transform from float tensor ([-1.,1.]) to int image ([0,255])
334 | return tf.image.convert_image_dtype((image + 1.0) / 2.0, tf.uint8)
335 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------
2 | # Tensorflow U-Net Implementation
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # Written by Cheng-Bin Jin
5 | # Email: sbkim0407@gmail.com
6 | # ---------------------------------------------------------
7 | import os
8 | import sys
9 | import cv2
10 | import logging
11 | import elasticdeform
12 | import numpy as np
13 | from sklearn.metrics import confusion_matrix
14 | from scipy.ndimage import rotate
15 |
16 |
17 | def init_logger(log_dir, name, is_train):
18 | logger = logging.getLogger(__name__) # logger
19 | logger.setLevel(logging.INFO)
20 |
21 | file_handler, stream_handler = None, None
22 | if is_train:
23 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s')
24 |
25 | # file handler
26 | file_handler = logging.FileHandler(os.path.join(log_dir, name + '.log'))
27 | file_handler.setFormatter(formatter)
28 | file_handler.setLevel(logging.INFO)
29 |
30 | # stream handler
31 | stream_handler = logging.StreamHandler()
32 | stream_handler.setFormatter(formatter)
33 |
34 | # add handlers
35 | logger.addHandler(file_handler)
36 | logger.addHandler(stream_handler)
37 |
38 | return logger, file_handler, stream_handler
39 |
40 |
41 | def release_handles(logger, file_handler, stream_handler):
42 | file_handler.close()
43 | stream_handler.close()
44 | logger.removeHandler(file_handler)
45 | logger.removeHandler(stream_handler)
46 |
47 |
48 | def make_folders(is_train=True, cur_time=None):
49 | if is_train:
50 | model_dir = os.path.join('model', '{}'.format(cur_time))
51 | log_dir = os.path.join('logs', '{}'.format(cur_time))
52 | sample_dir = os.path.join('sample', '{}'.format(cur_time))
53 | test_dir = os.path.join('test', '{}'.format(cur_time))
54 |
55 | if not os.path.isdir(model_dir):
56 | os.makedirs(model_dir)
57 |
58 | if not os.path.isdir(log_dir):
59 | os.makedirs(log_dir)
60 |
61 | if not os.path.isdir(sample_dir):
62 | os.makedirs(sample_dir)
63 |
64 | else:
65 | model_dir = os.path.join('model', '{}'.format(cur_time))
66 | log_dir = os.path.join('logs', '{}'.format(cur_time))
67 | sample_dir = os.path.join('sample', '{}'.format(cur_time))
68 | test_dir = os.path.join('test', '{}'.format(cur_time))
69 |
70 | if not os.path.isdir(test_dir):
71 | os.makedirs(test_dir)
72 |
73 | return model_dir, log_dir, sample_dir, test_dir
74 |
75 |
76 | def imshow(img, label, wmap, idx, alpha=0.6, delay=1, log_dir=None, show=False):
77 | img_dir = os.path.join(log_dir, 'img')
78 | if not os.path.isdir(img_dir):
79 | os.makedirs(img_dir)
80 |
81 | if len(img.shape) == 2:
82 | img = np.dstack((img, img, img))
83 |
84 | # Convert to pseudo color map from gray-scale image
85 | pseudo_label = None
86 | if len(label.shape) == 2:
87 | pseudo_label = pseudoColor(label)
88 |
89 | beta = 1. - alpha
90 | overlap = cv2.addWeighted(src1=img,
91 | alpha=alpha,
92 | src2=pseudo_label,
93 | beta=beta,
94 | gamma=0.0)
95 |
96 | # Weight-map
97 | wmap_color = cv2.applyColorMap(normalize_uint8(wmap), cv2.COLORMAP_JET)
98 |
99 | canvas = np.hstack((img, pseudo_label, overlap, wmap_color))
100 | cv2.imwrite(os.path.join(img_dir, 'GT_' + str(idx).zfill(2) + '.png'), canvas)
101 |
102 | if show:
103 | cv2.imshow('Show', canvas)
104 |
105 | if cv2.waitKey(delay) & 0xFF == 27:
106 | sys.exit('Esc clicked!')
107 |
108 | def normalize_uint8(x, x_min=0, x_max=12, fit=255):
109 | x_norm = np.uint8(fit * (x - x_min) / (x_max - x_min))
110 | return x_norm
111 |
112 |
113 | def pseudoColor(label, thickness=3):
114 | img = label.copy()
115 | img, contours, hierachy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
116 | img = np.dstack((img, img, img))
117 |
118 | for i in range(len(contours)):
119 | cnt = contours[i]
120 | cv2.drawContours(img, [cnt], contourIdx=-1, color=(0, 255, 0), thickness=thickness)
121 | cv2.fillPoly(img, [cnt], color=randomColors(i))
122 |
123 | return img
124 |
125 |
126 | def randomColors(idx):
127 | Sky = [128, 128, 128]
128 | Building = [128, 0, 0]
129 | Pole = [192, 192, 128]
130 | Road = [128, 64, 128]
131 | Pavement = [60, 40, 222]
132 | Tree = [128, 128, 0]
133 | SignSymbol = [192, 128, 128]
134 | Fence = [64, 64, 128]
135 | Car = [64, 0, 128]
136 | Pedestrian = [64, 64, 0]
137 | Bicyclist = [0, 128, 192]
138 | DarkRed = [0, 0, 139]
139 | PaleVioletRed = [147, 112, 219]
140 | Orange = [0, 165, 255]
141 | Teal = [128, 128, 0]
142 |
143 | color_dict = [Sky, Building, Pole, Road, Pavement,
144 | Tree, SignSymbol, Fence, Car, Pedestrian,
145 | Bicyclist, DarkRed, PaleVioletRed, Orange, Teal]
146 |
147 | return color_dict[idx % len(color_dict)]
148 |
149 |
150 | def test_augmentation(img, label, wmap, idx, margin=10, log_dir=None):
151 | img_dir = os.path.join(log_dir, 'img')
152 | if not os.path.isdir(img_dir):
153 | os.makedirs(img_dir)
154 |
155 | img_tran, label_tran, wmap_tran = aug_translate(img, label, wmap) # random translation
156 | img_flip, label_flip, wmap_flip = aug_flip(img, label, wmap) # random horizontal and vertical flip
157 | img_rota, label_rota, wmap_rota = aug_rotate(img, label, wmap) # random rotation
158 | img_defo, label_defo, wmap_defo = aug_elastic_deform(img, label, wmap) # random elastic deformation
159 | # img_pert, label_pert, wmap_pert = aug_perturbation(img, label, wmap) # random intensity perturbation
160 |
161 | # Arrange the images in a canvas and save them into the log file
162 | imgs = [img, img_tran, img_flip, img_rota, img_defo]
163 | labels = [label, label_tran, label_flip, label_rota, label_defo]
164 | wmaps = [wmap, wmap_tran, wmap_flip, wmap_rota, wmap_defo]
165 | h, w = img.shape
166 | canvas = np.zeros((3 * h + 4 * margin, len(imgs) * w + (len(imgs) + 1) * margin), dtype=np.uint8)
167 |
168 | for i, (img, label, wmap) in enumerate(zip(imgs, labels, wmaps)):
169 | canvas[1*margin:1*margin+h, (i+1) * margin + i * w:(i+1) * margin + (i + 1) * w] = img
170 | canvas[2*margin+1*h:2*margin+2*h, (i+1) * margin + i * w:(i+1) * margin + (i + 1) * w] = label
171 | canvas[3*margin+2*h:3*margin+3*h, (i+1) * margin + i * w:(i+1) * margin + (i + 1) * w] = normalize_uint8(wmap)
172 |
173 | cv2.imwrite(os.path.join(img_dir, 'augmentation_' + str(idx).zfill(2) + '.png'), canvas)
174 |
175 |
176 | def test_cropping(img, label, wmap, idx, input_size, output_size, log_dir=None,
177 | white=(255, 255, 255), blue=(255, 141, 47), red=(91, 70, 246), thickness=2, margin=10):
178 | img_dir = os.path.join(log_dir, 'img')
179 | if not os.path.isdir(img_dir):
180 | os.makedirs(img_dir)
181 |
182 | img_crop, label_crop, wmap_crop, img_pad, rand_pos_h, rand_pos_w = cropping(
183 | img, label, wmap, input_size, output_size, is_extend=True)
184 | border_size = int((input_size - output_size) * 0.5)
185 |
186 | # Convert gray images to BGR images
187 | img_pad = np.dstack((img_pad, img_pad, img_pad))
188 | label_show = np.dstack((label, label, label))
189 | wmap_show = normalize_uint8(np.dstack((wmap, wmap, wmap)))
190 |
191 | # Draw boundary lines
192 | img_pad = cv2.line(img=img_pad,
193 | pt1=(0, border_size),
194 | pt2=(img_pad.shape[1]-1, border_size),
195 | color=white,
196 | thickness=thickness)
197 | img_pad = cv2.line(img=img_pad,
198 | pt1=(0, img_pad.shape[0]-1-border_size),
199 | pt2=(img_pad.shape[1]-1, img_pad.shape[0]-1-border_size),
200 | color=white,
201 | thickness=thickness)
202 | img_pad = cv2.line(img=img_pad,
203 | pt1=(border_size, 0),
204 | pt2=(border_size, img_pad.shape[0]-1),
205 | color=white,
206 | thickness=thickness)
207 | img_pad = cv2.line(img=img_pad,
208 | pt1=(img_pad.shape[1]-1-border_size, 0),
209 | pt2=(img_pad.shape[1]-1-border_size, img_pad.shape[0]-1),
210 | color=white,
211 | thickness=thickness)
212 |
213 | # Draw the ROI input region
214 | img_pad = cv2.rectangle(img=img_pad,
215 | pt1=(rand_pos_w+border_size, rand_pos_h+border_size),
216 | pt2=(rand_pos_w+border_size+output_size, rand_pos_h+border_size+output_size),
217 | color=red,
218 | thickness=thickness+1)
219 | img_pad = cv2.rectangle(img=img_pad,
220 | pt1=(rand_pos_w, rand_pos_h),
221 | pt2=(rand_pos_w+input_size, rand_pos_h+input_size),
222 | color=blue,
223 | thickness=thickness+1)
224 | label_show = cv2.rectangle(img=label_show,
225 | pt1=(rand_pos_w, rand_pos_h),
226 | pt2=(rand_pos_w+output_size, rand_pos_h+output_size),
227 | color=red,
228 | thickness=thickness+1)
229 | wmap_show = cv2.rectangle(img=wmap_show,
230 | pt1=(rand_pos_w, rand_pos_h),
231 | pt2=(rand_pos_w+output_size, rand_pos_h+output_size),
232 | color=red,
233 | thickness=thickness+1)
234 |
235 | img_crop = cv2.rectangle(img=np.dstack((img_crop, img_crop, img_crop)),
236 | pt1=(2, 2),
237 | pt2=(img_crop.shape[1]-2, img_crop.shape[0]-2),
238 | color=blue,
239 | thickness=thickness+1)
240 | label_crop = cv2.rectangle(img=np.dstack((label_crop, label_crop, label_crop)),
241 | pt1=(2, 2),
242 | pt2=(label_crop.shape[1]-2, label_crop.shape[0]-2),
243 | color=red,
244 | thickness=thickness+1)
245 | wmap_crop = cv2.rectangle(img=normalize_uint8(np.dstack((wmap_crop, wmap_crop, wmap_crop))),
246 | pt1=(2, 2),
247 | pt2=(wmap_crop.shape[1]-2, wmap_crop.shape[0]-2),
248 | color=red,
249 | thickness=thickness+1)
250 |
251 | canvas = np.zeros((img_pad.shape[0] + label_show.shape[0] + wmap_show.shape[0] + 4 * margin,
252 | img_pad.shape[1] + img_crop.shape[1] + 3 * margin, 3), dtype=np.uint8)
253 |
254 | # Copy img_pad
255 | h_start = margin
256 | w_start = margin
257 | canvas[h_start:h_start + img_pad.shape[0], w_start:w_start+img_pad.shape[1], :] = img_pad
258 |
259 | # Copy label_show
260 | h_start = 2 * margin + img_pad.shape[0]
261 | w_start = margin
262 | canvas[h_start:h_start + label_show.shape[0], w_start:w_start + label_show.shape[1], :] = label_show
263 |
264 | # Copy wmap_show
265 | h_start = 3 * margin + img_pad.shape[0] + label_show.shape[0]
266 | w_start = margin
267 | canvas[h_start:h_start + wmap_show.shape[0], w_start:w_start + wmap_show.shape[1], :] = wmap_show
268 |
269 | # Draw connections between the left and right images
270 | # Four connections for the input images
271 | canvas = cv2.line(img=canvas,
272 | pt1=(margin+rand_pos_w, margin+rand_pos_h),
273 | pt2=(2*margin+img_pad.shape[1], margin),
274 | color=blue,
275 | thickness=thickness+1)
276 | canvas = cv2.line(img=canvas,
277 | pt1=(margin+rand_pos_w+input_size, margin+rand_pos_h),
278 | pt2=(2*margin+img_pad.shape[1]+img_crop.shape[1], margin),
279 | color=blue,
280 | thickness=thickness+1)
281 | canvas = cv2.line(img=canvas,
282 | pt1=(margin+rand_pos_w, margin+rand_pos_h+input_size),
283 | pt2=(2*margin+img_pad.shape[1], margin+input_size),
284 | color=blue,
285 | thickness=thickness+1)
286 | canvas = cv2.line(img=canvas,
287 | pt1=(margin+rand_pos_w+input_size, margin+rand_pos_h+input_size),
288 | pt2=(2*margin+img_pad.shape[1]+img_crop.shape[1], margin+input_size),
289 | color=blue,
290 | thickness=thickness+1)
291 |
292 | # Four connections for the label images
293 | canvas = cv2.line(img=canvas,
294 | pt1=(margin+rand_pos_w, 2*margin+img_pad.shape[1]+rand_pos_h),
295 | pt2=(2*margin+img_pad.shape[0], 2*margin+img_pad.shape[1]),
296 | color=red,
297 | thickness=thickness+1)
298 | canvas = cv2.line(img=canvas,
299 | pt1=(margin+output_size+rand_pos_w, 2*margin+img_pad.shape[0]+rand_pos_h),
300 | pt2=(2*margin+img_pad.shape[1]+output_size, 2*margin+img_pad.shape[0]),
301 | color=red,
302 | thickness=thickness+1)
303 | canvas = cv2.line(img=canvas,
304 | pt1=(margin+rand_pos_w, 2*margin+img_pad.shape[0]+output_size+rand_pos_h),
305 | pt2=(2*margin+img_pad.shape[1], 2*margin+img_pad.shape[0]+output_size),
306 | color=red,
307 | thickness=thickness+1)
308 | canvas = cv2.line(img=canvas,
309 | pt1=(margin+rand_pos_w+output_size, 2*margin+img_pad.shape[1]+output_size+rand_pos_h),
310 | pt2=(2*margin+img_pad.shape[1]+output_size, 2*margin+img_pad.shape[1]+output_size),
311 | color=red,
312 | thickness=thickness+1)
313 |
314 | # Four connections for the weight-map images
315 | canvas = cv2.line(img=canvas,
316 | pt1=(margin+rand_pos_w, 3*margin+img_pad.shape[0]+label_show.shape[0]+rand_pos_h),
317 | pt2=(2*margin+img_pad.shape[1], 3*margin+img_pad.shape[0]+label_show.shape[0]),
318 | color=red,
319 | thickness=thickness+1)
320 | canvas = cv2.line(img=canvas,
321 | pt1=(margin+output_size+rand_pos_w, 3*margin+img_pad.shape[0]+label_show.shape[0]+rand_pos_h),
322 | pt2=(2*margin+img_pad.shape[1]+output_size, 3*margin+img_pad.shape[0]+label_show.shape[0]),
323 | color=red,
324 | thickness=thickness+1)
325 | canvas = cv2.line(img=canvas,
326 | pt1=(margin+rand_pos_w, 3*margin+img_pad.shape[0]+label_show.shape[0]+output_size+rand_pos_h),
327 | pt2=(2*margin+img_pad.shape[1], 3*margin+img_pad.shape[0]+label_show.shape[0]+output_size),
328 | color=red,
329 | thickness=thickness+1)
330 | canvas = cv2.line(img=canvas,
331 | pt1=(margin+rand_pos_w+output_size, 3*margin+img_pad.shape[0]+label_show.shape[0]+output_size+rand_pos_h),
332 | pt2=(2*margin+img_pad.shape[1]+output_size, 3*margin+img_pad.shape[0]+label_show.shape[0]+output_size),
333 | color=red,
334 | thickness=thickness+1)
335 |
336 | # Copy img_crop
337 | h_start = margin
338 | w_start = 2 * margin + img_pad.shape[1]
339 | canvas[h_start:h_start + img_crop.shape[0], w_start:w_start + img_crop.shape[1], :] = img_crop
340 |
341 | # Copy label_crop
342 | h_start = 2 * margin + img_pad.shape[0]
343 | w_start = 2 * margin + img_pad.shape[1]
344 | canvas[h_start:h_start + label_crop.shape[0], w_start:w_start + label_crop.shape[1], :] = label_crop
345 |
346 | # Copy wmap_crop
347 | h_start = 3 * margin + img_pad.shape[0] + label_show.shape[0]
348 | w_start = 2 * margin + img_pad.shape[1]
349 | canvas[h_start:h_start + wmap_crop.shape[0], w_start:w_start + wmap_crop.shape[1], :] = wmap_crop
350 |
351 | cv2.imwrite(os.path.join(img_dir, 'crop_' + str(idx).zfill(2) + '.png'), canvas)
352 |
353 |
354 | def aug_translate(img, label, wmap, max_factor=1.2):
355 | assert len(img.shape) == 2 and len(label.shape) == 2 and len(wmap.shape) == 2
356 |
357 | # Resize originl image
358 | resize_factor = np.random.uniform(low=1., high=max_factor)
359 | img_bigger = cv2.resize(src=img.copy(), dsize=None, fx=resize_factor, fy=resize_factor,
360 | interpolation=cv2.INTER_LINEAR)
361 | label_bigger = cv2.resize(src=label.copy(), dsize=None, fx=resize_factor, fy=resize_factor,
362 | interpolation=cv2.INTER_NEAREST)
363 | wmap_bigger = cv2.resize(src=wmap.copy(), dsize=None, fx=resize_factor, fy=resize_factor,
364 | interpolation=cv2.INTER_NEAREST)
365 |
366 | # Generate random positions for horizontal and vertical axes
367 | h_bigger, w_bigger = img_bigger.shape
368 | h_star = np.random.random_integers(low=0, high=h_bigger-img.shape[0])
369 | w_star = np.random.random_integers(low=0, high=w_bigger-img.shape[1])
370 |
371 | # Crop image from the bigger one
372 | img_crop = img_bigger[h_star:h_star+img.shape[1], w_star:w_star+img.shape[0]]
373 | label_crop = label_bigger[h_star:h_star+img.shape[1], w_star:w_star+img.shape[0]]
374 | wmap_crop = wmap_bigger[h_star:h_star+img.shape[1], w_star:w_star+img.shape[0]]
375 |
376 | return img_crop, label_crop, wmap_crop
377 |
378 |
379 | def aug_flip(img, label, wmap):
380 | assert len(img.shape) == 2 and len(label.shape) == 2 and len(wmap.shape) == 2
381 |
382 | # Random horizontal flip
383 | if np.random.uniform(low=0., high=1.) > 0.5:
384 | img_hflip = cv2.flip(src=img, flipCode=0)
385 | label_hflip = cv2.flip(src=label, flipCode=0)
386 | wmap_hflip = cv2.flip(src=wmap, flipCode=0)
387 | else:
388 | img_hflip = img.copy()
389 | label_hflip = label.copy()
390 | wmap_hflip = wmap.copy()
391 |
392 | # Random vertical flip
393 | if np.random.uniform(low=0., high=1.) > 0.5:
394 | img_vflip = cv2.flip(src=img_hflip, flipCode=1)
395 | label_vflip = cv2.flip(src=label_hflip, flipCode=1)
396 | wmap_vflip = cv2.flip(src=wmap_hflip, flipCode=1)
397 | else:
398 | img_vflip = img_hflip.copy()
399 | label_vflip = label_hflip.copy()
400 | wmap_vflip = wmap_hflip.copy()
401 |
402 | return img_vflip, label_vflip, wmap_vflip
403 |
404 | def test_rotate(img, iter_time, test_dir=None, margin=5, start=0, stop=360, num=7):
405 | h, w = img.shape
406 | canvas = np.zeros((2*margin+h, num*w+(num+1)*margin), dtype=np.uint8)
407 |
408 | # Rotate test image using 60 degree interval
409 | for i, angle in enumerate(np.linspace(start=start, stop=stop, num=num, endpoint=False)):
410 | # print('angle: {:.3f}'.format(angle))
411 | img_rotate = rotate(input=img, angle=angle, axes=(0, 1), reshape=False, order=3, mode='reflect')
412 | canvas[margin:margin+h, (i+1)*margin+i*w:(i+1)*margin+(i+1)*w] = img_rotate
413 |
414 | cv2.imwrite(os.path.join(test_dir, 'Rotate_' + str(iter_time).zfill(2) + '.png'), canvas)
415 |
416 |
417 | def aug_rotate(img, label, wmap):
418 | assert len(img.shape) == 2 and len(label.shape) == 2 and len(wmap.shape)
419 |
420 | # Random rotate image
421 | angle = np.random.randint(low=0, high=360, size=None)
422 | img_rotate = rotate(input=img, angle=angle, axes=(0, 1), reshape=False, order=3, mode='reflect')
423 | label_rotate = rotate(input=label, angle=angle, axes=(0, 1), reshape=False, order=0, mode='reflect')
424 | wmap_rotate = rotate(input=wmap, angle=angle, axes=(0, 1), reshape=False, order=0, mode='reflect')
425 |
426 | # Correct label map
427 | ret, label_rotate = cv2.threshold(src=label_rotate, thresh=127.5, maxval=255, type=cv2.THRESH_BINARY)
428 |
429 | return img_rotate, label_rotate, wmap_rotate
430 |
431 |
432 | def aug_elastic_deform(img, label, wmap):
433 | assert len(img.shape) == 2 and len(label.shape) == 2 and len(wmap.shape) == 2
434 |
435 | # Apply deformation with a random 3 x 3 grid to inputs X and Y,
436 | # with a different interpolation for each input
437 | img_distort, label_distort, wmap_distort = elasticdeform.deform_random_grid(X=[img, label, wmap],
438 | sigma=10,
439 | points=3,
440 | order=[2, 0, 0],
441 | mode='mirror')
442 |
443 | return img_distort, label_distort, wmap_distort
444 |
445 |
446 | def aug_perturbation(img, label, wmap, low=0.8, high=1.2):
447 | pertur_map = np.random.uniform(low=low, high=high, size=img.shape)
448 | img_en = np.round(img * pertur_map).astype(np.uint8)
449 | img_en = np.clip(img_en, a_min=0, a_max=255)
450 | return img_en, label, wmap
451 |
452 |
453 | def test_imshow(img, idx, input_size=572, output_size=388, test_dir=None, margin=5, white=(255, 255, 255), thickness=2):
454 | border_size = int((input_size - output_size) * 0.5)
455 | img_pad = cv2.copyMakeBorder(img, border_size, border_size, border_size, border_size, cv2.BORDER_REFLECT_101)
456 |
457 | # Draw boundary lines
458 | img_pad = cv2.line(img=img_pad,
459 | pt1=(0, border_size),
460 | pt2=(img_pad.shape[1] - 1, border_size),
461 | color=white,
462 | thickness=thickness)
463 | img_pad = cv2.line(img=img_pad,
464 | pt1=(0, img_pad.shape[0] - 1 - border_size),
465 | pt2=(img_pad.shape[1] - 1, img_pad.shape[0] - 1 - border_size),
466 | color=white,
467 | thickness=thickness)
468 | img_pad = cv2.line(img=img_pad,
469 | pt1=(border_size, 0),
470 | pt2=(border_size, img_pad.shape[0] - 1),
471 | color=white,
472 | thickness=thickness)
473 | img_pad = cv2.line(img=img_pad,
474 | pt1=(img_pad.shape[1] - 1 - border_size, 0),
475 | pt2=(img_pad.shape[1] - 1 - border_size, img_pad.shape[0] - 1),
476 | color=white,
477 | thickness=thickness)
478 |
479 | # Crop 4 corners from the padded image
480 | img0 = img_pad[:input_size, :input_size]
481 | img1 = img_pad[:input_size, -input_size:]
482 | img2 = img_pad[-input_size:, :input_size]
483 | img3 = img_pad[-input_size:, -input_size:]
484 |
485 | canvas = np.zeros((2*input_size+3*margin, 2*input_size+img_pad.shape[1]+4*margin))
486 | canvas[margin:margin+img_pad.shape[0], margin:margin+img_pad.shape[1]] = img_pad
487 | canvas[margin:margin+input_size, 2*margin+img_pad.shape[1]:2*margin+img_pad.shape[1]+input_size] = img0
488 | canvas[margin:margin+input_size, 3*margin+img_pad.shape[1]+input_size:3*margin+img_pad.shape[1]+2*input_size] = img1
489 | canvas[2*margin+input_size:2*margin+2*input_size, 2*margin+img_pad.shape[1]:2*margin+img_pad.shape[1]+input_size] = img2
490 | canvas[2*margin+input_size:2*margin+2*input_size, 3*margin+img_pad.shape[1]+input_size:3*margin+img_pad.shape[1]+2*input_size] = img3
491 |
492 | cv2.imwrite(os.path.join(test_dir, 'GT_' + str(idx).zfill(2) + '.png'), canvas)
493 |
494 | def test_data_cropping(img, input_size, output_size, num_blocks=4):
495 | x_batchs = np.zeros((num_blocks, input_size, input_size), dtype=np.float32)
496 |
497 | border_size = int((input_size - output_size) * 0.5)
498 | img_pad = cv2.copyMakeBorder(img, border_size, border_size, border_size, border_size, cv2.BORDER_REFLECT_101)
499 |
500 | # Crop 4 corners from the padded image
501 | x_batchs[0] = img_pad[:input_size, :input_size]
502 | x_batchs[1] = img_pad[:input_size, -input_size:]
503 | x_batchs[2] = img_pad[-input_size:, :input_size]
504 | x_batchs[3] = img_pad[-input_size:, -input_size:]
505 | return x_batchs
506 |
507 |
508 | def cropping(img, label, wmap, input_size, output_size, is_extend=False):
509 | border_size = int((input_size - output_size) * 0.5)
510 | rand_pos_h = np.random.randint(low=0, high=img.shape[0] - output_size)
511 | rand_pos_w = np.random.randint(low=0, high=img.shape[1] - output_size)
512 |
513 | img_pad = cv2.copyMakeBorder(img, border_size, border_size, border_size, border_size, cv2.BORDER_REFLECT_101)
514 | img_crop = img_pad[rand_pos_h:rand_pos_h+input_size, rand_pos_w:rand_pos_w+input_size].copy()
515 | label_crop = label[rand_pos_h:rand_pos_h+output_size, rand_pos_w:rand_pos_w+output_size].copy()
516 | wmap_crop = wmap[rand_pos_h:rand_pos_h+output_size, rand_pos_w:rand_pos_w+output_size].copy()
517 |
518 | if is_extend:
519 | return img_crop, label_crop, wmap_crop, img_pad, rand_pos_h, rand_pos_w
520 | else:
521 | return img_crop, label_crop, wmap_crop
522 |
523 |
524 | def merge_rotated_preds(preds, img, iter_time, start, stop, num, test_dir, margin=5, alpha=0.6, is_save=False):
525 | inv_preds = np.zeros((num, *preds[0].shape), dtype=np.float32)
526 |
527 | for i, angle in enumerate(np.linspace(start=start, stop=stop, num=num, endpoint=False)):
528 | pred = preds[i].copy()
529 | inv_angle = 360. - angle
530 | inv_preds[i] = rotate(input=pred, angle=inv_angle, axes=(0, 1), reshape=False, order=0, mode='constant', cval=0.)
531 |
532 | y_pred = np.zeros_like(inv_preds[0])
533 | for i in range(num):
534 | y_pred += inv_preds[i]
535 |
536 | y_pred_cls = np.uint8(np.argmax(y_pred, axis=2) * 255.)
537 |
538 | if is_save:
539 | h, w = preds[0].shape[0], preds[0].shape[1]
540 | canvas = np.zeros((3*margin+2*h, (num+2)*margin+(num+1)*w), dtype=np.uint8)
541 |
542 | for i in range(num):
543 | canvas[margin:margin+h, (i+1)*margin+i*w:(i+1)*margin+(i+1)*w] = np.uint(np.argmax(preds[i], axis=2) * 255.)
544 | canvas[2*margin+h:2*margin+2*h, (i+1)*margin+i*w:(i+1)*margin+(i+1)*w] = np.uint(np.argmax(inv_preds[i], axis=2) * 255.)
545 |
546 | canvas[2 * margin + h:2 * margin + 2 * h, (num+1) * margin + num * w:(num+1) * margin + (num + 1) * w] = y_pred_cls
547 | cv2.imwrite(os.path.join(test_dir, 'Merge_' + str(iter_time).zfill(2) + '.png'), canvas)
548 |
549 | # pesudo color representation
550 | pseudo_label = pseudoColor(y_pred_cls)
551 | beta = 1. - alpha
552 |
553 | img_3c = np.dstack((img, img, img))
554 | overlap = cv2.addWeighted(src1=img_3c,
555 | alpha=alpha,
556 | src2=pseudo_label,
557 | beta=beta,
558 | gamma=0.0)
559 |
560 | canvas01 = np.hstack((img_3c, overlap))
561 | cv2.imwrite(os.path.join(test_dir, 'Pred1_' + str(iter_time).zfill(2) + '.png'), canvas01)
562 |
563 | canvas02 = np.hstack((img, y_pred_cls))
564 | cv2.imwrite(os.path.join(test_dir, 'Pred2_' + str(iter_time).zfill(2) + '.png'), canvas02)
565 |
566 | return y_pred_cls
567 |
568 |
569 | def merge_preds(preds, idx, ori_size=512, output_size=388, angle=0., test_dir=None, is_save=False, margin=5):
570 | result = np.zeros((ori_size, ori_size, 2), dtype=np.float32)
571 |
572 | # Past four corners
573 | result[:output_size, :output_size] += preds[0]
574 | result[:output_size, -output_size:] += preds[1]
575 | result[-output_size:, :output_size] += preds[2]
576 | result[-output_size:, -output_size:] += preds[3]
577 |
578 | if is_save:
579 | # Score to class
580 | result_cls = np.argmax(result, axis=2)
581 |
582 | canvas = np.zeros((2*output_size+3*margin, 2*output_size+ori_size+4*margin), dtype=np.uint8)
583 |
584 | # Score to prediction class
585 | preds_cls = np.argmax(preds, axis=3)
586 |
587 | # Copy five results
588 | canvas[margin:margin+output_size, margin:margin+output_size] = np.uint8(preds_cls[0] * 255.)
589 | canvas[margin:margin+output_size, 2*margin+output_size:2*margin+2*output_size] = np.uint8(preds_cls[1] * 255.)
590 | canvas[2*margin+output_size:2*margin+2*output_size, margin:margin+output_size] = np.uint8(preds_cls[2] * 255.)
591 | canvas[2*margin+output_size:2*margin+2*output_size, 2*margin+output_size:2*margin+2*output_size] = np.uint8(preds_cls[3] * 255.)
592 | canvas[margin:margin+ori_size, 3*margin+2*output_size:3*margin+2*output_size+ori_size] = np.uint8(result_cls * 255.)
593 |
594 | # Save results
595 | cv2.imwrite(os.path.join(test_dir, 'Merge_' + str(idx).zfill(2) + '_' + str(int(angle)).zfill(3) + '.png'), canvas)
596 |
597 | return result
598 |
599 |
600 | def pre_bilaterFilter(img, d=3, sigmaColor=75, simgaSpace=75):
601 | pre_img = cv2.bilateralFilter(src=img, d=d, sigmaColor=sigmaColor, sigmaSpace=simgaSpace)
602 | return pre_img
603 |
604 |
605 | def acc_measure(true_arr, pred_arr):
606 | cm = confusion_matrix(true_arr, pred_arr)
607 | acc = 1. * (cm[0, 0] + cm[1, 1]) / np.sum(cm)
608 | return acc
609 |
--------------------------------------------------------------------------------