├── .gitignore ├── README.md └── src ├── dataset.py ├── main.py ├── model.py ├── preprocessing.py ├── solver.py ├── tensorflow_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | github_imgs/ 2 | src/.idea 3 | src/__pycache__ 4 | src/logs 5 | src/model 6 | src/wmap_imgs 7 | src/sample 8 | src/test 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # U-Net-TensorFlow 2 | This repository is a TensorFlow implementation of the ["U-Net: Convolutional Networks for Biomedical Image Segmentation," MICCAI2015](https://arxiv.org/pdf/1505.04597.pdf). It completely follows the original U-Net paper. 3 | 4 |

5 | 7 | 8 | ## EM Segmentation Challenge Dataset 9 |

10 | 12 | 13 | ## Requirements 14 | - tensorflow 1.13.1 15 | - python 3.5.3 16 | - numpy 1.15.2 17 | - scipy 1.1.0 18 | - tifffile 2019.3.18 19 | - opencv 3.3.1 20 | - matplotlib 2.2.2 21 | - elasticdeform 0.4.4 22 | - scikit-learn 0.20.0 23 | 24 | ## Implementation Details 25 | This implementation completely follows the original U-Net paper from the following aspects: 26 | - Input image size 572 x 572 x 1 vs output labeled image 388 x 388 x 2 27 | - Upsampling used fractional strided convolusion (deconv) 28 | - Reflection mirror padding is used for the input image 29 | - Data augmentation: random translation, random horizontal and vertical flip, random rotation, and random elastic deformation 30 | - Loss function includes weighted cross-entropy loss and regularization term 31 | - Weight map is calculated using equation 2 of the original paper 32 | - In test stage, this implementation achieves average of the 7 rotated versions of the input data 33 | 34 | ## Examples of the Data Augmentation 35 | - Random Translation 36 | - Random Horizontal and Vertical Flip 37 | - Random Rotation 38 | - Random Elastic Deformation 39 | 40 |

41 | 43 | 44 |

45 | 47 | 48 | ## Fundamental of the Different Sized Input and Output Images in Training Process 49 | - Reflected mirror padding is utilized first (white lines indicate boundaries of the image) 50 | - Randomly cropping the input image, label image, and weighted image 51 | - Blue rectangle region of the input image and red rectangle of the weight map are the inputs of the U-Net in the training, and the red rectangle of the labeled image is the ground-truth of the network. 52 | 53 |

54 | 56 | 57 | ## Test Paradigm 58 | - In test stage, each test image is the average of the 7 rotated version of the input data. The final prediction is the averaging the 7 predicted restuls. 59 | 60 |

61 | 63 | 64 |

65 | 67 | 68 | - For each rotated image, the four regions are extracted, top left, top right, bottom left, and bottom right of the each image to go through the U-Net, and the prediction is calculated averaging the overlapping scores of the four results 69 | 70 |

71 | 73 | 74 | **Note**: White lines indicate boundaries of the image. 75 | 76 |

77 | 79 | 80 | **Note:** The prediciton results of the EM Segmentation Challenge Test Dataset 81 |

82 | 84 | 85 | ## Download Dataset 86 | Download the EM Segmetnation Challenge dataset from [ISBI challenge homepage](http://brainiac2.mit.edu/isbi_challenge/). 87 | 88 | ## Documentation 89 | ### Directory Hierarchy 90 | ``` 91 | . 92 | │ U-Net 93 | │ ├── src 94 | │ │ ├── dataset.py 95 | │ │ ├── main.py 96 | │ │ ├── model.py 97 | │ │ ├── preprocessing.py 98 | │ │ ├── solver.py 99 | │ │ ├── tensorflow_utils.py 100 | │ │ └── utils.py 101 | │ Data 102 | │ └── EMSegmentation 103 | │ │ ├── test-volume.tif 104 | │ │ ├── train-labels.tif 105 | │ │ ├── train-wmaps.npy (generated in preprocessing) 106 | │ │ └── train-volume.tif 107 | ``` 108 | ### Preprocessing 109 | Weight map need to be calculated using segmentaion labels in training data first. Calculaing wegith map using on-line method in training will slow down processing time. Therefore, calculating and saving weighted map first, the weight maps are augmented according to the input and label images. Use `preprocessing.py` to calculate weight maps. Example usage: 110 | ``` 111 | python preprocessing.py 112 | ``` 113 | 114 | ### Training U-Net 115 | Use `main.py` to train the U 116 | ``` 117 | python main.py 118 | ``` 119 | - `gpu_index`: gpu index if you have multiple gpus, default: `0` 120 | - `dataset`: dataset name, default: `EMSegmentation` 121 | - `batch_size`: batch size for one iteration, default: `4` 122 | - `is_train`: training or inference (test) mode, default: `True (training mode)` 123 | - `learning_rate`: initial learning rate for optimizer, default: `1e-3` 124 | - `weight_decay`: weight decay for model to handle overfitting, default: `1e-4` 125 | - `iters`: number of iterations, default: `20,000` 126 | - `print_freq`: print frequency for loss information, default: `10` 127 | - `sample_freq`: sample frequence for checking qualitative evaluation, default: `100` 128 | - `eval_freq`: evaluation frequency for evluation of the batch accuracy, default: `200` 129 | - `load_model`: folder of saved model that you wish to continue training, (e.g. 20190524-1606), default: `None` 130 | 131 | ### Test U-Net 132 | Use `main.py` to test the models. Example usage: 133 | ``` 134 | python main.py --is_train=False --load_model=folder/you/wish/to/test/e.g./20190524-1606 135 | ``` 136 | Please refer to the above arguments. 137 | 138 | ### Tensorboard Visualization 139 | **Note:** The following figure shows data loss, weighted data loss, regularization term, and total loss during training process. The batch accuracy also is given in tensorboard. 140 | 141 |

142 | 144 | 145 | ### Citation 146 | ``` 147 | @misc{chengbinjin2019u-nettensorflow, 148 | author = {Cheng-Bin Jin}, 149 | title = {U-Net Tensorflow}, 150 | year = {2019}, 151 | howpublished = {\url{https://github.com/ChengBinJin/U-Net-TensorFlow}}, 152 | note = {commit xxxxxxx} 153 | } 154 | ``` 155 | 156 | ### Attributions/Thanks 157 | - This project borrowed some code from [Zhixuhao](https://github.com/zhixuhao/unet) 158 | - Some readme formatting was borrowed from [Logan Engstrom](https://github.com/lengstrom/fast-style-transfer) 159 | 160 | ## License 161 | Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: sbkim0407@gmail.com). Free for research use, as long as proper attribution is given and this copyright notice is retained. 162 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow U-Net Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import logging 9 | import numpy as np 10 | import tifffile as tiff 11 | from scipy.ndimage import rotate 12 | 13 | import utils as utils 14 | 15 | logger = logging.getLogger(__name__) # logger 16 | logger.setLevel(logging.INFO) 17 | 18 | 19 | class Dataset(object): 20 | def __init__(self, name='EMSegmentation', log_dir=None): 21 | # It is depended on dataset 22 | self.input_size = 572 23 | self.output_size = 388 24 | self.input_channel = 1 25 | self.input_shape = (self.input_size, self.input_size, self.input_channel) 26 | self.output_shape = (self.output_size, self.output_size) 27 | 28 | self.name = name 29 | self.dataset_path = '../../Data/EMSegmentation' 30 | 31 | self.train_imgs = tiff.imread(os.path.join(self.dataset_path, 'train-volume.tif')) 32 | self.train_labels = tiff.imread(os.path.join(self.dataset_path, 'train-labels.tif')) 33 | self.train_wmaps = np.load(os.path.join(self.dataset_path, 'train-wmaps.npy')) 34 | self.test_imgs = tiff.imread(os.path.join(self.dataset_path, 'test-volume.tif')) 35 | self.mean_value = np.mean(self.train_imgs) 36 | 37 | self.num_train = self.train_imgs.shape[0] 38 | self.num_test = self.test_imgs.shape[0] 39 | self.img_shape = self.train_imgs[0].shape 40 | 41 | self.init_logger(log_dir) 42 | 43 | @staticmethod 44 | def init_logger(log_dir): 45 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s') 46 | 47 | # file handler 48 | file_handler = logging.FileHandler(os.path.join(log_dir, 'dataset.log')) 49 | file_handler.setFormatter(formatter) 50 | 51 | # stream handler 52 | stream_handler = logging.StreamHandler() 53 | stream_handler.setFormatter(formatter) 54 | 55 | # add handlers 56 | logger.addHandler(file_handler) 57 | logger.addHandler(stream_handler) 58 | 59 | def info(self, use_logging=True, log_dir=None): 60 | if use_logging: 61 | logger.info('- Training-img set:\t{}'.format(self.train_imgs.shape)) 62 | logger.info('- Training-label set:\t{}'.format(self.train_labels.shape)) 63 | logger.info('- Training-wmap set:\t{}'.format(self.train_wmaps.shape)) 64 | logger.info('- Test-img set:\t\t{}'.format(self.test_imgs.shape)) 65 | 66 | logger.info('- image shape:\t\t{}'.format(self.img_shape)) 67 | else: 68 | print('- Training-img set:\t{}'.format(self.train_imgs.shape)) 69 | print('- Training-label set:\t{}'.format(self.train_labels.shape)) 70 | print('- Training-wmap set:\t{}'.format(self.train_wmaps.shape)) 71 | print('- Test-img set:\t\t{}'.format(self.test_imgs.shape)) 72 | print('- image shape:\t\t{}'.format(self.img_shape)) 73 | 74 | print(' [*] Saving data augmented images to check U-Net fundamentals...') 75 | for idx in range(self.num_train): 76 | img_, label_, wmap_ = self.train_imgs[idx], self.train_labels[idx], self.train_wmaps[idx] 77 | utils.imshow(img_, label_, wmap_, idx, log_dir=log_dir) 78 | utils.test_augmentation(img_, label_, wmap_, idx, log_dir=log_dir) 79 | utils.test_cropping(img_, label_, wmap_, idx, self.input_size, self.output_size, log_dir=log_dir) 80 | print(' [!] Saving data augmented images to check U-Net fundamentals!') 81 | 82 | def info_test(self, test_dir): 83 | print(' [*] Saving test data cropping to check U-Net fundamentals...') 84 | for idx in range(self.num_test): 85 | img = self.test_imgs[idx] 86 | utils.test_imshow(img, idx, test_dir=test_dir) 87 | utils.test_rotate(img, idx, test_dir=test_dir) 88 | print(' [!] Saving test data cropping to check U-net fundamentals!') 89 | 90 | 91 | def random_batch(self, idx, batch_size=2): 92 | idx = idx % self.num_train 93 | x_img, y_label, w_map = self.train_imgs[idx], self.train_labels[idx], self.train_wmaps[idx] 94 | 95 | x_batchs = np.zeros((batch_size, self.input_size, self.input_size), dtype=np.float32) 96 | # y_batchs will be represented in one-hot in solver.train() 97 | y_batchs = np.zeros((batch_size, self.output_size, self.output_size), dtype=np.float32) 98 | w_batchs = np.zeros((batch_size, self.output_size, self.output_size), dtype=np.float32) 99 | for idx in range(batch_size): 100 | # Random translation 101 | x_batch, y_batch, w_batch = utils.aug_translate(x_img, y_label, w_map) 102 | 103 | # Random horizontal and vertical flip 104 | x_batch, y_batch, w_batch = utils.aug_flip(x_batch, y_batch, w_batch) 105 | 106 | # Random rotation 107 | x_batch, y_batch, w_batch = utils.aug_rotate(x_batch, y_batch, w_batch) 108 | 109 | # Random elastic deformation 110 | x_batch, y_batch, w_batch = utils.aug_elastic_deform(x_batch, y_batch, w_batch) 111 | 112 | # Following the originl U-Net paper 113 | # Resize image to 696(696=572+92*2) x 696(696=572+92*2) then crop 572 x 572 input image 114 | # and 388 x 388 lable map 115 | # 92 = (572 - 388) / 2 116 | x_batch, y_batch, w_batch = utils.cropping(x_batch, y_batch, w_batch, self.input_size, self.output_size) 117 | 118 | x_batchs[idx, :, :] = x_batch 119 | y_batchs[idx, :, :] = y_batch 120 | w_batchs[idx, :, :] = w_batch 121 | 122 | return self.zero_centering(x_batchs), (y_batchs / 255).astype(np.uint8), w_batchs 123 | 124 | def test_batch(self, idx, angle): 125 | x_img_ori = self.test_imgs[idx] 126 | 127 | # Rotate inage 128 | x_img = rotate(input=x_img_ori, angle=angle, axes=(0, 1), reshape=False, order=3, mode='reflect') 129 | x_batchs = utils.test_data_cropping(img=x_img, 130 | input_size=self.input_size, 131 | output_size=self.output_size, 132 | num_blocks=4) 133 | 134 | return self.zero_centering(x_batchs), x_img_ori 135 | 136 | def zero_centering(self, imgs): 137 | return imgs - self.mean_value 138 | 139 | 140 | # if __name__ == '__main__': 141 | # data = Dataset() 142 | # 143 | # for i in range(data.num_train): 144 | # img, label = data.train_imgs[i], data.train_labels[i] 145 | # utils.imshow(img, label, idx=i) 146 | 147 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow U-Net Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import logging 9 | import numpy as np 10 | import tensorflow as tf 11 | from datetime import datetime 12 | 13 | from dataset import Dataset 14 | from model import Model 15 | from solver import Solver 16 | import utils as utils 17 | 18 | FLAGS = tf.flags.FLAGS 19 | tf.flags.DEFINE_string('gpu_index', '0', 'gpu index if you have multiple gpus, default: 0') 20 | tf.flags.DEFINE_string('dataset', 'EMSegmentation', 'dataset name, default: EMSegmentation') 21 | tf.flags.DEFINE_integer('batch_size', 4, 'batch size for one iteration, default: 4') 22 | tf.flags.DEFINE_bool('is_train', True, 'training or inference mode, default: True') 23 | tf.flags.DEFINE_float('learning_rate', 1e-3, 'initial learning rate for optimizer, default: 0.001') 24 | tf.flags.DEFINE_float('weight_decay', 1e-4, 'weight decay for model to handle overfitting, default: 0.0001') 25 | tf.flags.DEFINE_integer('iters', 20000, 'number of iterations, default: 20,000') 26 | tf.flags.DEFINE_integer('print_freq', 10, 'print frequency for loss information, default: 10') 27 | tf.flags.DEFINE_integer('sample_freq', 100, 'sample frequence for checking qualitative evaluation, default: 100') 28 | tf.flags.DEFINE_integer('eval_freq', 200, 'evaluation frequency for evaluation of the batch accuracy, default: 200') 29 | tf.flags.DEFINE_string('load_model', None, 'folder of saved model that you wish to continue training ' 30 | '(e.g. 20190524-1606), default: None') 31 | 32 | logger = logging.getLogger(__name__) # logger 33 | logger.setLevel(logging.INFO) 34 | 35 | def init_logger(log_dir, is_train=True): 36 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s') 37 | 38 | # file handler 39 | file_handler = logging.FileHandler(os.path.join(log_dir, 'main.log')) 40 | file_handler.setFormatter(formatter) 41 | file_handler.setLevel(logging.INFO) 42 | 43 | # stream handler 44 | stream_handler = logging.StreamHandler() 45 | stream_handler.setFormatter(formatter) 46 | 47 | # add handlers 48 | logger.addHandler(file_handler) 49 | logger.addHandler(stream_handler) 50 | 51 | if is_train: 52 | logger.info('gpu_index: \t\t{}'.format(FLAGS.gpu_index)) 53 | logger.info('dataset: \t\t{}'.format(FLAGS.dataset)) 54 | logger.info('batch_size: \t\t{}'.format(FLAGS.batch_size)) 55 | logger.info('is_train: \t\t{}'.format(FLAGS.is_train)) 56 | logger.info('learning_rate: \t{}'.format(FLAGS.learning_rate)) 57 | logger.info('weight_decay: \t\t{}'.format(FLAGS.weight_decay)) 58 | logger.info('iters: \t\t{}'.format(FLAGS.iters)) 59 | logger.info('print_freq: \t\t{}'.format(FLAGS.print_freq)) 60 | logger.info('sample_freq: \t\t{}'.format(FLAGS.sample_freq)) 61 | logger.info('eval_freq: \t\t{}'.format(FLAGS.eval_freq)) 62 | logger.info('load_model: \t\t{}'.format(FLAGS.load_model)) 63 | else: 64 | print('-- gpu_index: \t\t{}'.format(FLAGS.gpu_index)) 65 | print('-- dataset: \t\t{}'.format(FLAGS.dataset)) 66 | print('-- batch_size: \t\t{}'.format(FLAGS.batch_size)) 67 | print('-- is_train: \t\t{}'.format(FLAGS.is_train)) 68 | print('-- learning_rate: \t{}'.format(FLAGS.learning_rate)) 69 | print('-- weight_decay: \t\t{}'.format(FLAGS.weight_decay)) 70 | print('-- iters: \t\t{}'.format(FLAGS.iters)) 71 | print('-- print_freq: \t\t{}'.format(FLAGS.print_freq)) 72 | print('-- sample_freq: \t\t{}'.format(FLAGS.sample_freq)) 73 | print('-- eval_freq: \t\t{}'.format(FLAGS.eval_freq)) 74 | print('-- load_model: \t\t{}'.format(FLAGS.load_model)) 75 | 76 | 77 | def main(_): 78 | os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_index 79 | 80 | # Initialize model and log folders 81 | if FLAGS.load_model is None: 82 | cur_time = datetime.now().strftime("%Y%m%d-%H%M") 83 | else: 84 | cur_time = FLAGS.load_model 85 | 86 | model_dir, log_dir, sample_dir, test_dir = utils.make_folders(is_train=FLAGS.is_train, cur_time=cur_time) 87 | init_logger(log_dir=log_dir, is_train=FLAGS.is_train) 88 | 89 | # Initilize dataset 90 | data = Dataset(name=FLAGS.dataset, log_dir=log_dir) 91 | data.info(use_logging=True, log_dir=log_dir) 92 | 93 | # Initialize session 94 | sess = tf.Session() 95 | 96 | # Initilize model 97 | model = Model(input_shape=data.input_shape, 98 | output_shape=data.output_shape, 99 | lr=FLAGS.learning_rate, 100 | weight_decay=FLAGS.weight_decay, 101 | total_iters=FLAGS.iters, 102 | is_train=FLAGS.is_train, 103 | log_dir=log_dir, 104 | name='U-Net') 105 | 106 | # Initilize solver 107 | solver = Solver(sess, model, data.mean_value) 108 | saver = tf.train.Saver(max_to_keep=1) 109 | 110 | if FLAGS.is_train: 111 | train(data, solver, saver, model_dir, log_dir, sample_dir) 112 | else: 113 | test(data, solver, saver, model_dir, test_dir) 114 | 115 | 116 | def train(data, solver, saver, model_dir, log_dir, sample_dir): 117 | best_acc = 0. 118 | num_evals = 0 119 | tb_writer = tf.summary.FileWriter(log_dir, graph_def=solver.sess.graph_def) 120 | solver.init() 121 | 122 | iter_time = 0 123 | if FLAGS.load_model is not None: 124 | flag, iter_time, best_acc = load_model(saver, solver, model_dir, is_train=True) 125 | logger.info(' [!] Load Success! Iter: {}, Best acc: {:.3f}'.format(iter_time, best_acc)) 126 | 127 | for iter_time in range(iter_time, FLAGS.iters): 128 | x_batch, y_batch, w_batch = data.random_batch(batch_size=FLAGS.batch_size, idx=iter_time) 129 | _, total_loss, avg_data_loss, weighted_data_loss, reg_term, summary, pred_cls = \ 130 | solver.train(x_batch, y_batch, w_batch) 131 | 132 | # Write to tensorbard 133 | tb_writer.add_summary(summary, iter_time) 134 | tb_writer.flush() 135 | 136 | if np.mod(iter_time, FLAGS.print_freq) == 0: 137 | msg = '{}/{}: \tTotal loss: {:.3f}, \tAvg. data loss: {:.3f}, \tWeighted data loss: {:.3f} \tReg. term: {:.3f}' 138 | print(msg.format(iter_time, FLAGS.iters, total_loss, avg_data_loss, weighted_data_loss, reg_term)) 139 | 140 | if np.mod(iter_time, FLAGS.sample_freq) == 0: 141 | solver.save_imgs(x_batch, pred_cls, y_batch, iter_time, sample_dir) 142 | 143 | if np.mod(iter_time, FLAGS.eval_freq) == 0: 144 | x_batch, y_batch, w_batch = data.random_batch(batch_size=FLAGS.batch_size * 20, 145 | idx=np.random.randint(low=0, high=FLAGS.iters)) 146 | acc, summary = solver.evalate(x_batch, y_batch, batch_size=FLAGS.batch_size) 147 | print('Evaluation! \tAcc: {:.3f} \tBest Acc: {:.3f}'.format(acc, best_acc)) 148 | 149 | # Write to tensorboard 150 | tb_writer.add_summary(summary, num_evals) 151 | tb_writer.flush() 152 | num_evals += 1 153 | 154 | if acc > best_acc: 155 | logger.info('Acc: {:.3f}, Best Acc: {:.3f}'.format(acc, best_acc)) 156 | best_acc = acc 157 | save_model(saver, solver, model_dir, iter_time, best_acc) 158 | 159 | 160 | def test(data, solver, saver, model_dir, test_dir, start=0, stop=360, num=7): 161 | # Load checkpoint 162 | flag, iter_time, best_acc = load_model(saver, solver, model_dir, is_train=False) 163 | if flag is True: 164 | print(' [!] Load Success! Iter: {}, Best acc: {:.3f}'.format(iter_time, best_acc)) 165 | else: 166 | print(' [!] Load Failed!') 167 | 168 | # Test 169 | data.info_test(test_dir) 170 | 171 | for iter_time in range(data.num_test): 172 | print('iter: {}'.format(iter_time)) 173 | 174 | y_preds = np.zeros((num, *data.img_shape, 2), dtype=np.float32) # [N, H, W, 2] 175 | x_ori_img = None 176 | for i, angle in enumerate(np.linspace(start=start, stop=stop, num=num, endpoint=False)): 177 | x_batchs, x_ori_img = data.test_batch(iter_time, angle) # four corpped image for one test image 178 | y_preds[i] = solver.test(x_batchs, iter_time, angle, test_dir, is_save=True) 179 | 180 | # Merge rotated label images 181 | utils.merge_rotated_preds(y_preds, x_ori_img, iter_time, start, stop, num, test_dir, is_save=True) 182 | 183 | 184 | def save_model(saver, solver, model_dir, iter_time, best_acc): 185 | solver.save_acc_record(best_acc) 186 | saver.save(solver.sess, os.path.join(model_dir, 'model'), global_step=iter_time) 187 | logger.info(' [*] Model saved! Iter: {}, Best Acc.: {:.3f}'.format(iter_time, best_acc)) 188 | 189 | 190 | def load_model(saver, solver, model_dir, is_train=False): 191 | if is_train: 192 | logger.info(' [*] Reading checkpoint...') 193 | else: 194 | print(' [*] Reading checkpoint...') 195 | 196 | ckpt = tf.train.get_checkpoint_state(model_dir) 197 | if ckpt and ckpt.model_checkpoint_path: 198 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 199 | saver.restore(solver.sess, os.path.join(model_dir, ckpt_name)) 200 | 201 | meta_graph_path = ckpt.model_checkpoint_path + '.meta' 202 | iter_time = int(meta_graph_path.split('-')[-1].split('.')[0]) 203 | best_acc = solver.load_acc_record() 204 | 205 | if is_train: 206 | logger.info(' [!] Load Iter: {}, Best Acc.: {:.3f}'.format(iter_time, best_acc)) 207 | 208 | return True, iter_time, best_acc 209 | else: 210 | return False, 0, 0. 211 | 212 | 213 | 214 | if __name__ == '__main__': 215 | tf.app.run() 216 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow U-Net Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import tensorflow as tf 8 | import utils as utils 9 | import tensorflow_utils as tf_utils 10 | 11 | class Model(object): 12 | def __init__(self, input_shape, output_shape, lr=0.001, weight_decay=1e-4, total_iters=2e4, is_train=True, 13 | log_dir=None, name='U-Net'): 14 | self.input_shape = input_shape 15 | self.output_shape = output_shape 16 | self.conv_dims = [64, 64, 128, 128, 256, 256, 512, 512, 1024, 1024, 17 | 512, 512, 512, 256, 256, 256, 128, 128, 128, 64, 64, 64, 2] 18 | self.lr = lr 19 | self.weight_decay = weight_decay 20 | self.total_steps = total_iters 21 | self.start_decay_step = int(self.total_steps * 0.5) 22 | self.decay_steps = self.total_steps - self.start_decay_step 23 | self.is_train = is_train 24 | self.log_dir = log_dir 25 | self.name = name 26 | self.tb_lr, self.pred = None, None 27 | self.logger, self.file_handler, self.stream_handler = utils.init_logger(log_dir=self.log_dir, 28 | name=self.name, 29 | is_train=self.is_train) 30 | 31 | with tf.variable_scope(self.name): 32 | self._build_net() 33 | 34 | self._tensorboard() 35 | tf_utils.show_all_variables(logger=self.logger if self.is_train else None) 36 | 37 | 38 | def _build_net(self): 39 | # Input placeholders 40 | self.inp_img = tf.placeholder(dtype=tf.float32, shape=[None, *self.input_shape], name='input_img') 41 | self.out_img = tf.placeholder(dtype=tf.uint8, shape=[None, *self.output_shape], name='output_img') 42 | self.weight_map = tf.placeholder(dtype=tf.float32, shape=[None, *self.output_shape], name='weight_map') 43 | self.keep_prob = tf.placeholder(dtype=tf.float32, name='keep_prob') 44 | self.val_acc = tf.placeholder(dtype=tf.float32, name='val_acc') 45 | 46 | # Best acc record 47 | self.best_acc_record = tf.get_variable(name='best_acc_save', dtype=tf.float32, initializer=tf.constant(0.), 48 | trainable=False) 49 | 50 | # One-hot representation 51 | self.out_img_one_hot = tf.one_hot(indices=self.out_img, depth=2, axis=-1, dtype=tf.float32, name='one_hot') 52 | 53 | # U-Net building 54 | self.u_net() 55 | self.pred_cls = tf.math.argmax(input=self.pred, axis=-1) 56 | 57 | # Data loss 58 | self.data_loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.out_img_one_hot, logits=self.pred) 59 | self.avg_data_loss = tf.reduce_mean(self.data_loss) 60 | self.weighted_data_loss = tf.reduce_mean(self.weight_map * self.data_loss) 61 | 62 | # Regularization term 63 | self.reg_term = self.weight_decay * tf.reduce_sum( 64 | [tf.nn.l2_loss(weight) for weight in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)]) 65 | 66 | # Total loss = Data loss + Regularization term 67 | self.total_loss = self.weighted_data_loss + self.reg_term 68 | 69 | # Optimizer 70 | self.train_op = self.optimizer_fn(loss=self.total_loss, name='Adam') 71 | 72 | # Accuracy 73 | self.accuracy = tf.reduce_mean( 74 | tf.cast(tf.math.equal(x=tf.cast(tf.math.argmax(self.out_img_one_hot, axis=-1), dtype=tf.uint8), 75 | y=tf.cast(self.pred_cls, dtype=tf.uint8)), dtype=tf.float32)) * 100. 76 | self.save_best_acc_op = tf.assign(self.best_acc_record, value=self.val_acc) 77 | 78 | def u_net(self): 79 | # Stage 1 80 | tf_utils.print_activations(self.inp_img, logger=self.logger) 81 | s1_conv1 = tf_utils.conv2d(x=self.inp_img, output_dim=self.conv_dims[0], k_h=3, k_w=3, d_h=1, d_w=1, 82 | padding='VALID', initializer='He', name='s1_conv1', logger=self.logger) 83 | s1_conv1 = tf_utils.relu(s1_conv1, name='relu_s1_conv1', logger=self.logger) 84 | s1_conv2 = tf_utils.conv2d(x=s1_conv1, output_dim=self.conv_dims[1], k_h=3, k_w=3, d_h=1, d_w=1, 85 | padding='VALID', initializer='He', name='s1_conv2', logger=self.logger) 86 | s1_conv2 = tf_utils.relu(s1_conv2, name='relu_s1_conv2', logger=self.logger) 87 | 88 | # Stage 2 89 | s2_maxpool = tf_utils.max_pool(x=s1_conv2, name='s2_maxpool', logger=self.logger) 90 | s2_conv1 = tf_utils.conv2d(x=s2_maxpool, output_dim=self.conv_dims[2], k_h=3, k_w=3, d_h=1, d_w=1, 91 | padding='VALID', initializer='He', name='s2_conv1', logger=self.logger) 92 | s2_conv1 = tf_utils.relu(s2_conv1, name='relu_s2_conv1', logger=self.logger) 93 | s2_conv2 = tf_utils.conv2d(x=s2_conv1, output_dim=self.conv_dims[3], k_h=3, k_w=3, d_h=1, d_w=1, 94 | padding='VALID', initializer='He', name='s2_conv2', logger=self.logger) 95 | s2_conv2 = tf_utils.relu(s2_conv2, name='relu_s2_conv2', logger=self.logger) 96 | 97 | # Stage 3 98 | s3_maxpool = tf_utils.max_pool(x=s2_conv2, name='s3_maxpool', logger=self.logger) 99 | s3_conv1 = tf_utils.conv2d(x=s3_maxpool, output_dim=self.conv_dims[4], k_h=3, k_w=3, d_h=1, d_w=1, 100 | padding='VALID', initializer='He', name='s3_conv1', logger=self.logger) 101 | s3_conv1 = tf_utils.relu(s3_conv1, name='relu_s3_conv1', logger=self.logger) 102 | s3_conv2 = tf_utils.conv2d(x=s3_conv1, output_dim=self.conv_dims[5], k_h=3, k_w=3, d_h=1, d_w=1, 103 | padding='VALID', initializer='He', name='s3_conv2', logger=self.logger) 104 | s3_conv2 = tf_utils.relu(s3_conv2, name='relu_s3_conv2', logger=self.logger) 105 | 106 | # Stage 4 107 | s4_maxpool = tf_utils.max_pool(x=s3_conv2, name='s4_maxpool', logger=self.logger) 108 | s4_conv1 = tf_utils.conv2d(x=s4_maxpool, output_dim=self.conv_dims[6], k_h=3, k_w=3, d_h=1, d_w=1, 109 | padding='VALID', initializer='He', name='s4_conv1', logger=self.logger) 110 | s4_conv1 = tf_utils.relu(s4_conv1, name='relu_s4_conv1', logger=self.logger) 111 | s4_conv2 = tf_utils.conv2d(x=s4_conv1, output_dim=self.conv_dims[7], k_h=3, k_w=3, d_h=1, d_w=1, 112 | padding='VALID', initializer='He', name='s4_conv2', logger=self.logger) 113 | s4_conv2 = tf_utils.relu(s4_conv2, name='relu_s4_conv2', logger=self.logger) 114 | s4_conv2_drop = tf_utils.dropout(x=s4_conv2, keep_prob=self.keep_prob, name='s4_conv2_dropout', 115 | logger=self.logger) 116 | 117 | # Stage 5 118 | s5_maxpool = tf_utils.max_pool(x=s4_conv2_drop, name='s5_maxpool', logger=self.logger) 119 | s5_conv1 = tf_utils.conv2d(x=s5_maxpool, output_dim=self.conv_dims[8], k_h=3, k_w=3, d_h=1, d_w=1, 120 | padding='VALID', initializer='He', name='s5_conv1', logger=self.logger) 121 | s5_conv1 = tf_utils.relu(s5_conv1, name='relu_s5_conv1', logger=self.logger) 122 | s5_conv2 = tf_utils.conv2d(x=s5_conv1, output_dim=self.conv_dims[9], k_h=3, k_w=3, d_h=1, d_w=1, 123 | padding='VALID', initializer='He', name='s5_conv2', logger=self.logger) 124 | s5_conv2 = tf_utils.relu(s5_conv2, name='relu_s5_conv2', logger=self.logger) 125 | s5_conv2_drop = tf_utils.dropout(x=s5_conv2, keep_prob=self.keep_prob, name='s5_conv2_dropout', 126 | logger=self.logger) 127 | 128 | # Stage 6 129 | s6_deconv1 = tf_utils.deconv2d(x=s5_conv2_drop, output_dim=self.conv_dims[10], k_h=2, k_w=2, initializer='He', 130 | name='s6_deconv1', logger=self.logger) 131 | s6_deconv1 = tf_utils.relu(s6_deconv1, name='relu_s6_deconv1', logger=self.logger) 132 | 133 | # Cropping 134 | h1, w1 = s4_conv2_drop.get_shape().as_list()[1:3] 135 | h2, w2 = s6_deconv1.get_shape().as_list()[1:3] 136 | s4_conv2_crop = tf.image.crop_to_bounding_box(image=s4_conv2_drop, 137 | offset_height=int(0.5 * (h1 - h2)), 138 | offset_width=int(0.5 * (w1 - w2)), 139 | target_height=h2, 140 | target_width=w2) 141 | tf_utils.print_activations(s4_conv2_crop, logger=self.logger) 142 | 143 | s6_concat = tf_utils.concat(values=[s4_conv2_crop, s6_deconv1], axis=3, name='s6_concat', logger=self.logger) 144 | s6_conv2 = tf_utils.conv2d(x=s6_concat, output_dim=self.conv_dims[11], k_h=3, k_w=3, d_h=1, d_w=1, 145 | padding='VALID', initializer='He', name='s6_conv2', logger=self.logger) 146 | s6_conv2 = tf_utils.relu(s6_conv2, name='relu_s6_conv2', logger=self.logger) 147 | s6_conv3 = tf_utils.conv2d(x=s6_conv2, output_dim=self.conv_dims[12], k_h=3, k_w=3, d_h=1, d_w=1, 148 | padding='VALID', initializer='He', name='s6_conv3', logger=self.logger) 149 | s6_conv3 = tf_utils.relu(s6_conv3, name='relu_s6_conv3', logger=self.logger) 150 | 151 | # Stage 7 152 | s7_deconv1 = tf_utils.deconv2d(x=s6_conv3, output_dim=self.conv_dims[13], k_h=2, k_w=2, initializer='He', 153 | name='s7_deconv1', logger=self.logger) 154 | s7_deconv1 = tf_utils.relu(s7_deconv1, name='relu_s7_deconv1', logger=self.logger) 155 | # Cropping 156 | h1, w1 = s3_conv2.get_shape().as_list()[1:3] 157 | h2, w2 = s7_deconv1.get_shape().as_list()[1:3] 158 | s3_conv2_crop = tf.image.crop_to_bounding_box(image=s3_conv2, 159 | offset_height=int(0.5 * (h1 - h2)), 160 | offset_width=int(0.5 * (w1 - w2)), 161 | target_height=h2, 162 | target_width=w2) 163 | tf_utils.print_activations(s3_conv2_crop, logger=self.logger) 164 | 165 | s7_concat = tf_utils.concat(values=[s3_conv2_crop, s7_deconv1], axis=3, name='s7_concat', logger=self.logger) 166 | s7_conv2 = tf_utils.conv2d(x=s7_concat, output_dim=self.conv_dims[14], k_h=3, k_w=3, d_h=1, d_w=1, 167 | padding='VALID', initializer='He', name='s7_conv2', logger=self.logger) 168 | s7_conv2 = tf_utils.relu(s7_conv2, name='relu_s7_conv2', logger=self.logger) 169 | s7_conv3 = tf_utils.conv2d(x=s7_conv2, output_dim=self.conv_dims[15], k_h=3, k_w=3, d_h=1, d_w=1, 170 | padding='VALID', initializer='He', name='s7_conv3', logger=self.logger) 171 | s7_conv3 = tf_utils.relu(s7_conv3, name='relu_s7_conv3', logger=self.logger) 172 | 173 | # Stage 8 174 | s8_deconv1 = tf_utils.deconv2d(x=s7_conv3, output_dim=self.conv_dims[16], k_h=2, k_w=2, initializer='He', 175 | name='s8_deconv1', logger=self.logger) 176 | s8_deconv1 = tf_utils.relu(s8_deconv1, name='relu_s8_deconv1', logger=self.logger) 177 | # Cropping 178 | h1, w1 = s2_conv2.get_shape().as_list()[1:3] 179 | h2, w2 = s8_deconv1.get_shape().as_list()[1:3] 180 | s2_conv2_crop = tf.image.crop_to_bounding_box(image=s2_conv2, 181 | offset_height=int(0.5 * (h1 - h2)), 182 | offset_width=int(0.5 * (w1 - w2)), 183 | target_height=h2, 184 | target_width=w2) 185 | tf_utils.print_activations(s2_conv2_crop, logger=self.logger) 186 | 187 | s8_concat = tf_utils.concat(values=[s2_conv2_crop, s8_deconv1], axis=3, name='s8_concat', logger=self.logger) 188 | s8_conv2 = tf_utils.conv2d(x=s8_concat, output_dim=self.conv_dims[17], k_h=3, k_w=3, d_h=1, d_w=1, 189 | padding='VALID', initializer='He', name='s8_conv2', logger=self.logger) 190 | s8_conv2 = tf_utils.relu(s8_conv2, name='relu_s8_conv2', logger=self.logger) 191 | s8_conv3 = tf_utils.conv2d(x=s8_conv2, output_dim=self.conv_dims[18], k_h=3, k_w=3, d_h=1, d_w=1, 192 | padding='VALID', initializer='He', name='s8_conv3', logger=self.logger) 193 | s8_conv3 = tf_utils.relu(s8_conv3, name='relu_conv3', logger=self.logger) 194 | 195 | # Stage 9 196 | s9_deconv1 = tf_utils.deconv2d(x=s8_conv3, output_dim=self.conv_dims[19], k_h=2, k_w=2, initializer='He', 197 | name='s9_deconv1', logger=self.logger) 198 | s9_deconv1 = tf_utils.relu(s9_deconv1, name='relu_s9_deconv1', logger=self.logger) 199 | # Cropping 200 | h1, w1 = s1_conv2.get_shape().as_list()[1:3] 201 | h2, w2 = s9_deconv1.get_shape().as_list()[1:3] 202 | s1_conv2_crop = tf.image.crop_to_bounding_box(image=s1_conv2, 203 | offset_height=int(0.5 * (h1 - h2)), 204 | offset_width=int(0.5 * (w1 - w2)), 205 | target_height=h2, 206 | target_width=w2) 207 | tf_utils.print_activations(s1_conv2_crop, logger=self.logger) 208 | 209 | s9_concat = tf_utils.concat(values=[s1_conv2_crop, s9_deconv1], axis=3, name='s9_concat', logger=self.logger) 210 | s9_conv2 = tf_utils.conv2d(x=s9_concat, output_dim=self.conv_dims[20], k_h=3, k_w=3, d_h=1, d_w=1, 211 | padding='VALID', initializer='He', name='s9_conv2', logger=self.logger) 212 | s9_conv2 = tf_utils.relu(s9_conv2, name='relu_s9_conv2', logger=self.logger) 213 | s9_conv3 = tf_utils.conv2d(x=s9_conv2, output_dim=self.conv_dims[21], k_h=3, k_w=3, d_h=1, d_w=1, 214 | padding='VALID', initializer='He', name='s9_conv3', logger=self.logger) 215 | s9_conv3 = tf_utils.relu(s9_conv3, name='relu_s9_conv3', logger=self.logger) 216 | self.pred = tf_utils.conv2d(x=s9_conv3, output_dim=self.conv_dims[22], k_h=1, k_w=1, d_h=1, d_w=1, 217 | padding='SAME', initializer='He', name='output', logger=self.logger) 218 | 219 | def optimizer_fn(self,loss, name=None): 220 | with tf.variable_scope(name): 221 | global_step = tf.Variable(0, dtype=tf.float32, trainable=False) 222 | start_learning_rate = self.lr 223 | end_learning_rate = 0. 224 | start_decay_step = self.start_decay_step 225 | decay_steps = self.decay_steps 226 | 227 | learning_rate = (tf.where(tf.greater_equal(global_step, start_decay_step), 228 | tf.train.polynomial_decay(start_learning_rate, 229 | global_step - start_decay_step, 230 | decay_steps, end_learning_rate, power=1.0), 231 | start_learning_rate)) 232 | self.tb_lr = tf.summary.scalar('learning_rate', learning_rate) 233 | 234 | learn_step = tf.train.AdamOptimizer(learning_rate, beta1=0.99).minimize(loss, global_step=global_step) 235 | 236 | return learn_step 237 | 238 | def _tensorboard(self): 239 | self.tb_total = tf.summary.scalar('Loss/total', self.total_loss) 240 | self.tb_data = tf.summary.scalar('Loss/avg_data', self.avg_data_loss) 241 | self.tb_weighted_data = tf.summary.scalar('Loss/weighted_data', self.weighted_data_loss) 242 | self.tb_reg = tf.summary.scalar('Loss/reg_term', self.reg_term) 243 | self.summary_op = tf.summary.merge( 244 | inputs=[self.tb_lr, self.tb_total, self.tb_data, self.tb_weighted_data, self.tb_reg]) 245 | 246 | self.val_acc_op = tf.summary.scalar('Acc', self.val_acc) 247 | 248 | def release_handles(self): 249 | utils.release_handles(self.logger, self.file_handler, self.stream_handler) 250 | 251 | -------------------------------------------------------------------------------- /src/preprocessing.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow U-Net Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import cv2 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import tifffile as tiff 12 | 13 | def main(dataset_path=None, is_write=False): 14 | # train_imgs = tiff.imread(os.path.join(dataset_path, 'train-volume.tif')) 15 | train_labels = tiff.imread(os.path.join(dataset_path, 'train-labels.tif')) 16 | 17 | save_dir = 'wmap_imgs' 18 | if not os.path.isdir(save_dir): 19 | os.makedirs(save_dir) 20 | 21 | file_name = os.path.join(dataset_path, 'train-wmaps.npy') 22 | if is_write: 23 | wmaps = np.zeros_like(train_labels, dtype=np.float32) 24 | for idx in range(train_labels.shape[0]): 25 | print('Image index: {}'.format(idx)) 26 | img = train_labels[idx] 27 | cal_weight_map(label=img, wmaps=wmaps, save_dir=save_dir, iter_time=idx) 28 | 29 | np.save(file_name, wmaps) 30 | 31 | wmaps = np.load(file_name) 32 | plot_wmaps(wmaps) 33 | 34 | 35 | def plot_wmaps(wmaps, nrows=5, ncols=6, hspace=0.2, wspace=0.1, vmin=0., vmax=12., interpolation='nearest'): 36 | # Create figure with sub-plots. 37 | fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(18, 10)) 38 | 39 | # Adjust vertical and horizontal spacing 40 | fig.subplots_adjust(hspace=hspace, wspace=wspace) 41 | 42 | im = None 43 | for i, ax in enumerate(axes.flat): 44 | if i < (nrows * ncols): 45 | # Plot image 46 | im = ax.imshow(wmaps[i], interpolation=interpolation, cmap=plt.cm.jet, vmin=vmin, vmax=vmax) 47 | 48 | # Show the classes as the label on the x-axis. 49 | xlabel = "Weight Map: {0}".format(str(i).zfill(2)) 50 | ax.set_xlabel(xlabel) 51 | 52 | # Remove ticks from the plot. 53 | ax.set_xticks([]) 54 | ax.set_yticks([]) 55 | 56 | # Add colorbar 57 | # fig.colorbar(im) # cmap=plt.cm.jet, vmin=vmin, vmax=vmax, 58 | 59 | # fig.subplots_adjust(right=0.8) 60 | fig.colorbar(im, ax=axes.ravel().tolist()) 61 | 62 | plt.show() 63 | 64 | 65 | def cal_weight_map(label, wmaps, save_dir, iter_time, wc=1., w0=10., sigma=5, interval=500, vmin=0, vmax=12): 66 | _, contours, _ = cv2.findContours(label, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 67 | wmap = wc * np.ones_like(label, dtype=np.float32) 68 | img = label.copy() 69 | 70 | y_points, x_points = np.where(img == 0) 71 | for idx, (y_point, x_point) in enumerate(zip(y_points, x_points)): 72 | if np.mod(idx, interval) == 0: 73 | print('{} / {}'.format(idx, len(y_points))) 74 | 75 | point = np.array([x_point, y_point]).astype(np.float32) 76 | dis_arr = [] 77 | 78 | for i in range(len(contours)): 79 | cnt = (np.squeeze(contours[i])).astype(np.float32) 80 | dis_arr.append(np.amin(np.sqrt(np.sum(np.power(point - cnt, 2), axis=1)))) 81 | 82 | dis_arr.sort() # sorting 83 | wmap[y_point, x_point] += wc + w0 * np.exp(- np.power(np.sum(dis_arr[0:2]), 2) / (2 * sigma * sigma)) 84 | 85 | plt.imshow(wmap, cmap=plt.cm.jet, vmin=vmin, vmax=vmax) 86 | plt.axis('off') 87 | 88 | # To solve the multiple color-bar problem 89 | if iter_time == 0: 90 | plt.colorbar() 91 | 92 | plt.savefig(os.path.join(save_dir, str(iter_time).zfill(2) + '.png'), bbox_inches='tight') 93 | wmaps[iter_time] = wmap 94 | 95 | 96 | if __name__ == '__main__': 97 | main(dataset_path='../../Data/EMSegmentation', is_write=False) 98 | 99 | 100 | -------------------------------------------------------------------------------- /src/solver.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow U-Net Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import cv2 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | import utils as utils 13 | 14 | 15 | class Solver(object): 16 | def __init__(self, sess, model, mean_value): 17 | self.sess = sess 18 | self.model = model 19 | self.mean_value = mean_value 20 | 21 | def train(self, x, y, wmap): 22 | feed = { 23 | self.model.inp_img: np.expand_dims(x, axis=3), 24 | self.model.out_img: y, 25 | self.model.weight_map: wmap, 26 | self.model.keep_prob: 0.5 27 | } 28 | 29 | train_op = self.model.train_op 30 | total_loss = self.model.total_loss 31 | avg_data_loass = self.model.avg_data_loss 32 | weighted_data_loss = self.model.weighted_data_loss 33 | reg_term = self.model.reg_term 34 | pred_cls = self.model.pred_cls 35 | summary = self.model.summary_op 36 | 37 | return self.sess.run([train_op, total_loss, avg_data_loass, weighted_data_loss, reg_term, summary, pred_cls], 38 | feed_dict=feed) 39 | 40 | def evalate(self, x, y, batch_size=4): 41 | print(' [*] Evaluation...') 42 | 43 | num_test = x.shape[0] 44 | avg_acc = 0. 45 | 46 | for i_start in range(0, num_test, batch_size): 47 | if i_start + batch_size < num_test: 48 | i_end = i_start + batch_size 49 | else: 50 | i_end = num_test - 1 51 | 52 | x_batch = x[i_start:i_end] 53 | y_batch = y[i_start:i_end] 54 | 55 | feed = { 56 | self.model.inp_img: np.expand_dims(x_batch, axis=3), 57 | self.model.out_img: y_batch, 58 | self.model.keep_prob: 1.0 59 | } 60 | 61 | acc_op = self.model.accuracy 62 | avg_acc += self.sess.run(acc_op, feed_dict=feed) 63 | 64 | avg_acc = np.float32(avg_acc / np.ceil(num_test / batch_size)) 65 | summary = self.sess.run(self.model.val_acc_op, feed_dict={self.model.val_acc: avg_acc}) 66 | 67 | return avg_acc, summary 68 | 69 | def test(self, x, iter_time, angle, test_dir, is_save=False): 70 | feed = { 71 | self.model.inp_img: np.expand_dims(x, axis=3), 72 | self.model.keep_prob: 1.0 73 | } 74 | 75 | preds = self.sess.run(self.model.pred, feed_dict=feed) 76 | pred = utils.merge_preds(preds, idx=iter_time, angle=angle, test_dir=test_dir, is_save=is_save) 77 | 78 | return pred 79 | 80 | def save_imgs(self, x_imgs, pred_imgs, y_imgs, iter_time, sample_dir=None, border=5): 81 | num_cols = 3 82 | _, H1, W1 = x_imgs.shape 83 | N, H2, W2 = pred_imgs.shape 84 | margin = int(0.5 * (H1 - H2)) 85 | 86 | canvas = np.zeros((N*H1+(N+1)*border, 1*W1+(num_cols-1)*W2+(num_cols+1)*border), dtype=np.uint8) 87 | for idx in range(N): 88 | canvas[(idx+1)*border+idx*H1:(idx+1)*border+(idx+1)*H1, border:border+W1] = \ 89 | x_imgs[idx] + self.mean_value 90 | canvas[(idx+1)*border+idx*H1+margin:(idx+1)*border+idx*H1+margin+H2, 2*border+W1:2*border+W1+W2] = \ 91 | pred_imgs[idx] * 255 92 | canvas[(idx+1)*border+idx*H1+margin:(idx+1)*border+idx*H1+margin+H2, 3*border+W1+W2:3*border+W1+2*W2] = \ 93 | y_imgs[idx] * 255 94 | 95 | cv2.imwrite(os.path.join(sample_dir, str(iter_time).zfill(5) + '.png'), canvas) 96 | 97 | def save_acc_record(self, acc): 98 | self.sess.run(self.model.save_best_acc_op, feed_dict={self.model.val_acc: acc}) 99 | 100 | def load_acc_record(self): 101 | best_acc = self.sess.run(self.model.best_acc_record) 102 | return best_acc 103 | 104 | def init(self): 105 | self.sess.run(tf.global_variables_initializer()) 106 | -------------------------------------------------------------------------------- /src/tensorflow_utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow Utils Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import numpy as np 8 | import tensorflow as tf 9 | from tensorflow.python.training import moving_averages 10 | 11 | 12 | def padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='pad2d'): 13 | if pad_type == 'REFLECT': 14 | return tf.pad(x, [[0, 0], [p_h, p_h], [p_w, p_w], [0, 0]], 'REFLECT', name=name) 15 | 16 | 17 | def conv2d(x, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, initializer=None, padding='SAME', name='conv2d', 18 | is_print=True, logger=None): 19 | with tf.variable_scope(name): 20 | if initializer is None: 21 | init_op = tf.truncated_normal_initializer(stddev=stddev) 22 | elif initializer == 'He': 23 | init_op = tf.initializers.he_normal() 24 | else: 25 | raise NotImplementedError 26 | 27 | w = tf.get_variable('w', [k_h, k_w, x.get_shape()[-1], output_dim], initializer=init_op) 28 | conv = tf.nn.conv2d(x, w, strides=[1, d_h, d_w, 1], padding=padding) 29 | 30 | biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) 31 | # conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) 32 | conv = tf.nn.bias_add(conv, biases) 33 | 34 | if is_print: 35 | print_activations(conv, logger) 36 | 37 | return conv 38 | 39 | 40 | def deconv2d(x, output_dim, k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, initializer=None, padding_='SAME', 41 | output_size=None, name='deconv2d', with_w=False, is_print=True, logger=None): 42 | with tf.variable_scope(name): 43 | input_shape = x.get_shape().as_list() 44 | 45 | # calculate output size 46 | h_output, w_output = None, None 47 | if not output_size: 48 | h_output, w_output = input_shape[1] * 2, input_shape[2] * 2 49 | # output_shape = [input_shape[0], h_output, w_output, k] # error when not define batch_size 50 | output_shape = [tf.shape(x)[0], h_output, w_output, output_dim] 51 | 52 | # conv2d transpose 53 | if initializer is None: 54 | init_op = tf.random_normal_initializer(stddev=stddev) 55 | elif initializer == 'He': 56 | init_op = tf.initializers.he_normal() 57 | else: 58 | raise NotImplementedError 59 | w = tf.get_variable('w', [k_h, k_w, output_dim, input_shape[3]], initializer=init_op) 60 | deconv = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=[1, d_h, d_w, 1], padding=padding_) 61 | 62 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 63 | deconv = tf.nn.bias_add(deconv, biases) 64 | 65 | if is_print: 66 | print_activations(deconv, logger) 67 | 68 | if with_w: 69 | return deconv, w, biases 70 | else: 71 | return deconv 72 | 73 | 74 | def concat(values, axis, name='concat', is_print=True, logger=None): 75 | output = tf.concat(values=values, axis=axis, name=name) 76 | 77 | if is_print: 78 | print_activations(output, logger) 79 | 80 | return output 81 | 82 | 83 | def upsampling2d(x, size=(2, 2), name='upsampling2d'): 84 | with tf.name_scope(name): 85 | shape = x.get_shape().as_list() 86 | return tf.image.resize_nearest_neighbor(x, size=(size[0] * shape[1], size[1] * shape[2])) 87 | 88 | def flatten(x, name='flatten', data_format='channels_last', is_print=True, logger=None): 89 | try: 90 | output = tf.layers.flatten(inputs=x, name=name, data_format=data_format) 91 | except(RuntimeError, TypeError, NameError): 92 | print('[*] Catch the flatten function Error!') 93 | output = tf.contrib.layers.flatten(inputs=x, scope=name) 94 | 95 | if is_print: 96 | print_activations(output, logger) 97 | return output 98 | 99 | 100 | def linear(x, output_size, bias_start=0.0, with_w=False, name='fc', is_print=True, logger=None): 101 | shape = x.get_shape().as_list() 102 | 103 | with tf.variable_scope(name): 104 | matrix = tf.get_variable(name="matrix", shape=[shape[1], output_size], 105 | dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer()) 106 | bias = tf.get_variable(name="bias", shape=[output_size], 107 | initializer=tf.constant_initializer(bias_start)) 108 | output = tf.matmul(x, matrix) + bias 109 | 110 | if is_print: 111 | print_activations(output, logger) 112 | 113 | if with_w: 114 | return output, matrix, bias 115 | else: 116 | return output 117 | 118 | 119 | def norm(x, name, _type, _ops, is_train=True): 120 | if _type == 'batch': 121 | return batch_norm(x, name=name, _ops=_ops, is_train=is_train) 122 | elif _type == 'instance': 123 | return instance_norm(x, name=name) 124 | else: 125 | raise NotImplementedError 126 | 127 | 128 | def batch_norm(x, name, _ops, is_train=True): 129 | """Batch normalization.""" 130 | with tf.variable_scope(name): 131 | params_shape = [x.get_shape()[-1]] 132 | 133 | beta = tf.get_variable('beta', params_shape, tf.float32, 134 | initializer=tf.constant_initializer(0.0, tf.float32)) 135 | gamma = tf.get_variable('gamma', params_shape, tf.float32, 136 | initializer=tf.constant_initializer(1.0, tf.float32)) 137 | 138 | if is_train is True: 139 | mean, variance = tf.nn.moments(x, [0, 1, 2], name='moments') 140 | 141 | moving_mean = tf.get_variable('moving_mean', params_shape, tf.float32, 142 | initializer=tf.constant_initializer(0.0, tf.float32), 143 | trainable=False) 144 | moving_variance = tf.get_variable('moving_variance', params_shape, tf.float32, 145 | initializer=tf.constant_initializer(1.0, tf.float32), 146 | trainable=False) 147 | 148 | _ops.append(moving_averages.assign_moving_average(moving_mean, mean, 0.9)) 149 | _ops.append(moving_averages.assign_moving_average(moving_variance, variance, 0.9)) 150 | else: 151 | mean = tf.get_variable('moving_mean', params_shape, tf.float32, 152 | initializer=tf.constant_initializer(0.0, tf.float32), trainable=False) 153 | variance = tf.get_variable('moving_variance', params_shape, tf.float32, 154 | initializer=tf.constant_initializer(1.0, tf.float32), trainable=False) 155 | 156 | # epsilon used to be 1e-5. Maybe 0.001 solves NaN problem in deeper net. 157 | y = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 1e-5) 158 | y.set_shape(x.get_shape()) 159 | 160 | return y 161 | 162 | 163 | def instance_norm(x, name='instance_norm', mean=1.0, stddev=0.02, epsilon=1e-5): 164 | with tf.variable_scope(name): 165 | depth = x.get_shape()[3] 166 | scale = tf.get_variable( 167 | 'scale', [depth], tf.float32, 168 | initializer=tf.random_normal_initializer(mean=mean, stddev=stddev, dtype=tf.float32)) 169 | offset = tf.get_variable('offset', [depth], initializer=tf.constant_initializer(0.0)) 170 | 171 | # calcualte mean and variance as instance 172 | mean, variance = tf.nn.moments(x, axes=[1, 2], keep_dims=True) 173 | 174 | # normalization 175 | inv = tf.rsqrt(variance + epsilon) 176 | normalized = (x - mean) * inv 177 | 178 | return scale * normalized + offset 179 | 180 | 181 | def n_res_blocks(x, _ops=None, norm_='instance', is_train=True, num_blocks=6, is_print=True, logger=None): 182 | output = None 183 | for idx in range(1, num_blocks+1): 184 | output = res_block(x, x.get_shape()[3], _ops=_ops, norm_=norm_, is_train=is_train, 185 | name='res{}'.format(idx)) 186 | x = output 187 | 188 | if is_print: 189 | print_activations(output, logger) 190 | 191 | return output 192 | 193 | 194 | # norm(x, name, _type, _ops, is_train=True) 195 | def res_block(x, k, _ops=None, norm_='instance', is_train=True, pad_type=None, name=None): 196 | with tf.variable_scope(name): 197 | conv1, conv2 = None, None 198 | 199 | # 3x3 Conv-Batch-Relu S1 200 | with tf.variable_scope('layer1'): 201 | if pad_type is None: 202 | conv1 = conv2d(x, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv') 203 | elif pad_type == 'REFLECT': 204 | padded1 = padding2d(x, p_h=1, p_w=1, pad_type='REFLECT', name='padding') 205 | conv1 = conv2d(padded1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv') 206 | normalized1 = norm(conv1, name='norm', _type=norm_, _ops=_ops, is_train=is_train) 207 | relu1 = tf.nn.relu(normalized1) 208 | 209 | # 3x3 Conv-Batch S1 210 | with tf.variable_scope('layer2'): 211 | if pad_type is None: 212 | conv2 = conv2d(relu1, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='SAME', name='conv') 213 | elif pad_type == 'REFLECT': 214 | padded2 = padding2d(relu1, p_h=1, p_w=1, pad_type='REFLECT', name='padding') 215 | conv2 = conv2d(padded2, k, k_h=3, k_w=3, d_h=1, d_w=1, padding='VALID', name='conv') 216 | normalized2 = norm(conv2, name='norm', _type=norm_, _ops=_ops, is_train=is_train) 217 | 218 | # sum layer1 and layer2 219 | output = x + normalized2 220 | return output 221 | 222 | 223 | def identity(x, name='identity', is_print=True, logger=None): 224 | output = tf.identity(x, name=name) 225 | if is_print: 226 | print_activations(output, logger) 227 | 228 | return output 229 | 230 | 231 | def max_pool(x, name='max_pool', ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], is_print=True, logger=None): 232 | output = tf.nn.max_pool(value=x, ksize=ksize, strides=strides, padding='SAME', name=name) 233 | if is_print: 234 | print_activations(output, logger) 235 | return output 236 | 237 | def dropout(x, keep_prob=0.5, seed=None, name='dropout', is_print=True, logger=None): 238 | try: 239 | output = tf.nn.dropout(x=x, 240 | rate=keep_prob, 241 | seed=tf.set_random_seed(seed) if seed else None, 242 | name=name) 243 | except(RuntimeError, TypeError, NameError): 244 | print('[*] Catch the dropout function Error!') 245 | output = tf.nn.dropout(x=x, 246 | keep_prob=keep_prob, 247 | seed=tf.set_random_seed(seed) if seed else None, 248 | name=name) 249 | 250 | if is_print: 251 | print_activations(output, logger) 252 | 253 | return output 254 | 255 | def sigmoid(x, name='sigmoid', is_print=True, logger=None): 256 | output = tf.nn.sigmoid(x, name=name) 257 | if is_print: 258 | print_activations(output, logger) 259 | 260 | return output 261 | 262 | 263 | def tanh(x, name='tanh', is_print=True, logger=None): 264 | output = tf.nn.tanh(x, name=name) 265 | if is_print: 266 | print_activations(output, logger) 267 | 268 | return output 269 | 270 | 271 | def relu(x, name='relu', is_print=True, logger=None): 272 | output = tf.nn.relu(x, name=name) 273 | if is_print: 274 | print_activations(output, logger) 275 | 276 | return output 277 | 278 | 279 | def lrelu(x, leak=0.2, name='lrelu', is_print=True, logger=None): 280 | output = tf.maximum(x, leak*x, name=name) 281 | if is_print: 282 | print_activations(output, logger) 283 | 284 | return output 285 | 286 | 287 | def elu(x, name='elu', is_print=True, logger=None): 288 | output = tf.nn.elu(x, name=name) 289 | if is_print: 290 | print_activations(output, logger) 291 | 292 | return output 293 | 294 | 295 | def xavier_init(in_dim): 296 | xavier_stddev = 1. / tf.sqrt(in_dim / 2.) 297 | return xavier_stddev 298 | 299 | 300 | def print_activations(t, logger=None): 301 | if logger is None: 302 | print(t.op.name, '{}', t.get_shape().as_list()) 303 | else: 304 | logger.info(t.op.name + '{}'.format(t.get_shape().as_list())) 305 | 306 | 307 | def show_all_variables(logger=None): 308 | total_count = 0 309 | 310 | for idx, op in enumerate(tf.trainable_variables()): 311 | shape = op.get_shape() 312 | count = np.prod(shape) 313 | 314 | if logger is None: 315 | print("[%2d] %s %s = %s" % (idx, op.name, shape, count)) 316 | else: 317 | logger.info("[%2d] %s %s = %s" % (idx, op.name, shape, count)) 318 | 319 | total_count += int(count) 320 | 321 | if logger is None: 322 | print("[Total] variable size: %s" % "{:,}".format(total_count)) 323 | else: 324 | logger.info("[Total] variable size: %s" % "{:,}".format(total_count)) 325 | 326 | 327 | def batch_convert2int(images): 328 | # images: 4D float tensor (batch_size, image_size, image_size, depth) 329 | return tf.map_fn(convert2int, images, dtype=tf.uint8) 330 | 331 | 332 | def convert2int(image): 333 | # transform from float tensor ([-1.,1.]) to int image ([0,255]) 334 | return tf.image.convert_image_dtype((image + 1.0) / 2.0, tf.uint8) 335 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------- 2 | # Tensorflow U-Net Implementation 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Cheng-Bin Jin 5 | # Email: sbkim0407@gmail.com 6 | # --------------------------------------------------------- 7 | import os 8 | import sys 9 | import cv2 10 | import logging 11 | import elasticdeform 12 | import numpy as np 13 | from sklearn.metrics import confusion_matrix 14 | from scipy.ndimage import rotate 15 | 16 | 17 | def init_logger(log_dir, name, is_train): 18 | logger = logging.getLogger(__name__) # logger 19 | logger.setLevel(logging.INFO) 20 | 21 | file_handler, stream_handler = None, None 22 | if is_train: 23 | formatter = logging.Formatter('%(asctime)s:%(name)s:%(message)s') 24 | 25 | # file handler 26 | file_handler = logging.FileHandler(os.path.join(log_dir, name + '.log')) 27 | file_handler.setFormatter(formatter) 28 | file_handler.setLevel(logging.INFO) 29 | 30 | # stream handler 31 | stream_handler = logging.StreamHandler() 32 | stream_handler.setFormatter(formatter) 33 | 34 | # add handlers 35 | logger.addHandler(file_handler) 36 | logger.addHandler(stream_handler) 37 | 38 | return logger, file_handler, stream_handler 39 | 40 | 41 | def release_handles(logger, file_handler, stream_handler): 42 | file_handler.close() 43 | stream_handler.close() 44 | logger.removeHandler(file_handler) 45 | logger.removeHandler(stream_handler) 46 | 47 | 48 | def make_folders(is_train=True, cur_time=None): 49 | if is_train: 50 | model_dir = os.path.join('model', '{}'.format(cur_time)) 51 | log_dir = os.path.join('logs', '{}'.format(cur_time)) 52 | sample_dir = os.path.join('sample', '{}'.format(cur_time)) 53 | test_dir = os.path.join('test', '{}'.format(cur_time)) 54 | 55 | if not os.path.isdir(model_dir): 56 | os.makedirs(model_dir) 57 | 58 | if not os.path.isdir(log_dir): 59 | os.makedirs(log_dir) 60 | 61 | if not os.path.isdir(sample_dir): 62 | os.makedirs(sample_dir) 63 | 64 | else: 65 | model_dir = os.path.join('model', '{}'.format(cur_time)) 66 | log_dir = os.path.join('logs', '{}'.format(cur_time)) 67 | sample_dir = os.path.join('sample', '{}'.format(cur_time)) 68 | test_dir = os.path.join('test', '{}'.format(cur_time)) 69 | 70 | if not os.path.isdir(test_dir): 71 | os.makedirs(test_dir) 72 | 73 | return model_dir, log_dir, sample_dir, test_dir 74 | 75 | 76 | def imshow(img, label, wmap, idx, alpha=0.6, delay=1, log_dir=None, show=False): 77 | img_dir = os.path.join(log_dir, 'img') 78 | if not os.path.isdir(img_dir): 79 | os.makedirs(img_dir) 80 | 81 | if len(img.shape) == 2: 82 | img = np.dstack((img, img, img)) 83 | 84 | # Convert to pseudo color map from gray-scale image 85 | pseudo_label = None 86 | if len(label.shape) == 2: 87 | pseudo_label = pseudoColor(label) 88 | 89 | beta = 1. - alpha 90 | overlap = cv2.addWeighted(src1=img, 91 | alpha=alpha, 92 | src2=pseudo_label, 93 | beta=beta, 94 | gamma=0.0) 95 | 96 | # Weight-map 97 | wmap_color = cv2.applyColorMap(normalize_uint8(wmap), cv2.COLORMAP_JET) 98 | 99 | canvas = np.hstack((img, pseudo_label, overlap, wmap_color)) 100 | cv2.imwrite(os.path.join(img_dir, 'GT_' + str(idx).zfill(2) + '.png'), canvas) 101 | 102 | if show: 103 | cv2.imshow('Show', canvas) 104 | 105 | if cv2.waitKey(delay) & 0xFF == 27: 106 | sys.exit('Esc clicked!') 107 | 108 | def normalize_uint8(x, x_min=0, x_max=12, fit=255): 109 | x_norm = np.uint8(fit * (x - x_min) / (x_max - x_min)) 110 | return x_norm 111 | 112 | 113 | def pseudoColor(label, thickness=3): 114 | img = label.copy() 115 | img, contours, hierachy = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 116 | img = np.dstack((img, img, img)) 117 | 118 | for i in range(len(contours)): 119 | cnt = contours[i] 120 | cv2.drawContours(img, [cnt], contourIdx=-1, color=(0, 255, 0), thickness=thickness) 121 | cv2.fillPoly(img, [cnt], color=randomColors(i)) 122 | 123 | return img 124 | 125 | 126 | def randomColors(idx): 127 | Sky = [128, 128, 128] 128 | Building = [128, 0, 0] 129 | Pole = [192, 192, 128] 130 | Road = [128, 64, 128] 131 | Pavement = [60, 40, 222] 132 | Tree = [128, 128, 0] 133 | SignSymbol = [192, 128, 128] 134 | Fence = [64, 64, 128] 135 | Car = [64, 0, 128] 136 | Pedestrian = [64, 64, 0] 137 | Bicyclist = [0, 128, 192] 138 | DarkRed = [0, 0, 139] 139 | PaleVioletRed = [147, 112, 219] 140 | Orange = [0, 165, 255] 141 | Teal = [128, 128, 0] 142 | 143 | color_dict = [Sky, Building, Pole, Road, Pavement, 144 | Tree, SignSymbol, Fence, Car, Pedestrian, 145 | Bicyclist, DarkRed, PaleVioletRed, Orange, Teal] 146 | 147 | return color_dict[idx % len(color_dict)] 148 | 149 | 150 | def test_augmentation(img, label, wmap, idx, margin=10, log_dir=None): 151 | img_dir = os.path.join(log_dir, 'img') 152 | if not os.path.isdir(img_dir): 153 | os.makedirs(img_dir) 154 | 155 | img_tran, label_tran, wmap_tran = aug_translate(img, label, wmap) # random translation 156 | img_flip, label_flip, wmap_flip = aug_flip(img, label, wmap) # random horizontal and vertical flip 157 | img_rota, label_rota, wmap_rota = aug_rotate(img, label, wmap) # random rotation 158 | img_defo, label_defo, wmap_defo = aug_elastic_deform(img, label, wmap) # random elastic deformation 159 | # img_pert, label_pert, wmap_pert = aug_perturbation(img, label, wmap) # random intensity perturbation 160 | 161 | # Arrange the images in a canvas and save them into the log file 162 | imgs = [img, img_tran, img_flip, img_rota, img_defo] 163 | labels = [label, label_tran, label_flip, label_rota, label_defo] 164 | wmaps = [wmap, wmap_tran, wmap_flip, wmap_rota, wmap_defo] 165 | h, w = img.shape 166 | canvas = np.zeros((3 * h + 4 * margin, len(imgs) * w + (len(imgs) + 1) * margin), dtype=np.uint8) 167 | 168 | for i, (img, label, wmap) in enumerate(zip(imgs, labels, wmaps)): 169 | canvas[1*margin:1*margin+h, (i+1) * margin + i * w:(i+1) * margin + (i + 1) * w] = img 170 | canvas[2*margin+1*h:2*margin+2*h, (i+1) * margin + i * w:(i+1) * margin + (i + 1) * w] = label 171 | canvas[3*margin+2*h:3*margin+3*h, (i+1) * margin + i * w:(i+1) * margin + (i + 1) * w] = normalize_uint8(wmap) 172 | 173 | cv2.imwrite(os.path.join(img_dir, 'augmentation_' + str(idx).zfill(2) + '.png'), canvas) 174 | 175 | 176 | def test_cropping(img, label, wmap, idx, input_size, output_size, log_dir=None, 177 | white=(255, 255, 255), blue=(255, 141, 47), red=(91, 70, 246), thickness=2, margin=10): 178 | img_dir = os.path.join(log_dir, 'img') 179 | if not os.path.isdir(img_dir): 180 | os.makedirs(img_dir) 181 | 182 | img_crop, label_crop, wmap_crop, img_pad, rand_pos_h, rand_pos_w = cropping( 183 | img, label, wmap, input_size, output_size, is_extend=True) 184 | border_size = int((input_size - output_size) * 0.5) 185 | 186 | # Convert gray images to BGR images 187 | img_pad = np.dstack((img_pad, img_pad, img_pad)) 188 | label_show = np.dstack((label, label, label)) 189 | wmap_show = normalize_uint8(np.dstack((wmap, wmap, wmap))) 190 | 191 | # Draw boundary lines 192 | img_pad = cv2.line(img=img_pad, 193 | pt1=(0, border_size), 194 | pt2=(img_pad.shape[1]-1, border_size), 195 | color=white, 196 | thickness=thickness) 197 | img_pad = cv2.line(img=img_pad, 198 | pt1=(0, img_pad.shape[0]-1-border_size), 199 | pt2=(img_pad.shape[1]-1, img_pad.shape[0]-1-border_size), 200 | color=white, 201 | thickness=thickness) 202 | img_pad = cv2.line(img=img_pad, 203 | pt1=(border_size, 0), 204 | pt2=(border_size, img_pad.shape[0]-1), 205 | color=white, 206 | thickness=thickness) 207 | img_pad = cv2.line(img=img_pad, 208 | pt1=(img_pad.shape[1]-1-border_size, 0), 209 | pt2=(img_pad.shape[1]-1-border_size, img_pad.shape[0]-1), 210 | color=white, 211 | thickness=thickness) 212 | 213 | # Draw the ROI input region 214 | img_pad = cv2.rectangle(img=img_pad, 215 | pt1=(rand_pos_w+border_size, rand_pos_h+border_size), 216 | pt2=(rand_pos_w+border_size+output_size, rand_pos_h+border_size+output_size), 217 | color=red, 218 | thickness=thickness+1) 219 | img_pad = cv2.rectangle(img=img_pad, 220 | pt1=(rand_pos_w, rand_pos_h), 221 | pt2=(rand_pos_w+input_size, rand_pos_h+input_size), 222 | color=blue, 223 | thickness=thickness+1) 224 | label_show = cv2.rectangle(img=label_show, 225 | pt1=(rand_pos_w, rand_pos_h), 226 | pt2=(rand_pos_w+output_size, rand_pos_h+output_size), 227 | color=red, 228 | thickness=thickness+1) 229 | wmap_show = cv2.rectangle(img=wmap_show, 230 | pt1=(rand_pos_w, rand_pos_h), 231 | pt2=(rand_pos_w+output_size, rand_pos_h+output_size), 232 | color=red, 233 | thickness=thickness+1) 234 | 235 | img_crop = cv2.rectangle(img=np.dstack((img_crop, img_crop, img_crop)), 236 | pt1=(2, 2), 237 | pt2=(img_crop.shape[1]-2, img_crop.shape[0]-2), 238 | color=blue, 239 | thickness=thickness+1) 240 | label_crop = cv2.rectangle(img=np.dstack((label_crop, label_crop, label_crop)), 241 | pt1=(2, 2), 242 | pt2=(label_crop.shape[1]-2, label_crop.shape[0]-2), 243 | color=red, 244 | thickness=thickness+1) 245 | wmap_crop = cv2.rectangle(img=normalize_uint8(np.dstack((wmap_crop, wmap_crop, wmap_crop))), 246 | pt1=(2, 2), 247 | pt2=(wmap_crop.shape[1]-2, wmap_crop.shape[0]-2), 248 | color=red, 249 | thickness=thickness+1) 250 | 251 | canvas = np.zeros((img_pad.shape[0] + label_show.shape[0] + wmap_show.shape[0] + 4 * margin, 252 | img_pad.shape[1] + img_crop.shape[1] + 3 * margin, 3), dtype=np.uint8) 253 | 254 | # Copy img_pad 255 | h_start = margin 256 | w_start = margin 257 | canvas[h_start:h_start + img_pad.shape[0], w_start:w_start+img_pad.shape[1], :] = img_pad 258 | 259 | # Copy label_show 260 | h_start = 2 * margin + img_pad.shape[0] 261 | w_start = margin 262 | canvas[h_start:h_start + label_show.shape[0], w_start:w_start + label_show.shape[1], :] = label_show 263 | 264 | # Copy wmap_show 265 | h_start = 3 * margin + img_pad.shape[0] + label_show.shape[0] 266 | w_start = margin 267 | canvas[h_start:h_start + wmap_show.shape[0], w_start:w_start + wmap_show.shape[1], :] = wmap_show 268 | 269 | # Draw connections between the left and right images 270 | # Four connections for the input images 271 | canvas = cv2.line(img=canvas, 272 | pt1=(margin+rand_pos_w, margin+rand_pos_h), 273 | pt2=(2*margin+img_pad.shape[1], margin), 274 | color=blue, 275 | thickness=thickness+1) 276 | canvas = cv2.line(img=canvas, 277 | pt1=(margin+rand_pos_w+input_size, margin+rand_pos_h), 278 | pt2=(2*margin+img_pad.shape[1]+img_crop.shape[1], margin), 279 | color=blue, 280 | thickness=thickness+1) 281 | canvas = cv2.line(img=canvas, 282 | pt1=(margin+rand_pos_w, margin+rand_pos_h+input_size), 283 | pt2=(2*margin+img_pad.shape[1], margin+input_size), 284 | color=blue, 285 | thickness=thickness+1) 286 | canvas = cv2.line(img=canvas, 287 | pt1=(margin+rand_pos_w+input_size, margin+rand_pos_h+input_size), 288 | pt2=(2*margin+img_pad.shape[1]+img_crop.shape[1], margin+input_size), 289 | color=blue, 290 | thickness=thickness+1) 291 | 292 | # Four connections for the label images 293 | canvas = cv2.line(img=canvas, 294 | pt1=(margin+rand_pos_w, 2*margin+img_pad.shape[1]+rand_pos_h), 295 | pt2=(2*margin+img_pad.shape[0], 2*margin+img_pad.shape[1]), 296 | color=red, 297 | thickness=thickness+1) 298 | canvas = cv2.line(img=canvas, 299 | pt1=(margin+output_size+rand_pos_w, 2*margin+img_pad.shape[0]+rand_pos_h), 300 | pt2=(2*margin+img_pad.shape[1]+output_size, 2*margin+img_pad.shape[0]), 301 | color=red, 302 | thickness=thickness+1) 303 | canvas = cv2.line(img=canvas, 304 | pt1=(margin+rand_pos_w, 2*margin+img_pad.shape[0]+output_size+rand_pos_h), 305 | pt2=(2*margin+img_pad.shape[1], 2*margin+img_pad.shape[0]+output_size), 306 | color=red, 307 | thickness=thickness+1) 308 | canvas = cv2.line(img=canvas, 309 | pt1=(margin+rand_pos_w+output_size, 2*margin+img_pad.shape[1]+output_size+rand_pos_h), 310 | pt2=(2*margin+img_pad.shape[1]+output_size, 2*margin+img_pad.shape[1]+output_size), 311 | color=red, 312 | thickness=thickness+1) 313 | 314 | # Four connections for the weight-map images 315 | canvas = cv2.line(img=canvas, 316 | pt1=(margin+rand_pos_w, 3*margin+img_pad.shape[0]+label_show.shape[0]+rand_pos_h), 317 | pt2=(2*margin+img_pad.shape[1], 3*margin+img_pad.shape[0]+label_show.shape[0]), 318 | color=red, 319 | thickness=thickness+1) 320 | canvas = cv2.line(img=canvas, 321 | pt1=(margin+output_size+rand_pos_w, 3*margin+img_pad.shape[0]+label_show.shape[0]+rand_pos_h), 322 | pt2=(2*margin+img_pad.shape[1]+output_size, 3*margin+img_pad.shape[0]+label_show.shape[0]), 323 | color=red, 324 | thickness=thickness+1) 325 | canvas = cv2.line(img=canvas, 326 | pt1=(margin+rand_pos_w, 3*margin+img_pad.shape[0]+label_show.shape[0]+output_size+rand_pos_h), 327 | pt2=(2*margin+img_pad.shape[1], 3*margin+img_pad.shape[0]+label_show.shape[0]+output_size), 328 | color=red, 329 | thickness=thickness+1) 330 | canvas = cv2.line(img=canvas, 331 | pt1=(margin+rand_pos_w+output_size, 3*margin+img_pad.shape[0]+label_show.shape[0]+output_size+rand_pos_h), 332 | pt2=(2*margin+img_pad.shape[1]+output_size, 3*margin+img_pad.shape[0]+label_show.shape[0]+output_size), 333 | color=red, 334 | thickness=thickness+1) 335 | 336 | # Copy img_crop 337 | h_start = margin 338 | w_start = 2 * margin + img_pad.shape[1] 339 | canvas[h_start:h_start + img_crop.shape[0], w_start:w_start + img_crop.shape[1], :] = img_crop 340 | 341 | # Copy label_crop 342 | h_start = 2 * margin + img_pad.shape[0] 343 | w_start = 2 * margin + img_pad.shape[1] 344 | canvas[h_start:h_start + label_crop.shape[0], w_start:w_start + label_crop.shape[1], :] = label_crop 345 | 346 | # Copy wmap_crop 347 | h_start = 3 * margin + img_pad.shape[0] + label_show.shape[0] 348 | w_start = 2 * margin + img_pad.shape[1] 349 | canvas[h_start:h_start + wmap_crop.shape[0], w_start:w_start + wmap_crop.shape[1], :] = wmap_crop 350 | 351 | cv2.imwrite(os.path.join(img_dir, 'crop_' + str(idx).zfill(2) + '.png'), canvas) 352 | 353 | 354 | def aug_translate(img, label, wmap, max_factor=1.2): 355 | assert len(img.shape) == 2 and len(label.shape) == 2 and len(wmap.shape) == 2 356 | 357 | # Resize originl image 358 | resize_factor = np.random.uniform(low=1., high=max_factor) 359 | img_bigger = cv2.resize(src=img.copy(), dsize=None, fx=resize_factor, fy=resize_factor, 360 | interpolation=cv2.INTER_LINEAR) 361 | label_bigger = cv2.resize(src=label.copy(), dsize=None, fx=resize_factor, fy=resize_factor, 362 | interpolation=cv2.INTER_NEAREST) 363 | wmap_bigger = cv2.resize(src=wmap.copy(), dsize=None, fx=resize_factor, fy=resize_factor, 364 | interpolation=cv2.INTER_NEAREST) 365 | 366 | # Generate random positions for horizontal and vertical axes 367 | h_bigger, w_bigger = img_bigger.shape 368 | h_star = np.random.random_integers(low=0, high=h_bigger-img.shape[0]) 369 | w_star = np.random.random_integers(low=0, high=w_bigger-img.shape[1]) 370 | 371 | # Crop image from the bigger one 372 | img_crop = img_bigger[h_star:h_star+img.shape[1], w_star:w_star+img.shape[0]] 373 | label_crop = label_bigger[h_star:h_star+img.shape[1], w_star:w_star+img.shape[0]] 374 | wmap_crop = wmap_bigger[h_star:h_star+img.shape[1], w_star:w_star+img.shape[0]] 375 | 376 | return img_crop, label_crop, wmap_crop 377 | 378 | 379 | def aug_flip(img, label, wmap): 380 | assert len(img.shape) == 2 and len(label.shape) == 2 and len(wmap.shape) == 2 381 | 382 | # Random horizontal flip 383 | if np.random.uniform(low=0., high=1.) > 0.5: 384 | img_hflip = cv2.flip(src=img, flipCode=0) 385 | label_hflip = cv2.flip(src=label, flipCode=0) 386 | wmap_hflip = cv2.flip(src=wmap, flipCode=0) 387 | else: 388 | img_hflip = img.copy() 389 | label_hflip = label.copy() 390 | wmap_hflip = wmap.copy() 391 | 392 | # Random vertical flip 393 | if np.random.uniform(low=0., high=1.) > 0.5: 394 | img_vflip = cv2.flip(src=img_hflip, flipCode=1) 395 | label_vflip = cv2.flip(src=label_hflip, flipCode=1) 396 | wmap_vflip = cv2.flip(src=wmap_hflip, flipCode=1) 397 | else: 398 | img_vflip = img_hflip.copy() 399 | label_vflip = label_hflip.copy() 400 | wmap_vflip = wmap_hflip.copy() 401 | 402 | return img_vflip, label_vflip, wmap_vflip 403 | 404 | def test_rotate(img, iter_time, test_dir=None, margin=5, start=0, stop=360, num=7): 405 | h, w = img.shape 406 | canvas = np.zeros((2*margin+h, num*w+(num+1)*margin), dtype=np.uint8) 407 | 408 | # Rotate test image using 60 degree interval 409 | for i, angle in enumerate(np.linspace(start=start, stop=stop, num=num, endpoint=False)): 410 | # print('angle: {:.3f}'.format(angle)) 411 | img_rotate = rotate(input=img, angle=angle, axes=(0, 1), reshape=False, order=3, mode='reflect') 412 | canvas[margin:margin+h, (i+1)*margin+i*w:(i+1)*margin+(i+1)*w] = img_rotate 413 | 414 | cv2.imwrite(os.path.join(test_dir, 'Rotate_' + str(iter_time).zfill(2) + '.png'), canvas) 415 | 416 | 417 | def aug_rotate(img, label, wmap): 418 | assert len(img.shape) == 2 and len(label.shape) == 2 and len(wmap.shape) 419 | 420 | # Random rotate image 421 | angle = np.random.randint(low=0, high=360, size=None) 422 | img_rotate = rotate(input=img, angle=angle, axes=(0, 1), reshape=False, order=3, mode='reflect') 423 | label_rotate = rotate(input=label, angle=angle, axes=(0, 1), reshape=False, order=0, mode='reflect') 424 | wmap_rotate = rotate(input=wmap, angle=angle, axes=(0, 1), reshape=False, order=0, mode='reflect') 425 | 426 | # Correct label map 427 | ret, label_rotate = cv2.threshold(src=label_rotate, thresh=127.5, maxval=255, type=cv2.THRESH_BINARY) 428 | 429 | return img_rotate, label_rotate, wmap_rotate 430 | 431 | 432 | def aug_elastic_deform(img, label, wmap): 433 | assert len(img.shape) == 2 and len(label.shape) == 2 and len(wmap.shape) == 2 434 | 435 | # Apply deformation with a random 3 x 3 grid to inputs X and Y, 436 | # with a different interpolation for each input 437 | img_distort, label_distort, wmap_distort = elasticdeform.deform_random_grid(X=[img, label, wmap], 438 | sigma=10, 439 | points=3, 440 | order=[2, 0, 0], 441 | mode='mirror') 442 | 443 | return img_distort, label_distort, wmap_distort 444 | 445 | 446 | def aug_perturbation(img, label, wmap, low=0.8, high=1.2): 447 | pertur_map = np.random.uniform(low=low, high=high, size=img.shape) 448 | img_en = np.round(img * pertur_map).astype(np.uint8) 449 | img_en = np.clip(img_en, a_min=0, a_max=255) 450 | return img_en, label, wmap 451 | 452 | 453 | def test_imshow(img, idx, input_size=572, output_size=388, test_dir=None, margin=5, white=(255, 255, 255), thickness=2): 454 | border_size = int((input_size - output_size) * 0.5) 455 | img_pad = cv2.copyMakeBorder(img, border_size, border_size, border_size, border_size, cv2.BORDER_REFLECT_101) 456 | 457 | # Draw boundary lines 458 | img_pad = cv2.line(img=img_pad, 459 | pt1=(0, border_size), 460 | pt2=(img_pad.shape[1] - 1, border_size), 461 | color=white, 462 | thickness=thickness) 463 | img_pad = cv2.line(img=img_pad, 464 | pt1=(0, img_pad.shape[0] - 1 - border_size), 465 | pt2=(img_pad.shape[1] - 1, img_pad.shape[0] - 1 - border_size), 466 | color=white, 467 | thickness=thickness) 468 | img_pad = cv2.line(img=img_pad, 469 | pt1=(border_size, 0), 470 | pt2=(border_size, img_pad.shape[0] - 1), 471 | color=white, 472 | thickness=thickness) 473 | img_pad = cv2.line(img=img_pad, 474 | pt1=(img_pad.shape[1] - 1 - border_size, 0), 475 | pt2=(img_pad.shape[1] - 1 - border_size, img_pad.shape[0] - 1), 476 | color=white, 477 | thickness=thickness) 478 | 479 | # Crop 4 corners from the padded image 480 | img0 = img_pad[:input_size, :input_size] 481 | img1 = img_pad[:input_size, -input_size:] 482 | img2 = img_pad[-input_size:, :input_size] 483 | img3 = img_pad[-input_size:, -input_size:] 484 | 485 | canvas = np.zeros((2*input_size+3*margin, 2*input_size+img_pad.shape[1]+4*margin)) 486 | canvas[margin:margin+img_pad.shape[0], margin:margin+img_pad.shape[1]] = img_pad 487 | canvas[margin:margin+input_size, 2*margin+img_pad.shape[1]:2*margin+img_pad.shape[1]+input_size] = img0 488 | canvas[margin:margin+input_size, 3*margin+img_pad.shape[1]+input_size:3*margin+img_pad.shape[1]+2*input_size] = img1 489 | canvas[2*margin+input_size:2*margin+2*input_size, 2*margin+img_pad.shape[1]:2*margin+img_pad.shape[1]+input_size] = img2 490 | canvas[2*margin+input_size:2*margin+2*input_size, 3*margin+img_pad.shape[1]+input_size:3*margin+img_pad.shape[1]+2*input_size] = img3 491 | 492 | cv2.imwrite(os.path.join(test_dir, 'GT_' + str(idx).zfill(2) + '.png'), canvas) 493 | 494 | def test_data_cropping(img, input_size, output_size, num_blocks=4): 495 | x_batchs = np.zeros((num_blocks, input_size, input_size), dtype=np.float32) 496 | 497 | border_size = int((input_size - output_size) * 0.5) 498 | img_pad = cv2.copyMakeBorder(img, border_size, border_size, border_size, border_size, cv2.BORDER_REFLECT_101) 499 | 500 | # Crop 4 corners from the padded image 501 | x_batchs[0] = img_pad[:input_size, :input_size] 502 | x_batchs[1] = img_pad[:input_size, -input_size:] 503 | x_batchs[2] = img_pad[-input_size:, :input_size] 504 | x_batchs[3] = img_pad[-input_size:, -input_size:] 505 | return x_batchs 506 | 507 | 508 | def cropping(img, label, wmap, input_size, output_size, is_extend=False): 509 | border_size = int((input_size - output_size) * 0.5) 510 | rand_pos_h = np.random.randint(low=0, high=img.shape[0] - output_size) 511 | rand_pos_w = np.random.randint(low=0, high=img.shape[1] - output_size) 512 | 513 | img_pad = cv2.copyMakeBorder(img, border_size, border_size, border_size, border_size, cv2.BORDER_REFLECT_101) 514 | img_crop = img_pad[rand_pos_h:rand_pos_h+input_size, rand_pos_w:rand_pos_w+input_size].copy() 515 | label_crop = label[rand_pos_h:rand_pos_h+output_size, rand_pos_w:rand_pos_w+output_size].copy() 516 | wmap_crop = wmap[rand_pos_h:rand_pos_h+output_size, rand_pos_w:rand_pos_w+output_size].copy() 517 | 518 | if is_extend: 519 | return img_crop, label_crop, wmap_crop, img_pad, rand_pos_h, rand_pos_w 520 | else: 521 | return img_crop, label_crop, wmap_crop 522 | 523 | 524 | def merge_rotated_preds(preds, img, iter_time, start, stop, num, test_dir, margin=5, alpha=0.6, is_save=False): 525 | inv_preds = np.zeros((num, *preds[0].shape), dtype=np.float32) 526 | 527 | for i, angle in enumerate(np.linspace(start=start, stop=stop, num=num, endpoint=False)): 528 | pred = preds[i].copy() 529 | inv_angle = 360. - angle 530 | inv_preds[i] = rotate(input=pred, angle=inv_angle, axes=(0, 1), reshape=False, order=0, mode='constant', cval=0.) 531 | 532 | y_pred = np.zeros_like(inv_preds[0]) 533 | for i in range(num): 534 | y_pred += inv_preds[i] 535 | 536 | y_pred_cls = np.uint8(np.argmax(y_pred, axis=2) * 255.) 537 | 538 | if is_save: 539 | h, w = preds[0].shape[0], preds[0].shape[1] 540 | canvas = np.zeros((3*margin+2*h, (num+2)*margin+(num+1)*w), dtype=np.uint8) 541 | 542 | for i in range(num): 543 | canvas[margin:margin+h, (i+1)*margin+i*w:(i+1)*margin+(i+1)*w] = np.uint(np.argmax(preds[i], axis=2) * 255.) 544 | canvas[2*margin+h:2*margin+2*h, (i+1)*margin+i*w:(i+1)*margin+(i+1)*w] = np.uint(np.argmax(inv_preds[i], axis=2) * 255.) 545 | 546 | canvas[2 * margin + h:2 * margin + 2 * h, (num+1) * margin + num * w:(num+1) * margin + (num + 1) * w] = y_pred_cls 547 | cv2.imwrite(os.path.join(test_dir, 'Merge_' + str(iter_time).zfill(2) + '.png'), canvas) 548 | 549 | # pesudo color representation 550 | pseudo_label = pseudoColor(y_pred_cls) 551 | beta = 1. - alpha 552 | 553 | img_3c = np.dstack((img, img, img)) 554 | overlap = cv2.addWeighted(src1=img_3c, 555 | alpha=alpha, 556 | src2=pseudo_label, 557 | beta=beta, 558 | gamma=0.0) 559 | 560 | canvas01 = np.hstack((img_3c, overlap)) 561 | cv2.imwrite(os.path.join(test_dir, 'Pred1_' + str(iter_time).zfill(2) + '.png'), canvas01) 562 | 563 | canvas02 = np.hstack((img, y_pred_cls)) 564 | cv2.imwrite(os.path.join(test_dir, 'Pred2_' + str(iter_time).zfill(2) + '.png'), canvas02) 565 | 566 | return y_pred_cls 567 | 568 | 569 | def merge_preds(preds, idx, ori_size=512, output_size=388, angle=0., test_dir=None, is_save=False, margin=5): 570 | result = np.zeros((ori_size, ori_size, 2), dtype=np.float32) 571 | 572 | # Past four corners 573 | result[:output_size, :output_size] += preds[0] 574 | result[:output_size, -output_size:] += preds[1] 575 | result[-output_size:, :output_size] += preds[2] 576 | result[-output_size:, -output_size:] += preds[3] 577 | 578 | if is_save: 579 | # Score to class 580 | result_cls = np.argmax(result, axis=2) 581 | 582 | canvas = np.zeros((2*output_size+3*margin, 2*output_size+ori_size+4*margin), dtype=np.uint8) 583 | 584 | # Score to prediction class 585 | preds_cls = np.argmax(preds, axis=3) 586 | 587 | # Copy five results 588 | canvas[margin:margin+output_size, margin:margin+output_size] = np.uint8(preds_cls[0] * 255.) 589 | canvas[margin:margin+output_size, 2*margin+output_size:2*margin+2*output_size] = np.uint8(preds_cls[1] * 255.) 590 | canvas[2*margin+output_size:2*margin+2*output_size, margin:margin+output_size] = np.uint8(preds_cls[2] * 255.) 591 | canvas[2*margin+output_size:2*margin+2*output_size, 2*margin+output_size:2*margin+2*output_size] = np.uint8(preds_cls[3] * 255.) 592 | canvas[margin:margin+ori_size, 3*margin+2*output_size:3*margin+2*output_size+ori_size] = np.uint8(result_cls * 255.) 593 | 594 | # Save results 595 | cv2.imwrite(os.path.join(test_dir, 'Merge_' + str(idx).zfill(2) + '_' + str(int(angle)).zfill(3) + '.png'), canvas) 596 | 597 | return result 598 | 599 | 600 | def pre_bilaterFilter(img, d=3, sigmaColor=75, simgaSpace=75): 601 | pre_img = cv2.bilateralFilter(src=img, d=d, sigmaColor=sigmaColor, sigmaSpace=simgaSpace) 602 | return pre_img 603 | 604 | 605 | def acc_measure(true_arr, pred_arr): 606 | cm = confusion_matrix(true_arr, pred_arr) 607 | acc = 1. * (cm[0, 0] + cm[1, 1]) / np.sum(cm) 608 | return acc 609 | --------------------------------------------------------------------------------