├── .gitignore ├── LICENSE ├── Net.py ├── README.md ├── check_tfrecords.py ├── doc ├── comparision.png ├── net.png └── schematic.png ├── fn_backbone.py ├── fn_head.py ├── fn_loss.py ├── main.py ├── postprocessing.py ├── predict.py ├── prepare_U2OScell_tfrecords.py ├── prepare_cvppp_tfrecords.py ├── preprocess.py ├── train_cell.sh ├── train_leaf.sh └── utils ├── __init__.py ├── center.py ├── data_dep ├── evaluation.py ├── retrieval_dir.py └── visulize.py ├── evaluation.py ├── img_io.py ├── process.py ├── tfrecord_creation.py ├── tfrecord_parse.py ├── tfrecord_type.py └── tfrecords_convert.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | utils/__pycache__ 3 | tfrecords_check 4 | tfrecords 5 | model* 6 | test 7 | backup -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 long.chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Net.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import time 4 | from datetime import datetime 5 | import tensorflow as tf 6 | from fn_loss import build_embedding_loss, build_dist_loss 7 | from fn_head import build_embedding_head, build_dist_head 8 | from fn_backbone import build_doubleHead 9 | from preprocess import extract_fn 10 | 11 | import sys 12 | import numpy as np 13 | import cv2 14 | 15 | MAX_IMAGE_SUMMARY = 1 16 | 17 | class LocalDisNet(object): 18 | 19 | def __init__(self, sess, flags): 20 | self.sess = sess 21 | self.backbone_fn = build_doubleHead 22 | 23 | self.flags = flags 24 | self.dtype = tf.float32 25 | 26 | self.checkpoint_dir = os.path.join(self.flags.model_dir, "checkpoint") 27 | self.summary_dir = os.path.join(self.flags.model_dir, "summary") 28 | 29 | self.image_w = 512 30 | self.image_h = 512 31 | 32 | def build_test(self): 33 | self.input = tf.placeholder(tf.float32, 34 | (None, self.image_h, self.image_w, self.flags.image_channels)) 35 | img_normalized = tf.image.per_image_standardization(self.input) 36 | features1, features2 = self.backbone_fn(inputs=img_normalized) 37 | self.embedding = build_embedding_head(features1, self.flags.embedding_dim) 38 | print("embedding branch built.") 39 | if self.flags.dist_branch: 40 | self.dist = build_dist_head(features2) 41 | print("distance regression branch built.") 42 | self.saver = tf.train.Saver(max_to_keep=2, name='checkpoint') 43 | 44 | def train(self, batch_size, training_epoches, train_dir, val_dir=None): 45 | 46 | ###################### 47 | #### prepare data #### 48 | ###################### 49 | 50 | preprocess_f = lambda sample: extract_fn(sample, 51 | image_channels=self.flags.image_channels, 52 | image_depth=self.flags.image_depth, 53 | dist_map=self.flags.dist_branch) 54 | # config training dataset 55 | train_tf = [os.path.join(train_dir, f) for f in os.listdir(train_dir)] 56 | train_ds = tf.data.TFRecordDataset(train_tf) 57 | train_ds = train_ds.map(preprocess_f) 58 | train_ds = train_ds.shuffle(buffer_size=100) 59 | train_ds = train_ds.repeat(training_epoches) 60 | train_ds = train_ds.batch(batch_size) 61 | train_iterator = train_ds.make_one_shot_iterator() 62 | train_handle = self.sess.run(train_iterator.string_handle()) 63 | # config validation dataset 64 | if val_dir is not None: 65 | val_tf = [os.path.join(val_dir, f) for f in os.listdir(val_dir)] 66 | val_ds = tf.data.TFRecordDataset(val_tf) 67 | val_ds = val_ds.map(preprocess_f) 68 | val_ds = val_ds.batch(batch_size) 69 | val_iterator = val_ds.make_initializable_iterator() 70 | val_handle = self.sess.run(val_iterator.string_handle()) 71 | # make iterator 72 | handle = tf.placeholder(tf.string, shape=[]) 73 | iterator = tf.data.Iterator.from_string_handle( 74 | handle, train_ds.output_types, train_ds.output_shapes) 75 | sample = iterator.get_next() 76 | 77 | ######################################## 78 | #### build the network and training #### 79 | ######################################## 80 | 81 | # prepare aux and summary training data 82 | self._make_aux() 83 | img_normalized = tf.image.per_image_standardization(sample['image/image']) 84 | tf.summary.image('input_image', img_normalized, max_outputs=MAX_IMAGE_SUMMARY) 85 | tf.summary.image('ground_truth', tf.cast(sample['image/label'] * 10, dtype=tf.uint8), max_outputs=MAX_IMAGE_SUMMARY) 86 | if self.flags.dist_branch: 87 | tf.summary.image('distance_map', sample['image/dist_map']*255, max_outputs=MAX_IMAGE_SUMMARY) 88 | features1, features2 = self.backbone_fn(inputs=img_normalized) 89 | # build embedding branch 90 | embedding = build_embedding_head(features1, self.flags.embedding_dim) 91 | embedding_loss = build_embedding_loss(embedding, sample['image/label'], sample['image/neighbor'], include_bg=self.flags.include_bg) 92 | tf.summary.scalar('loss_embedding', embedding_loss) 93 | tf.summary.image('emb_dim1-3', embedding[:, :, :, 0:3], max_outputs=MAX_IMAGE_SUMMARY) 94 | # build distance regression branch 95 | if self.flags.dist_branch: 96 | dist = build_dist_head(features2) 97 | dist_loss = build_dist_loss(dist, sample['image/dist_map']) 98 | train_loss = embedding_loss + dist_loss 99 | tf.summary.scalar('loss_dist', dist_loss) 100 | tf.summary.image('output_dist', dist, max_outputs=MAX_IMAGE_SUMMARY) 101 | else: 102 | train_loss = embedding_loss 103 | tf.summary.scalar('loss', train_loss) 104 | 105 | # build optimizer 106 | global_step = tf.Variable(0, trainable=False) 107 | lr = tf.train.exponential_decay(self.flags.lr, global_step, 5000, 0.9, staircase=True) 108 | tf.summary.scalar('lr', lr) 109 | opt = tf.train.AdamOptimizer(learning_rate=lr).minimize(train_loss, global_step=global_step, name="opt") 110 | 111 | # summary and checkpoint 112 | summary = tf.summary.merge_all() 113 | train_writer = tf.summary.FileWriter( 114 | os.path.join(self.summary_dir, 'train'), graph=self.sess.graph) 115 | if val_dir is not None: 116 | val_writer = tf.summary.FileWriter( 117 | os.path.join(self.summary_dir, 'val'), graph=self.sess.graph) 118 | summary_proto = tf.Summary() 119 | 120 | ######################## 121 | #### start training #### 122 | ######################## 123 | 124 | self.saver = tf.train.Saver(max_to_keep=5, name='checkpoint') 125 | t_step = self.restore_weights() 126 | if t_step <= 0: 127 | self.sess.run(tf.global_variables_initializer()) 128 | logging.info("{}: Init new training".format(datetime.now())) 129 | try: 130 | t_time = time.time() 131 | while True: 132 | t_step = t_step + 1 133 | if t_step % self.flags.summary_steps == 0 or t_step == 1: 134 | loss, _, c_summary = self.sess.run([train_loss, opt, summary], feed_dict={handle: train_handle}) 135 | train_writer.add_summary(c_summary, t_step) 136 | time_periter = (time.time() - t_time) / self.flags.summary_steps 137 | logging.info("{}: Iteration_{} ({:.4f}s/iter)".format(datetime.now(), t_step, time_periter)) 138 | t_time = time.time() 139 | else: 140 | loss, _ = self.sess.run([train_loss, opt], feed_dict={handle: train_handle}) 141 | logging.info("Training step {} loss: {}".format(t_step, loss)) 142 | 143 | # save checkpoint 144 | if t_step % self.flags.save_steps == 0: 145 | self.saver.save(self.sess, os.path.join(self.checkpoint_dir, 'model'), 146 | global_step=t_step) 147 | logging.info("{}: Iteration_{} Saved checkpoint".format(datetime.now(), t_step)) 148 | 149 | if val_dir is not None and t_step % self.flags.validation_steps == 0: 150 | v_step = 0 151 | self.sess.run(val_iterator.initializer) 152 | losses = [] 153 | while True: 154 | v_step = v_step + 1 155 | try: 156 | l = self.sess.run([train_loss], feed_dict={handle: val_handle}) 157 | losses.append(l) 158 | logging.info("Validation step {} loss: {}".format(v_step, l)) 159 | except Exception as e: 160 | val_summary = tf.Summary(value=[ 161 | tf.Summary.Value(tag="loss_val", simple_value=np.mean(losses))]) 162 | val_writer.add_summary(val_summary, t_step) 163 | break 164 | 165 | except Exception as e: 166 | logging.info(e) 167 | logging.info("{}: Done training".format(datetime.now())) 168 | 169 | def restore_model(self, ckp_dir=None): 170 | self.build_test() 171 | return self.restore_weights(ckp_dir) 172 | 173 | def restore_weights(self, ckp_dir=None): 174 | if ckp_dir is None: 175 | ckp_dir = self.checkpoint_dir 176 | latest_checkpoint = tf.train.latest_checkpoint(ckp_dir) 177 | if latest_checkpoint: 178 | step_num = int(os.path.basename(latest_checkpoint).split("-")[1]) 179 | assert step_num > 0, "Please ensure checkpoint format is model-*.*." 180 | self.saver.restore(self.sess, latest_checkpoint) 181 | logging.info("{}: Restore model from step {}. Loaded checkpoint {}" 182 | .format(datetime.now(), step_num, latest_checkpoint)) 183 | return step_num 184 | else: 185 | return 0 186 | 187 | def _make_aux(self): 188 | if not os.path.exists(self.summary_dir): 189 | os.makedirs(self.summary_dir) 190 | if not os.path.exists(self.checkpoint_dir): 191 | os.makedirs(self.checkpoint_dir) 192 | 193 | log_file = self.flags.model_dir + "/log.log" 194 | logging.basicConfig(format='%(asctime)s [%(levelname)s] %(message)s', 195 | filename=log_file, 196 | level=logging.DEBUG, 197 | filemode='w') 198 | logging.getLogger().addHandler(logging.StreamHandler()) 199 | 200 | 201 | def segment_from_seed(self, imgs, seed_thres=0.5, similarity_thres=0.7, resize=True): 202 | 203 | ''' 204 | imgs: list of images (numpy array) 205 | min_sz: minimal size of object 206 | resize: resize segments to the same size of imgs 207 | ''' 208 | import postprocessing as pp 209 | from skimage.filters import gaussian 210 | 211 | imgs_input = [] 212 | for i in range(len(imgs)): 213 | img = np.squeeze(imgs[i]) 214 | if img.shape[0:2] != (self.image_h, self.image_w): 215 | imgs_input.append(cv2.resize(img, (self.image_h, self.image_w))) 216 | else: 217 | imgs_input.append(img) 218 | imgs_input = np.array(imgs_input) 219 | 220 | if len(imgs_input.shape) == 3: 221 | imgs_input = np.expand_dims(imgs_input, axis=-1) 222 | 223 | embs, dist = self.sess.run([self.embedding, self.dist], feed_dict={self.input: imgs_input}) 224 | 225 | segs = [] 226 | for i in range(len(embs)): 227 | # get seeds 228 | dist = np.squeeze(gaussian(dist[i], sigma=3)) 229 | seeds = pp.get_seeds(dist, thres=seed_thres) 230 | # seed to instance mask 231 | emb = pp.smooth_emb(embs[i], radius=3) 232 | # emb = embs[i] 233 | seg = pp.mask_from_seeds(emb, seeds, similarity_thres=similarity_thres) 234 | # remove noise 235 | seg = pp.remove_noise(seg, dist, min_size=10, min_intensity=0.1) 236 | segs.append(seg) 237 | 238 | if resize: 239 | for i in range(len(segs)): 240 | segs[i] = cv2.resize(segs[i], (imgs[i].shape[0], imgs[i].shape[1]), interpolation=cv2.INTER_NEAREST) 241 | 242 | return segs 243 | 244 | 245 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Instance Segmentation with Pixel Embeddings 2 | 3 | [Institute of Imaging & Computer Vision, RWTH Aachen University](https://www.lfb.rwth-aachen.de/en/) 4 | 5 | This repository (InstSegv1) contains the implementation of instance segmentation approach described in the papers: 6 | 7 | - Long Chen, Martin Strauch and Dorit Merhof. 8 | Instance Segmentation of Biomedical Images with an Object-Aware Embedding Learned with Local Constraints \[[Paper](https://www.researchgate.net/publication/336396370_Instance_Segmentation_of_Biomedical_Images_with_an_Object-Aware_Embedding_Learned_with_Local_Constraints)\] 9 | International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI) 2019. 10 | 11 | Please [cite the paper(s)](#how-to-cite) if you are using this code in your research. 12 | 13 | Check [instSeg](https://github.com/looooongChen/instSeg) for reconstructed code (tensorflow 2.x) and improved work. 14 | 15 | ## Overview: 16 | 17 |

18 | 19 |

20 | 21 | To decouple the embedding branch and the distance regression branch, we construct a decode path for each. Compared to the sturcture in the paper (two branches on the same feature map), the training is of decoupled path is more robust. 22 | 23 |

24 | 25 |

26 | 27 | # Prerequisites 28 | 29 | ## Python dependencies: 30 | 31 | - tensorflow (programmed with tf 1.14) 32 | - scikit-image 33 | - opencv 34 | 35 | ## Data dependencies: 36 | 37 | In the paper, we tested the network with two datasets: 38 | - BBBC006: a hunman U2OS cell dataset [bbbc006](https://data.broadinstitute.org/bbbc/BBBC006/) 39 | - CVPPP: leaf segmentation dataset [cvppp2017](https://competitions.codalab.org/competitions/18405#learn_the_details) 40 | 41 | 42 | # Usage 43 | 44 | ## convert your dataset to tfrecords 45 | 46 | You can use the function ```def create_tf_record()``` in ```utils/tfrecords_convert.py``` to convert your dataset to tfrecords. Images will be resized to the same, distance map and neighbor relationship will be computed and saved in tfrecord files. 47 | 48 | The function requirs two python dictionary ```image_dict``` and ```gt_dict``` as inputs. Dictionary values are the path of input images and label images, respectively. Dictionary keys are only used to determine which label image corresponds to which input image, so any kind of identifier can be used as the key. 49 | 50 | Other arguments: 51 | - neighbor_distance_in_percent: [0, 1], the distance to determine neighborhood, in percentage of image width 52 | - resize: to form a training batch, images should be resized to the same 53 | - dist_map: boolen, compute distance map or not 54 | - gt_type: 'label' for label map / 'indexed' for indexed png 55 | 56 | Example scripts for converting tfrecords are provided: 57 | - BBBC006 dataset: ```prepare_U2OScell_tfrecords.py``` 58 | - CVPPP2017 dataset: ```prepare_cvppp_tfrecords.py``` 59 | 60 | Note: 61 | - label images are saved as uint16: objects in one image should not more than 216 - 1 = 65535 (0 is reserved for background) 62 | - the distance map is normalized per object and saved as uint8 63 | 64 | ## train the model 65 | 66 | To train your own model, run: 67 | ``` python train.py --phase=train ``` 68 | 69 | other options of ```main.py``` are provided to config the training, refer to ```main.py``` for details. 70 | 71 | ## prediction 72 | 73 | ``` python train.py --phase=test --test_dir=../.. --test_res=../..``` 74 | 75 | Images in ```test_dir``` will be read and segmented, with the segmentation mask saved under ```test_res```. Segmentations will be also saved as indexed png for visualization purpose. 76 | 77 | ## evaluation 78 | ``` python train.py --phase=evaluation ``` 79 | Since the file structure varies from dataset to dataset, we only provide a ```Evaluator``` class, which can report precision and recall under different IoU/F-score. You can use it to implement your own evaluation easily. Refer to ```utils/evaluation.py``` for details. 80 | 81 | # Results 82 | 83 | ## Comparision with other methods 84 |

85 | 86 |

87 | 88 | 89 | 90 | ## How to cite 91 | ``` 92 | 93 | @inproceedings{LongMACCAIInstance, 94 | author = {Long Chen, Martin Strauch, Dorit Merhof}, 95 | title = {Instance Segmentation of Biomedical Images with an Object-Aware Embedding Learned with Local Constraints}, 96 | booktitle = {MICCAI 2019}, 97 | year = {2019}, 98 | } 99 | 100 | ``` -------------------------------------------------------------------------------- /check_tfrecords.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import cv2 4 | import os 5 | import glob 6 | import time 7 | import shutil 8 | from preprocess import extract_fn 9 | from utils.img_io import save_indexed_png 10 | 11 | dist_map_included = True 12 | 13 | # dataset_dir = "./tfrecords/U2OScell/train" 14 | # image_channels = 1 15 | # image_depth = 'uint16' 16 | 17 | dataset_dir = "./tfrecords/CVPPP2017/train" 18 | image_channels = 3 19 | image_depth = 'uint8' 20 | 21 | test_dir = "./tfrecords_check" 22 | 23 | if os.path.exists(test_dir): 24 | shutil.rmtree(test_dir) 25 | os.mkdir(test_dir) 26 | time.sleep(1) 27 | 28 | tfrecords = [os.path.join(dataset_dir, f) 29 | for f in os.listdir(dataset_dir) if os.path.isfile(os.path.join(dataset_dir, f))] 30 | dataset = tf.data.TFRecordDataset(tfrecords) 31 | dataset = dataset.map(lambda x: extract_fn(x, image_channels=image_channels, image_depth=image_depth, dist_map=dist_map_included)) 32 | iterator = dataset.make_one_shot_iterator() 33 | next_element = iterator.get_next() 34 | 35 | with tf.Session() as sess: 36 | for i in range(100): 37 | sample = sess.run(next_element) 38 | print(sample['image/filename'].decode("utf-8")+": height {}, width {}".format(sample['image/height'], sample['image/width'])) 39 | print("objects in total: {}".format(sample['image/obj_count'])) 40 | 41 | cv2.imwrite(os.path.join(test_dir, 'image'+str(i)+'.tif'), sample['image/image']) 42 | save_indexed_png(os.path.join(test_dir, 'label'+str(i)+'.png'), sample['image/label'].astype(np.uint8)) 43 | if dist_map_included: 44 | cv2.imwrite(os.path.join(test_dir, 'dist'+str(i)+'.png'), sample['image/dist_map']*255) 45 | 46 | # print(sample['image/neighbor'][:,0:10]) 47 | -------------------------------------------------------------------------------- /doc/comparision.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/looooongChen/instance_segmentation_with_pixel_embeddings/113683182342db8233bd883a6a4ee33b870e06f4/doc/comparision.png -------------------------------------------------------------------------------- /doc/net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/looooongChen/instance_segmentation_with_pixel_embeddings/113683182342db8233bd883a6a4ee33b870e06f4/doc/net.png -------------------------------------------------------------------------------- /doc/schematic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/looooongChen/instance_segmentation_with_pixel_embeddings/113683182342db8233bd883a6a4ee33b870e06f4/doc/schematic.png -------------------------------------------------------------------------------- /fn_backbone.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | DROP_RATE=0.5 4 | 5 | def build_doubleHead(inputs, features=32, drop_rate=0.2, name="doubleHead_UNet"): 6 | 7 | with tf.variable_scope(name): 8 | with tf.variable_scope("Conv1"): 9 | conv1 = tf.keras.layers.Conv2D(features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(inputs) 10 | conv1 = tf.keras.layers.Conv2D(features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv1) 11 | pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1) 12 | with tf.variable_scope("Conv2"): 13 | conv2 = tf.keras.layers.Conv2D(2*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(pool1) 14 | conv2 = tf.keras.layers.Conv2D(2*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv2) 15 | pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2) 16 | with tf.variable_scope("Conv3"): 17 | conv3 = tf.keras.layers.Conv2D(4*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(pool2) 18 | conv3 = tf.keras.layers.Conv2D(4*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv3) 19 | pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3) 20 | with tf.variable_scope("Conv4"): 21 | conv4 = tf.keras.layers.Conv2D(8*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(pool3) 22 | conv4 = tf.keras.layers.Conv2D(8*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv4) 23 | drop4 = tf.keras.layers.Dropout(drop_rate)(conv4) 24 | pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(drop4) 25 | with tf.variable_scope("Conv5"): 26 | conv5 = tf.keras.layers.Conv2D(16*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(pool4) 27 | conv5 = tf.keras.layers.Conv2D(16*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv5) 28 | drop5 = tf.keras.layers.Dropout(drop_rate)(conv5) 29 | 30 | with tf.variable_scope("Conv6_1"): 31 | up6_1 = tf.keras.layers.Conv2D(8*features, 2, activation='relu',padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')( 32 | tf.keras.layers.UpSampling2D(size=(2, 2))(drop5)) 33 | merge6_1 = tf.keras.layers.concatenate([drop4, up6_1], axis=3) 34 | conv6_1 = tf.keras.layers.Conv2D(8*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(merge6_1) 35 | conv6_1 = tf.keras.layers.Conv2D(8*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv6_1) 36 | with tf.variable_scope("Conv6_2"): 37 | up6_2 = tf.keras.layers.Conv2D(8*features, 2, activation='relu',padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')( 38 | tf.keras.layers.UpSampling2D(size=(2, 2))(drop5)) 39 | merge6_2 = tf.keras.layers.concatenate([drop4, up6_2], axis=3) 40 | conv6_2 = tf.keras.layers.Conv2D(8*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(merge6_2) 41 | conv6_2 = tf.keras.layers.Conv2D(8*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv6_2) 42 | 43 | 44 | with tf.variable_scope("Conv7_1"): 45 | up7_1 = tf.keras.layers.Conv2D(4*features, 2, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')( 46 | tf.keras.layers.UpSampling2D(size=(2, 2))(conv6_1)) 47 | merge7_1 = tf.keras.layers.concatenate([conv3, up7_1], axis=3) 48 | conv7_1 = tf.keras.layers.Conv2D(4*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(merge7_1) 49 | conv7_1 = tf.keras.layers.Conv2D(4*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv7_1) 50 | with tf.variable_scope("Conv7_2"): 51 | up7_2 = tf.keras.layers.Conv2D(4*features, 2, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')( 52 | tf.keras.layers.UpSampling2D(size=(2, 2))(conv6_2)) 53 | merge7_2 = tf.keras.layers.concatenate([conv3, up7_2], axis=3) 54 | conv7_2 = tf.keras.layers.Conv2D(4*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(merge7_2) 55 | conv7_2 = tf.keras.layers.Conv2D(4*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv7_2) 56 | 57 | with tf.variable_scope("Conv8_1"): 58 | up8_1 = tf.keras.layers.Conv2D(2*features, 2, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')( 59 | tf.keras.layers.UpSampling2D(size=(2, 2))(conv7_1)) 60 | merge8_1 = tf.keras.layers.concatenate([conv2, up8_1], axis=3) 61 | conv8_1 = tf.keras.layers.Conv2D(2*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(merge8_1) 62 | conv8_1 = tf.keras.layers.Conv2D(2*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv8_1) 63 | with tf.variable_scope("Conv8_2"): 64 | up8_2 = tf.keras.layers.Conv2D(2*features, 2, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')( 65 | tf.keras.layers.UpSampling2D(size=(2, 2))(conv7_2)) 66 | merge8_2 = tf.keras.layers.concatenate([conv2, up8_2], axis=3) 67 | conv8_2 = tf.keras.layers.Conv2D(2*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(merge8_2) 68 | conv8_2 = tf.keras.layers.Conv2D(2*features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv8_2) 69 | 70 | with tf.variable_scope("Conv9_1"): 71 | up9_1 = tf.keras.layers.Conv2D(features, 2, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')( 72 | tf.keras.layers.UpSampling2D(size=(2, 2))(conv8_1)) 73 | merge9_1 = tf.keras.layers.concatenate([conv1, up9_1], axis=3) 74 | conv9_1 = tf.keras.layers.Conv2D(features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(merge9_1) 75 | conv9_1 = tf.keras.layers.Conv2D(features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv9_1) 76 | with tf.variable_scope("Conv9_2"): 77 | up9_2 = tf.keras.layers.Conv2D(features, 2, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')( 78 | tf.keras.layers.UpSampling2D(size=(2, 2))(conv8_2)) 79 | merge9_2 = tf.keras.layers.concatenate([conv1, up9_2], axis=3) 80 | conv9_2 = tf.keras.layers.Conv2D(features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(merge9_2) 81 | conv9_2 = tf.keras.layers.Conv2D(features, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(conv9_2) 82 | 83 | return conv9_1, conv9_2 -------------------------------------------------------------------------------- /fn_head.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def build_embedding_head(features, embedding_dim=16, name='embedding_branch'): 4 | 5 | with tf.variable_scope(name): 6 | features = tf.keras.layers.Dropout(rate=0.5)(features) 7 | emb = tf.keras.layers.Conv2D(embedding_dim, 3, activation='linear',padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(features) 8 | emb = tf.nn.l2_normalize(emb, axis=-1) 9 | 10 | return emb 11 | 12 | def build_dist_head(features, name='dist_regression'): 13 | with tf.variable_scope(name): 14 | features = tf.keras.layers.Dropout(rate=0.5)(features) 15 | dist = tf.keras.layers.Conv2D(1, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(features) 16 | 17 | return dist 18 | 19 | # def build_embedding_head(features, embedding_dim=16, name='embedding_branch'): 20 | 21 | # with tf.variable_scope(name): 22 | # e_conv = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(features) 23 | # e_conv = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(e_conv) 24 | # e_conv = tf.keras.layers.Dropout(rate=0.3)(e_conv) 25 | # emb = tf.keras.layers.Conv2D(embedding_dim, 3, activation='linear',padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(e_conv) 26 | # emb = tf.nn.l2_normalize(emb, axis=-1) 27 | 28 | # return emb 29 | 30 | # def build_dist_head(features, name='dist_regression'): 31 | # with tf.variable_scope(name): 32 | # d_conv = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(features) 33 | # d_conv = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(d_conv) 34 | # d_conv = tf.keras.layers.Dropout(rate=0.3)(d_conv) 35 | # dist = tf.keras.layers.Conv2D(1, 3, activation='relu', padding='same', use_bias=True, kernel_initializer='he_normal', bias_initializer='zeros')(d_conv) 36 | 37 | # return dist 38 | -------------------------------------------------------------------------------- /fn_loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def weight_fg(label): 4 | """ 5 | label: [B W H 1] 6 | """ 7 | pos = tf.greater(label, 0) 8 | neg = tf.equal(label, 0) 9 | num_pos = tf.count_nonzero(pos, axis=[1,2,3], keepdims=True, dtype=tf.float32) 10 | num_neg = tf.count_nonzero(neg, axis=[1,2,3], keepdims=True, dtype=tf.float32) 11 | total = num_neg + num_pos 12 | return tf.cast(pos, dtype=tf.float32)*total/(2*num_pos) \ 13 | + tf.cast(neg, dtype=tf.float32)*total/(2*num_neg) 14 | 15 | def build_dist_loss(dist, dist_gt, name='dist_reg_loss'): 16 | 17 | with tf.variable_scope(name): 18 | weights = weight_fg(dist_gt) 19 | dist_gt = dist_gt * 10 20 | loss = tf.square(dist-dist_gt)*weights 21 | # loss = tf.square(dist-dist_gt) 22 | 23 | return tf.reduce_mean(loss) 24 | 25 | def build_embedding_loss(embedding, label_map, neighbor, include_bg=True, name='emb_loss'): 26 | """ 27 | :param embedding: [B W H C] 28 | :param label_map: [B W H 1] 29 | :param neighbor: neighbot list 30 | :param include_bg: weather take background as an independent object 31 | """ 32 | 33 | with tf.variable_scope(name): 34 | 35 | def cond(loss, embedding, label_map, neighbor, i): 36 | return tf.less(i, tf.shape(embedding)[0]) 37 | 38 | def body(loss, embedding, label_map, neighbor, i): 39 | loss_single = embedding_loss_single_example(embedding[i], 40 | label_map[i], 41 | neighbor[i], 42 | include_bg) 43 | 44 | loss = loss.write(i, loss_single) 45 | 46 | return loss, embedding, label_map, neighbor, i+1 47 | 48 | loss = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True) 49 | loss, _, _, _, _ = tf.while_loop(cond, body, [loss, embedding, label_map, neighbor, 0]) 50 | 51 | loss = loss.stack() 52 | loss = tf.reduce_mean(loss) 53 | return loss 54 | 55 | 56 | def embedding_loss_single_example(embedding, 57 | label_map, 58 | neighbor, 59 | include_bg=True): 60 | """ 61 | build embedding loss 62 | :param embedding: 3 dim tensor, should be normalized 63 | :param label_map: 3 dim tensor with 1 channel 64 | :param neighbor: row N is the neighbors of object N, N starts with 1, 0 indicates the background 65 | :param include_bg: weather take background as an independent object 66 | """ 67 | 68 | # flatten the tensors 69 | label_flat = tf.reshape(label_map, [-1]) 70 | embedding_flat = tf.reshape(embedding, [-1, tf.shape(embedding)[-1]]) 71 | embedding_flat = tf.nn.l2_normalize(embedding_flat, axis=1) 72 | # weight_flat = tf.reshape(weight_fg(tf.expand_dims(label_map, axis=0)), [-1, 1]) 73 | 74 | # if not include background, mask out background pixels 75 | if not include_bg: 76 | label_mask = tf.greater(label_flat, 0) 77 | label_flat = tf.boolean_mask(label_flat, label_mask) 78 | embedding_flat = tf.boolean_mask(embedding_flat, label_mask) 79 | # weight_flat = tf.boolean_mask(weight_flat, label_mask) 80 | 81 | # grouping based on labels 82 | unique_labels, unique_id, counts = tf.unique_with_counts(label_flat) 83 | counts = tf.reshape(tf.cast(counts, tf.float32), (-1, 1)) 84 | segmented_sum = tf.unsorted_segment_sum(embedding_flat, unique_id, tf.size(unique_labels)) 85 | # mean embedding of each instance 86 | mu = tf.nn.l2_normalize(segmented_sum/counts, axis=1) 87 | mu_expand = tf.gather(mu, unique_id) 88 | 89 | ########################## 90 | #### inner class loss #### 91 | ########################## 92 | 93 | loss_inner = tf.losses.cosine_distance(mu_expand, embedding_flat, 94 | axis=1, 95 | # weights=weight_flat, 96 | reduction=tf.losses.Reduction.MEAN) 97 | 98 | ########################## 99 | #### inter class loss #### 100 | ########################## 101 | 102 | # repeat mu 103 | instance_num = tf.size(unique_labels) 104 | mu_interleave = tf.tile(mu, [instance_num, 1]) 105 | mu_rep = tf.tile(mu, [1, instance_num]) 106 | mu_rep = tf.reshape(mu_rep, (instance_num*instance_num, -1)) 107 | 108 | # get inter loss for each pair 109 | loss_inter = tf.losses.cosine_distance(mu_interleave, mu_rep, 110 | axis=1, 111 | reduction=tf.losses.Reduction.NONE) 112 | loss_inter = tf.abs(1-loss_inter) 113 | 114 | # compute adjacent indicator 115 | # indicator: bg(0) is adjacent to any object 116 | # 0 1 1 1 1 ... 117 | # 1 x x x x ... 118 | # 1 x x x x ... 119 | # ... 120 | bg = tf.zeros([tf.shape(neighbor)[0], 1], dtype=tf.int32) 121 | neighbor = tf.concat([bg, neighbor], axis=1) 122 | dep = instance_num if include_bg else instance_num + 1 123 | 124 | adj_indicator = tf.one_hot(neighbor, depth=dep, dtype=tf.float32) 125 | adj_indicator = tf.reduce_sum(adj_indicator, axis=1) 126 | adj_indicator = tf.cast(adj_indicator > 0, tf.float32) 127 | 128 | bg_indicator = tf.one_hot(0, depth=dep, on_value=0.0, off_value=1.0, dtype=tf.float32) 129 | bg_indicator = tf.reshape(bg_indicator, [1, -1]) 130 | indicator = tf.concat([bg_indicator, adj_indicator], axis=0) 131 | 132 | # reorder the rows and columns in the same order of unique_labels 133 | # if background (0) is not included, the first row and column will be ignores, since 0 is not the unique_labels 134 | indicator = tf.gather(indicator, unique_labels, axis=0) 135 | indicator = tf.gather(indicator, unique_labels, axis=1) 136 | inter_mask = tf.reshape(indicator, [-1, 1]) 137 | 138 | loss_inter = tf.reduce_sum(loss_inter*inter_mask)/(tf.reduce_sum(inter_mask)+1e-12) 139 | 140 | return loss_inner+loss_inter 141 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | from Net import LocalDisNet 5 | from skimage.io import imread, imsave 6 | from preprocess import extract_fn 7 | from utils.img_io import save_indexed_png 8 | import cv2 9 | from utils.evaluation import Evaluator 10 | 11 | 12 | def main(_): 13 | tf_flags = tf.app.flags.FLAGS 14 | 15 | if tf_flags.phase == "train": 16 | with tf.Session() as sess: 17 | model = LocalDisNet(sess, tf_flags) 18 | if tf_flags.validation: 19 | val_dir = os.path.join(tf_flags.val_dir) 20 | else: 21 | val_dir = None 22 | model.train(tf_flags.batch_size, 23 | tf_flags.training_epoches, 24 | os.path.join(tf_flags.train_dir), 25 | val_dir) 26 | elif tf_flags.phase == 'prediction': 27 | 28 | if not os.path.exists(tf_flags.test_res): 29 | os.makedirs(tf_flags.test_res) 30 | 31 | img_path = {f: os.path.join(tf_flags.test_dir, f) for f in os.listdir(tf_flags.test_dir)} 32 | 33 | if not os.path.exists(tf_flags.test_res): 34 | os.makedirs(tf_flags.test_res) 35 | 36 | with tf.Session() as sess: 37 | model = LocalDisNet(sess, tf_flags) 38 | model.restore_model() 39 | for f_name, f_path in img_path.items(): 40 | img = imread(f_path) 41 | print("Processing: ", f_path) 42 | segs = model.segment_from_seed([img], seed_thres=0.7, similarity_thres=0.7, resize=True) 43 | save_indexed_png(os.path.join(tf_flags.test_res, os.path.splitext(f_name)[0]+'_seg.png'), segs[0].astype(np.uint8)) 44 | 45 | elif tf_flags.phase == 'evaluation': 46 | e = Evaluator(gt_type="mask") 47 | # implement your the evaluation based on your dataset with Evaluator 48 | pass 49 | 50 | if __name__ == '__main__': 51 | 52 | tf.app.flags.DEFINE_string("phase", "train", 53 | "model phase: train/test/evaluation") 54 | 55 | # architecture config 56 | tf.app.flags.DEFINE_boolean("dist_branch", True, 57 | "whether train dist regression branch or not") 58 | tf.app.flags.DEFINE_boolean("include_bg", True, 59 | "whether include background as an independent object") 60 | tf.app.flags.DEFINE_integer("embedding_dim", 16, 61 | "dimension of the embedding") 62 | 63 | # training config 64 | tf.app.flags.DEFINE_string("train_dir", "./tfrecords/U2OScell/train", 65 | "train dataset directory") 66 | tf.app.flags.DEFINE_boolean("validation", True, 67 | "run validation during training or not, if False, --val_dir will be ignored") 68 | tf.app.flags.DEFINE_string("val_dir", "./tfrecords/U2OScell/val", 69 | "validation dataset directory") 70 | tf.app.flags.DEFINE_string("image_depth", "uint16", 71 | "depth of image: uint8/uint16") 72 | tf.app.flags.DEFINE_integer("image_channels", 3, "number of image channels") 73 | tf.app.flags.DEFINE_string("model_dir", "./model_CVPPP2017", 74 | "checkpoint and summary directory.") 75 | tf.app.flags.DEFINE_float("lr", 0.0001, 76 | "Learning Rate.") 77 | tf.app.flags.DEFINE_integer("batch_size", 4, 78 | "batch size for training.") 79 | tf.app.flags.DEFINE_integer("training_epoches", 500, 80 | "total training steps.") 81 | tf.app.flags.DEFINE_integer("summary_steps", 100, 82 | "summary period.") 83 | tf.app.flags.DEFINE_integer("save_steps", 2000, 84 | "checkpoint period.") 85 | tf.app.flags.DEFINE_integer("validation_steps", 200, 86 | "validation period.") 87 | 88 | # test config 89 | tf.app.flags.DEFINE_string("test_dir", "./test/cvppp_test", 90 | "evaluation dataset directory") 91 | tf.app.flags.DEFINE_string("test_res", "./test/cvppp_res", 92 | "evaluation dataset directory") 93 | 94 | tf.app.run(main=main) 95 | -------------------------------------------------------------------------------- /postprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.measure import regionprops, label 3 | from skimage.measure import label as assign_label 4 | from skimage.morphology import erosion as im_erosion 5 | from skimage.morphology import dilation as im_dilation 6 | from skimage.morphology import square as mor_square 7 | from skimage.feature import peak_local_max 8 | 9 | from skimage.segmentation import slic 10 | import skimage.future.graph as gf 11 | 12 | 13 | def smooth_emb(emb, radius): 14 | from scipy import ndimage 15 | from skimage.morphology import disk 16 | emb = emb.copy() 17 | w = disk(radius)/np.sum(disk(radius)) 18 | for i in range(emb.shape[-1]): 19 | emb[:, :, i] = ndimage.convolve(emb[:, :, i], w, mode='reflect') 20 | emb = emb / np.linalg.norm(emb, axis=-1, keepdims=True) 21 | return emb 22 | 23 | 24 | def get_seeds(dist_map, thres=0.7): 25 | c = np.squeeze(dist_map) 26 | mask = peak_local_max(dist_map, min_distance=10, threshold_abs=thres * c.max(), indices=False) 27 | # mask = c > thres * c.max() 28 | return mask 29 | 30 | 31 | def mask_from_seeds(embedding, seeds, similarity_thres=0.7): 32 | embedding = np.squeeze(embedding) 33 | seeds = label(seeds) 34 | props = regionprops(seeds) 35 | 36 | mean = {} 37 | for p in props: 38 | row, col = p.coords[:, 0], p.coords[:, 1] 39 | emb_mean = np.mean(embedding[row, col], axis=0) 40 | emb_mean = emb_mean/np.linalg.norm(emb_mean) 41 | mean[p.label] = emb_mean 42 | 43 | while True: 44 | dilated = im_dilation(seeds, mor_square(3)) 45 | 46 | front_r, front_c = np.nonzero(seeds != dilated) 47 | 48 | similarity = [np.dot(embedding[r, c, :], mean[dilated[r, c]]) 49 | for r, c in zip(front_r, front_c)] 50 | 51 | # bg = seeds[front_r, front_c] == 0 52 | # add_ind = np.logical_and([s > similarity_thres for s in similarity], bg) 53 | add_ind = np.array([s > similarity_thres for s in similarity]) 54 | 55 | if np.all(add_ind == False): 56 | break 57 | 58 | seeds[front_r[add_ind], front_c[add_ind]] = dilated[front_r[add_ind], front_c[add_ind]] 59 | 60 | return seeds 61 | 62 | def remove_noise(l_map, d_map, min_size=10, min_intensity=0.1): 63 | max_instensity = d_map.max() 64 | props = regionprops(l_map, intensity_image=d_map) 65 | for p in props: 66 | if p.area < min_size: 67 | l_map[l_map==p.label] = 0 68 | if p.mean_intensity/max_instensity < min_intensity: 69 | l_map[l_map==p.label] = 0 70 | return label(l_map) 71 | 72 | 73 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import numpy as np 4 | from Net import LocalDisNet 5 | from fn_backbone import build_unet, build_unet_d7 6 | # from data_utils import extract_fn 7 | # import cv2 8 | # from utils.evaluation import Evaluator 9 | 10 | 11 | def main(_): 12 | tf_flags = tf.app.flags.FLAGS 13 | # gpu config. 14 | # config = tf.ConfigProto() 15 | # config.gpu_options.per_process_gpu_memory_fraction = 0.5 16 | # config.gpu_options.allow_growth = True 17 | 18 | if tf_flags.architecture == 'd7': 19 | backbone_net = build_unet_d7 20 | else: 21 | backbone_net = build_unet 22 | 23 | if tf_flags.phase == "train": 24 | with tf.Session() as sess: 25 | train_model = LocalDisNet(sess, backbone_net, tf_flags) 26 | train_model.train(tf_flags.batch_size, 27 | tf_flags.training_epoches, 28 | os.path.join(tf_flags.train_dir), 29 | os.path.join(tf_flags.val_dir)) 30 | elif tf_flags.phase == 'test': 31 | # import skimage as ski 32 | from skimage import morphology, measure 33 | from skimage.io import imsave 34 | 35 | if not os.path.exists(tf_flags.res_dir): 36 | os.makedirs(tf_flags.res_dir) 37 | 38 | e = Evaluator() 39 | 40 | # load dataset from tfrecords 41 | val_dir = os.path.join(tf_flags.dataset_dir, 'test') 42 | val_tf = [os.path.join(val_dir, f) for f in os.listdir(val_dir)] 43 | # build dataset 44 | val_ds = tf.data.TFRecordDataset(val_tf) 45 | val_ds = val_ds.map(lambda sample: 46 | extract_fn(sample, [512, 512], augmentation=False, return_raw=True)).batch(1) 47 | val_iterator = val_ds.make_one_shot_iterator() 48 | val_example = val_iterator.get_next() 49 | 50 | with tf.Session() as sess: 51 | # test on a image pair. 52 | import csv 53 | csvfile = open('./res_cells.csv', 'w', newline='') 54 | csvwriter = csv.writer(csvfile, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL) 55 | 56 | model = net.DiscrimitiveNet(sess, backbone_net, tf_flags) 57 | 58 | while True: 59 | try: 60 | img, gt, dist, _, example_raw = sess.run(val_example) 61 | _, fname = os.path.split(example_raw['image/filename'][0].decode('utf-8')) 62 | 63 | # h, w = example_raw['image/height'][0], example_raw['image/width'][0] 64 | pred, emb = model._segment_emb(img, dist) 65 | 66 | aps = e.add_example(np.squeeze(pred), np.squeeze(gt)) 67 | print(['%.2f' % ap for ap in aps]) 68 | 69 | aps_b = ['%.4f' % ap for ap in aps] 70 | csvwriter.writerow([fname] + aps_b) 71 | e.save_last_as_image(os.path.join(tf_flags.res_dir, fname), 72 | img[0, :, : , :], thres=0.6, isBGR=False) 73 | 74 | # center = np.squeeze(center) 75 | # cv2.imwrite(os.path.join(tf_flags.res_dir, "c_"+fname), (center*25).astype(np.uint8)) 76 | # from scipy.io import savemat 77 | # savemat(os.path.join(tf_flags.res_dir, fname[:-3]+'mat'), dict(emb=emb)) 78 | 79 | except Exception as exc: 80 | print(exc) 81 | csvfile.close() 82 | print(exc) 83 | print(e.score()) 84 | break 85 | 86 | 87 | if __name__ == '__main__': 88 | 89 | tf.app.flags.DEFINE_string("phase", "train", 90 | "model phase: train/test/sparsity_tune.") 91 | 92 | # architecture config 93 | tf.app.flags.DEFINE_string("architecture", "d9", 94 | "architecture of the backbone network, d7/d9") 95 | tf.app.flags.DEFINE_boolean("include_bg", True, 96 | "whether include background as an independent object") 97 | tf.app.flags.DEFINE_integer("embedding_dim", 8, 98 | "dimension of the embedding") 99 | 100 | # training config 101 | tf.app.flags.DEFINE_string("train_dir", "./tfrecords/U2OScell/train", 102 | "dataset directory") 103 | tf.app.flags.DEFINE_string("val_dir", "./tfrecords/U2OScell/val", 104 | "dataset directory") 105 | tf.app.flags.DEFINE_string("image_depth", "uint8", 106 | "depth of image: uint8/uint16") 107 | tf.app.flags.DEFINE_integer("image_channels", 1, "number of image channels") 108 | tf.app.flags.DEFINE_string("model_dir", "./model_U2OScell", 109 | "checkpoint and summary directory.") 110 | tf.app.flags.DEFINE_string("checkpoint_prefix", "model", 111 | "checkpoint name for restoring.") 112 | tf.app.flags.DEFINE_float("lr", 0.0001, 113 | "Learning Rate.") 114 | tf.app.flags.DEFINE_integer("batch_size", 4, 115 | "batch size for training.") 116 | tf.app.flags.DEFINE_integer("training_epoches", 500, 117 | "total training steps.") 118 | tf.app.flags.DEFINE_integer("summary_steps", 100, 119 | "summary period.") 120 | tf.app.flags.DEFINE_integer("save_steps", 2000, 121 | "checkpoint period.") 122 | tf.app.flags.DEFINE_integer("validation_steps", 200, 123 | "validation period.") 124 | 125 | # test config 126 | tf.app.flags.DEFINE_string("res_dir", "./model_orthogonal", 127 | "result directory") 128 | tf.app.flags.DEFINE_boolean("keep_size", False, 129 | "resize to original size or not when testing") 130 | 131 | tf.app.run(main=main) 132 | -------------------------------------------------------------------------------- /prepare_U2OScell_tfrecords.py: -------------------------------------------------------------------------------- 1 | from utils.tfrecords_convert import create_tf_record 2 | import os 3 | import random 4 | 5 | img_dir = 'd:/Datasets/BBBC006_U2OScell/images' 6 | gt_dir = 'd:/Datasets/BBBC006_U2OScell/ground_truth' 7 | val_ratio = 0.2 8 | output_dir = './tfrecords/U2OScell' 9 | 10 | neighbor_distance_in_percent = 0.02 11 | resize = (512, 512) 12 | dist_map = True 13 | gt_type = 'label' 14 | max_neighbor = 32 15 | 16 | assert os.path.exists(img_dir) 17 | assert os.path.exists(gt_dir) 18 | if not os.path.exists(os.path.join(output_dir, 'train')): 19 | os.makedirs(os.path.join(output_dir, 'train')) 20 | if not os.path.exists(os.path.join(output_dir, 'val')): 21 | os.makedirs(os.path.join(output_dir, 'val')) 22 | 23 | img_dict = {os.path.splitext(f)[0]: os.path.join(img_dir, f) for f in os.listdir(img_dir) if os.path.isfile(os.path.join(img_dir, f))} 24 | gt_dict = {os.path.splitext(f)[0]: os.path.join(gt_dir, f) for f in os.listdir(gt_dir) if os.path.isfile(os.path.join(gt_dir, f))} 25 | 26 | for k in gt_dict.keys(): 27 | if k not in img_dict.keys(): 28 | del gt_dict[k] 29 | 30 | keys = list(gt_dict.keys()) 31 | random.shuffle(keys) 32 | split = int((1-val_ratio) * len(keys)) 33 | 34 | img_dict_train = {k: img_dict[k] for k in keys[0:split]} 35 | gt_dict_train = {k: gt_dict[k] for k in keys[0:split]} 36 | 37 | img_dict_val = {k: img_dict[k] for k in keys[split:]} 38 | gt_dict_val = {k: gt_dict[k] for k in keys[split:]} 39 | 40 | # print(img_dict) 41 | 42 | create_tf_record(img_dict_train, 43 | gt_dict_train, 44 | os.path.join(output_dir, 'train', 'train'), 45 | neighbor_distance_in_percent=neighbor_distance_in_percent, 46 | resize=resize, 47 | dist_map=dist_map, 48 | gt_type=gt_type, 49 | max_neighbor=max_neighbor) 50 | 51 | create_tf_record(img_dict_val, 52 | gt_dict_val, 53 | os.path.join(output_dir, 'val', 'val'), 54 | neighbor_distance_in_percent=neighbor_distance_in_percent, 55 | resize=resize, 56 | dist_map=dist_map, 57 | gt_type=gt_type, 58 | max_neighbor=max_neighbor) 59 | 60 | -------------------------------------------------------------------------------- /prepare_cvppp_tfrecords.py: -------------------------------------------------------------------------------- 1 | from utils.tfrecords_convert import create_tf_record 2 | import os 3 | import random 4 | 5 | image_dir = 'D:/Datasets/CVPPP2017_CodaLab/training_images' 6 | gt_dir = 'D:/Datasets/CVPPP2017_CodaLab/training_truth' 7 | examples = ['A1', 'A2', 'A3', 'A4'] 8 | val_ratio = 0.2 9 | output_dir = './tfrecords/CVPPP2017_val' 10 | # val_ratio = 0 11 | # output_dir = './tfrecords/CVPPP2017' 12 | 13 | neighbor_distance_in_percent = 0.02 14 | resize = (512, 512) 15 | dist_map = True 16 | gt_type = 'label' 17 | max_neighbor = 32 18 | 19 | 20 | img_dict = {} 21 | gt_dict = {} 22 | for g in examples: 23 | for f in os.listdir(os.path.join(image_dir, g)): 24 | b, _ = os.path.splitext(f) 25 | img_dict[g+'_'+b] = os.path.join(image_dir, g, f) 26 | for f in os.listdir(os.path.join(gt_dir, g)): 27 | b, _ = os.path.splitext(f) 28 | gt_dict[g+'_'+b] = os.path.join(gt_dir, g, f) 29 | 30 | keys = list(img_dict.keys()) 31 | random.shuffle(keys) 32 | split = int((1-val_ratio) * len(keys)) 33 | 34 | if not os.path.exists(os.path.join(output_dir, 'train')): 35 | os.makedirs(os.path.join(output_dir, 'train')) 36 | if split == len(keys): 37 | create_tf_record(img_dict, 38 | gt_dict, 39 | os.path.join(output_dir, 'train', 'train'), 40 | neighbor_distance_in_percent=neighbor_distance_in_percent, 41 | resize=resize, 42 | dist_map=dist_map, 43 | gt_type=gt_type, 44 | max_neighbor=max_neighbor) 45 | else: 46 | img_dict_train = {k: img_dict[k] for k in keys[0:split]} 47 | gt_dict_train = {k: gt_dict[k] for k in keys[0:split]} 48 | 49 | img_dict_val = {k: img_dict[k] for k in keys[split:]} 50 | gt_dict_val = {k: gt_dict[k] for k in keys[split:]} 51 | 52 | create_tf_record(img_dict_train, 53 | gt_dict_train, 54 | os.path.join(output_dir, 'train', 'train'), 55 | neighbor_distance_in_percent=neighbor_distance_in_percent, 56 | resize=resize, 57 | dist_map=dist_map, 58 | gt_type=gt_type, 59 | max_neighbor=max_neighbor) 60 | 61 | if not os.path.exists(os.path.join(output_dir, 'val')): 62 | os.makedirs(os.path.join(output_dir, 'val')) 63 | create_tf_record(img_dict_val, 64 | gt_dict_val, 65 | os.path.join(output_dir, 'val', 'val'), 66 | neighbor_distance_in_percent=neighbor_distance_in_percent, 67 | resize=resize, 68 | dist_map=dist_map, 69 | gt_type=gt_type, 70 | max_neighbor=max_neighbor) 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from utils.tfrecord_parse import extract_fn_local_dis 3 | 4 | 5 | MAX_INSTANCE = 500 6 | 7 | def extract_fn(data_record, 8 | image_channels, 9 | image_depth='uint16', 10 | dist_map=False): 11 | 12 | sample = extract_fn_local_dis(data_record, image_depth=image_depth, dist_map=dist_map) 13 | 14 | sample['image/image'].set_shape([None, None, image_channels]) 15 | # sample['image/image'] = tf.image.per_image_standardization(sample['image/image']) # moved to Net.py 16 | sample['image/label'] = tf.cast(sample['image/label'], dtype=tf.int32) 17 | if dist_map: 18 | sample['image/dist_map'] = sample['image/dist_map']/255 19 | 20 | # recommend not to use tensorflow resize function !!! not !!! 21 | # if resize is not None: 22 | # img = tf.image.resize_images(img, resize, align_corners=False) 23 | # label = tf.image.resize_images(label, resize, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, align_corners=False) 24 | # if dist_map: 25 | # d_map = tf.image.resize_images(d_map, resize, align_corners=False) 26 | 27 | # pad neighbor list to the same size (to form batches, the same size is required) 28 | neighbor = sample['image/neighbor'] 29 | padding = [[0, MAX_INSTANCE - tf.shape(neighbor)[0]], [0, 0]] 30 | sample['image/neighbor'] = tf.pad(neighbor, padding, 'CONSTANT', constant_values=0) 31 | 32 | return sample 33 | 34 | if __name__ == "__main__": 35 | tf.enable_eager_execution() 36 | dataset = tf.data.TFRecordDataset(['d:/Datasets/DSB2018/tfrecords/stage1_train/DSB2018.record-00000-of-00005']) 37 | dataset = dataset.map(lambda example: extract_fn(example, [512, 512])) 38 | iterator = dataset.make_one_shot_iterator() 39 | img, label, dist, neigbro = iterator.get_next() 40 | 41 | import numpy as np 42 | l = np.min(label) 43 | print(l) 44 | -------------------------------------------------------------------------------- /train_cell.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | phase="train" 4 | dist_branch=True 5 | include_bg=True 6 | embedding_dim=16 7 | 8 | train_dir="./tfrecords/U2OScell/train" 9 | validation=True 10 | val_dir="./tfrecords/U2OScell/val" 11 | image_depth="uint16" 12 | image_channels=1 13 | model_dir="./model_U2OScell" 14 | 15 | lr=0.0001 16 | batch_size=4 17 | training_epoches=300 18 | 19 | 20 | cd /work/scratch/chen/instance_segmentation_with_pixel_embeddings 21 | 22 | /home/staff/chen/miniconda3/envs/tf/bin/python /work/scratch/chen/instance_segmentation_with_pixel_embeddings/main.py \ 23 | --phase="$phase" \ 24 | --dist_branch="$dist_branch" \ 25 | --include_bg="$include_bg" \ 26 | --embedding_dim="$embedding_dim" \ 27 | --train_dir="$train_dir" \ 28 | --validation="$validation" \ 29 | --val_dir="$val_dir"\ 30 | --image_depth="$image_depth" \ 31 | --image_channels="$image_channels" \ 32 | --model_dir="$model_dir" \ 33 | --lr="$lr" \ 34 | --batch_size="$batch_size" \ 35 | --training_epoches="$training_epoches" -------------------------------------------------------------------------------- /train_leaf.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | phase="train" 4 | dist_branch=True 5 | include_bg=True 6 | embedding_dim=16 7 | 8 | train_dir="./tfrecords/CVPPP2017/train" 9 | validation=False 10 | # val_dir="./tfrecords/CVPPP2017_val/val" 11 | image_depth="uint8" 12 | image_channels=3 13 | model_dir="./model_CVPPP2017" 14 | 15 | lr=0.0001 16 | batch_size=4 17 | training_epoches=300 18 | 19 | 20 | cd /work/scratch/chen/instance_segmentation_with_pixel_embeddings 21 | 22 | /home/staff/chen/miniconda3/envs/tf/bin/python /work/scratch/chen/instance_segmentation_with_pixel_embeddings/main.py \ 23 | --phase="$phase" \ 24 | --dist_branch="$dist_branch" \ 25 | --include_bg="$include_bg" \ 26 | --embedding_dim="$embedding_dim" \ 27 | --train_dir="$train_dir" \ 28 | --validation="$validation" \ 29 | --val_dir="$val_dir"\ 30 | --image_depth="$image_depth" \ 31 | --image_channels="$image_channels" \ 32 | --model_dir="$model_dir" \ 33 | --lr="$lr" \ 34 | --batch_size="$batch_size" \ 35 | --training_epoches="$training_epoches" -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/looooongChen/instance_segmentation_with_pixel_embeddings/113683182342db8233bd883a6a4ee33b870e06f4/utils/__init__.py -------------------------------------------------------------------------------- /utils/center.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import cv2 4 | from skimage import io 5 | from skimage.util import pad 6 | from skimage.draw import circle 7 | 8 | 9 | def mask2contour(img): 10 | """ 11 | Input: Single channel uint8 image 12 | Ouput: Contour list, non-approximated 13 | Step: Unconnect adjacent cell masks then use cv2.findContours 14 | """ 15 | img_dilate = cv2.dilate(img, np.ones((3,3),np.uint8), iterations=1) 16 | img_unconnected = np.where(img==0, 0, 255).astype('uint8') 17 | img_unconnected = np.where(img_dilate-img != 0, 0, img_unconnected).astype('uint8') 18 | 19 | ret = cv2.findContours(img_unconnected, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) 20 | contours = ret[-2] 21 | print(len(contours), 'contours found!') 22 | return contours 23 | 24 | 25 | def ROI(image, contourlist, padded=True): 26 | """ 27 | Crop ROI with same label 28 | Return list ROIs and list of BBOXes (minX, maxX, minY, maxY) 29 | """ 30 | roilist = [] 31 | BBOXlist = [] 32 | 33 | for ct in contourlist: 34 | bbox = cv2.boundingRect(ct) 35 | bbox = (bbox[1],bbox[0],bbox[3],bbox[2]) 36 | coordiBBOX = (bbox[0], bbox[0]+bbox[2], bbox[1], bbox[1]+bbox[3]) 37 | roi = image[bbox[0]:bbox[0]+bbox[2], bbox[1]:bbox[1]+bbox[3]] 38 | 39 | (values,counts) = np.unique(roi,return_counts=True) 40 | label = values[np.argmax(counts)] 41 | roi = np.where(roi == label, roi, 0) 42 | if padded: 43 | roi = pad(roi, pad_width=1, mode="constant") 44 | roilist.append(roi) 45 | BBOXlist.append(coordiBBOX) 46 | return roilist, BBOXlist 47 | 48 | 49 | def checkConti(arrayC): 50 | """ 51 | Check if the labelled pixels are continuous 52 | along in the vector. 53 | 00111111100 -> True 54 | 00111001100 -> False 55 | """ 56 | diff = arrayC - np.roll(arrayC, 1) 57 | changeFlag = np.count_nonzero(diff) 58 | if changeFlag > 2: 59 | return False 60 | else: 61 | return True 62 | 63 | 64 | def roi2origin(x, y, coordiBBOX, padded=True): 65 | """ 66 | Coordinates transformation 67 | from ROI to original large image 68 | """ 69 | if padded: 70 | x_origin = x + coordiBBOX[0] - 1 71 | y_origin = y + coordiBBOX[2] - 1 72 | else: 73 | x_origin = x + coordiBBOX[0] 74 | y_origin = y + coordiBBOX[2] 75 | return x_origin, y_origin 76 | 77 | 78 | 79 | def main(filename): 80 | t = time.time() 81 | img = io.imread(filename) 82 | 83 | ROIlist, BBOXlist = ROI(img, mask2contour(img)) 84 | for roi, coordiBBOX in zip(ROIlist, BBOXlist): 85 | contiX = [] 86 | contiY = [] 87 | midX_origin = [] 88 | midY_origin = [] 89 | 90 | for x in range(roi.shape[0]): 91 | if checkConti(roi[x,:]): 92 | contiX.append(x) 93 | for y in range(roi.shape[1]): 94 | if checkConti(roi[:,y].T): 95 | contiY.append(y) 96 | midX = contiX[ np.int(len(contiX)/2) ] 97 | midY = contiY[ np.int(len(contiY)/2) ] 98 | 99 | x_origin, y_origin = roi2origin(midX, midY, coordiBBOX, padded=True) 100 | midX_origin.append(x_origin) 101 | midY_origin.append(y_origin) 102 | 103 | for x, y in zip(midX_origin, midY_origin): 104 | rr, cc = circle(x, y, 3) 105 | img[rr, cc] = 0 106 | 107 | print(time.time() - t) 108 | io.imshow(img) 109 | io.show() 110 | 111 | main(filename="img1.png") -------------------------------------------------------------------------------- /utils/data_dep/evaluation.py: -------------------------------------------------------------------------------- 1 | import skimage as ski 2 | import numpy as np 3 | import os 4 | from .visulize import visulize_mask 5 | 6 | 7 | class Evaluator(object): 8 | 9 | def __init__(self, thres=None): 10 | # self.type = type 11 | self.APs = [] 12 | self.mAP = None 13 | if thres is None: 14 | self.thres = np.arange(start=0.5, stop=0.95, step=0.05) 15 | else: 16 | self.thres = thres 17 | self.ap_dict = {i: [] for i, _ in enumerate(self.thres)} 18 | self.examples = [] 19 | 20 | def add_example(self, pred, gt): 21 | e = Example(pred, gt) 22 | self.examples.append(e) 23 | 24 | aps = [] 25 | for i, t in enumerate(self.thres): 26 | ap_ = e.get_ap(t) 27 | self.ap_dict[i].append(ap_) 28 | aps.append(ap_) 29 | return aps 30 | 31 | def save_last_as_image(self, fname, bg_image, thres=0.5, isBGR=False): 32 | self.examples[-1].save_as_image(fname, bg_image, thres=thres, isBGR=isBGR) 33 | 34 | def score(self): 35 | for i, _ in enumerate(self.thres): 36 | self.APs.append(np.mean(self.ap_dict[i])) 37 | self.mAP = np.mean(self.APs) 38 | return self.mAP, self.APs 39 | 40 | 41 | class Example(object): 42 | 43 | """ 44 | class for a prediction-ground truth pair 45 | """ 46 | 47 | def __init__(self, pred, gt): 48 | self.pred = pred 49 | self.gt = gt 50 | self.gt_num = len(np.unique(gt)) - 1 51 | self.IoU_dict = {} # (prediction label)-(IoU) 52 | self.match_dict = {} # (prediction label)-(matched gt label) 53 | 54 | self.match_non_overlap(pred, gt) 55 | 56 | def match_non_overlap(self, pred, gt): 57 | pred_area = self.get_area_dict(pred) 58 | gt_area = self.get_area_dict(gt) 59 | unique = np.unique(pred) 60 | 61 | for label in unique: 62 | if label == 0: 63 | continue 64 | u, c = np.unique(gt[pred == label], return_counts=True) 65 | ind = np.argsort(c, kind='mergesort') 66 | if len(u) == 1 and u[ind[-1]] == 0: 67 | # only contain background 68 | self.IoU_dict[label] = 0 69 | self.match_dict[label] = None 70 | else: 71 | # take the gt label with the largest overlap 72 | i = ind[-2] if u[ind[-1]] == 0 else ind[-1] 73 | union = c[i] 74 | intersect = pred_area[label] + gt_area[u[i]] - c[i] 75 | self.IoU_dict[label] = union/intersect 76 | self.match_dict[label] = u[i] 77 | 78 | def get_area_dict(self, label_map): 79 | props = ski.measure.regionprops(label_map) 80 | return {p.label: p.area for p in props} 81 | 82 | def get_ap(self, thres): 83 | """ 84 | compute ap for a certain dice value 85 | :param thres: dice value 86 | :return: ap 87 | """ 88 | tp = 0 89 | match_gt = [] 90 | for k, value in self.IoU_dict.items(): 91 | if value > thres: 92 | tp = tp + 1 93 | match_gt.append(self.match_dict[k]) 94 | tp_fp = len(self.IoU_dict) 95 | fn = self.gt_num -len(match_gt) 96 | return tp/(tp_fp+fn) 97 | 98 | def save_as_image(self, fname, bg_image, thres, isBGR=False): 99 | """ 100 | save a visualization image, plot match in blue, non-match in red 101 | :param fname: path to save the image 102 | :param bg_image: original image 103 | :param thres: the dice value to determine match/non-match 104 | :param isBGR: 105 | :return: 106 | """ 107 | if len(bg_image.shape) == 3 and isBGR: 108 | bg_image = bg_image[:, :, ::-1] 109 | tp = self.pred.copy() 110 | fp = self.pred.copy() 111 | for k, value in self.IoU_dict.items(): 112 | if value > thres: 113 | fp[fp == k] = 0 114 | else: 115 | tp[tp == k] = 0 116 | res = visulize_mask(bg_image, tp, fill=True, color='blue') 117 | res = visulize_mask(res, fp, fill=True, color='red') 118 | print("Vis saved in: " + fname) 119 | ski.io.imsave(fname, res) 120 | return res 121 | 122 | 123 | if __name__ == "__main__": 124 | pred = ski.io.imread('./test/pre_1.png', as_gray=True) 125 | gt = ski.io.imread('./test/gt_1.png', as_gray=True) 126 | pred = ski.measure.label(pred) 127 | e = Evaluator() 128 | e.add_example(pred, gt) 129 | e.save_last_as_image('./test.png', gt, 0.5) 130 | # pred = ski.io.imread('./test/pre_2.png', as_gray=True) 131 | # gt = ski.io.imread('./test/gt_2.png', as_gray=True) 132 | # pred = ski.measure.label(pred) 133 | # e.evaluate_single(pred, gt) 134 | # print(e.score()) 135 | # # pred = gt 136 | # IoU_dict, match_dict = match_non_overlap(pred, gt) 137 | # print(IoU_dict) 138 | # print(match_dict) 139 | # print(get_ap_non_overlap(pred, gt, 0.8)) 140 | # print(evaluate_dir_no_overlap('./test', "pre", "gt", 0.8)) 141 | -------------------------------------------------------------------------------- /utils/data_dep/retrieval_dir.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import listdir 3 | from os.path import isfile, join 4 | 5 | 6 | def ls_files(dir, ext=None, key_with_ext=True): 7 | f_dict = {} 8 | if ext is not None: 9 | ext = '.'+ext if ext[0] != '.' else ext 10 | for f in os.listdir(dir): 11 | if f[0] == '.': 12 | continue 13 | 14 | b, e = os.path.splitext(f) 15 | if ext is not None: 16 | if e != ext: 17 | continue 18 | 19 | file_path = os.path.join(dir, f) 20 | if os.path.isfile(file_path): 21 | if key_with_ext: 22 | f_dict[f] = file_path 23 | else: 24 | f_dict[b] = file_path 25 | 26 | return f_dict 27 | 28 | 29 | # def ls_files_with_suffix(dir, suffix, ext=None, join_base=True): 30 | # files = ls_files(dir, ext=ext, join_base=False) 31 | # filtered_files = [] 32 | # for f in files: 33 | # name, _ = os.path.splitext(f) 34 | # if name.endswith(suffix): 35 | # file_path = os.path.join(dir, f) if join_base else f 36 | # filtered_files.append(file_path) 37 | # return filtered_files 38 | 39 | 40 | # def ls_files_with_prefix(dir, prefix, ext=None, join_base=True): 41 | # files = ls_files(dir, ext=ext, join_base=False) 42 | # filtered_files = [] 43 | # for f in files: 44 | # name, _ = os.path.splitext(f) 45 | # if name.startswith(prefix): 46 | # file_path = os.path.join(dir, f) if join_base else f 47 | # filtered_files.append(file_path) 48 | # return filtered_files 49 | 50 | 51 | # def ls_dirs(root_dir): 52 | # dirs = [] 53 | # # ids = [d for d in os.listdir(train_dir)] 54 | # for d in os.listdir(root_dir): 55 | # dir = os.path.join(root_dir, d) 56 | # if os.path.isdir(dir): 57 | # dirs.append(dir) 58 | # return dirs 59 | 60 | 61 | if __name__ == '__main__': 62 | # dir = 'D:/Datasets/DSB2018' 63 | # print(ls_dirs(dir)) 64 | # print(ls_files(dir, ext='.csv')) 65 | 66 | dir = 'D:\Datasets\BBBC006_U2OScell\ground_truth' 67 | d = ls_files(dir, ext='.png', key_with_ext=False) 68 | print(d) 69 | 70 | -------------------------------------------------------------------------------- /utils/data_dep/visulize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage as ski 3 | 4 | RED = (210, 20, 55) 5 | BLUE = (25, 100, 230) 6 | 7 | 8 | def mask_color_img(img, mask, color='red', alpha=0.3): 9 | if color == 'red': 10 | color = RED 11 | elif color == 'blue': 12 | color = BLUE 13 | elif isinstance(color, str): 14 | color = BLUE 15 | else: 16 | color = color 17 | 18 | if img.ndim != 3: 19 | img = ski.color.gray2rgb(img) 20 | mask = (mask > 0).astype(np.uint8) 21 | layer = img.copy() 22 | for i in range(3): 23 | layer[:, :, i] = np.multiply(layer[:, :, i], 1-mask)+color[i]*mask 24 | res = (1-alpha)*img + alpha*layer 25 | return res.astype(np.uint8) 26 | 27 | 28 | def visulize_mask(img, label_map, fill=False, color='blue'): 29 | b = get_boundary_from_label_map(label_map) 30 | overlayed = mask_color_img(img, b, color=color, alpha=0.8) 31 | if fill: 32 | mask = np.multiply((label_map > 0).astype(np.uint8), 1-b) 33 | overlayed = mask_color_img(overlayed, mask, color=color, alpha=0.2) 34 | return overlayed 35 | 36 | 37 | def get_boundary_from_label_map(map): 38 | map = map.copy() 39 | 40 | kernel = np.ones((3, 3), np.uint8) 41 | erosion = ski.morphology.erosion(map, kernel) 42 | dilation = ski.morphology.dilation(map, kernel) 43 | boundary = np.not_equal(erosion, dilation) 44 | return boundary 45 | 46 | 47 | if __name__ == "__main__": 48 | from skimage.io import imsave, imread 49 | import skimage as ski 50 | pred = imread('./test/pre_2.png', as_gray=True) 51 | gt = imread('./test/gt_2.png') 52 | pred = ski.measure.label(pred) 53 | 54 | im = visulize_mask(gt, pred, fill=True) 55 | imsave('./vis.png', im.astype(np.uint8)) 56 | -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import skimage as ski 2 | from skimage.morphology import binary_dilation, disk 3 | import numpy as np 4 | import os 5 | from scipy.spatial import distance_matrix 6 | 7 | 8 | class Evaluator(object): 9 | 10 | def __init__(self, thres=None, gt_type="mask", line_match_thres=3): 11 | # self.type = type 12 | 13 | if thres is None: 14 | # self.thres = np.arange(start=0.5, stop=1, step=0.1) 15 | self.thres = [0.5, 0.6, 0.7, 0.8, 0.9] 16 | else: 17 | self.thres = thres 18 | 19 | self.gt_type = gt_type 20 | self.line_match_thres = line_match_thres 21 | 22 | self.examples = [] 23 | self.total_pred = 0 24 | self.total_gt = 0 25 | 26 | # self.IoU = [] # (prediction label)-(IoU) 27 | # self.recall = [] 28 | # self.precision = [] 29 | 30 | 31 | def add_example(self, pred, gt): 32 | e = Example(pred, gt, self.gt_type, self.line_match_thres) 33 | self.examples.append(e) 34 | 35 | self.total_pred += e.pred_num 36 | self.total_gt += e.gt_num 37 | print("example added, total: ", len(self.examples)) 38 | # self.IoU[0:0] = list(e.IoU.values()) 39 | # self.recall[0:0] = list(e.recall.values()) 40 | # self.precision[0:0] = list(e.precision.values()) 41 | 42 | def eval(self, metric='IoU'): 43 | 44 | res = {} 45 | for t in self.thres: 46 | pred_match = 0 47 | gt_match = 0 48 | for e in self.examples: 49 | p_m, g_m = e.return_match_num(t, metric) 50 | pred_match += p_m 51 | gt_match += g_m 52 | res[metric + '_' + str(t)] = [pred_match/self.total_pred, gt_match/self.total_gt] 53 | 54 | for k, v in res.items(): 55 | print(k, v) 56 | 57 | 58 | # def save_last_as_image(self, fname, bg_image, thres=0.5, isBGR=False): 59 | # self.examples[-1].save_as_image(fname, bg_image, thres=thres, isBGR=isBGR) 60 | 61 | # def score(self): 62 | # for i, _ in enumerate(self.thres): 63 | # self.APs.append(np.mean(self.ap_dict[i])) 64 | # self.mAP = np.mean(self.APs) 65 | # return self.mAP, self.APs 66 | 67 | 68 | class Example(object): 69 | 70 | """ 71 | class for a prediction-ground truth pair 72 | single_slide: faster when object number is high, but can not handle overlap 73 | type: "line or "mask" 74 | """ 75 | 76 | def __init__(self, pred, gt, gt_type='mask', line_match_thres=3): 77 | self.gt_type = gt_type 78 | self.line_match_thres = line_match_thres 79 | 80 | pred = np.squeeze(pred) 81 | gt = np.squeeze(gt) 82 | 83 | if pred.ndim == 2 and gt.ndim == 2: 84 | self.single_slide = True 85 | self.pred = ski.measure.label(pred>0) 86 | self.gt = ski.measure.label(gt>0) 87 | self.gt_num = len(np.unique(self.gt)) - 1 88 | self.pred_num = len(np.unique(self.pred)) - 1 89 | else: 90 | self.single_slide = False 91 | self.pred = self.map2stack(pred) 92 | self.gt = self.map2stack(gt) 93 | self.gt_num = self.gt.shape[0] 94 | self.pred_num = self.pred.shape[0] 95 | 96 | self.match_dict = {} # (prediction label)-(matched gt label) 97 | self.IoU = {} # (prediction label)-(IoU) 98 | self.recall = {} 99 | self.precision = {} 100 | 101 | self._match_non_overlap() 102 | 103 | # print(len(self.match_dict), len(self.IoU), len(self.recall), len(self.precision), self.gt_num, self.pred_num) 104 | 105 | def _match_non_overlap(self): 106 | self.pred_area = self.get_area_dict(self.pred) 107 | self.gt_area = self.get_area_dict(self.gt) 108 | 109 | for label, pred_area in self.pred_area.items(): 110 | self.IoU[label] = 0 111 | self.match_dict[label] = 0 112 | self.recall[label] = 0 113 | self.precision[label] = 0 114 | if self.gt_type == "mask": 115 | if self.single_slide: 116 | u, c = np.unique(self.gt[self.pred == label], return_counts=True) 117 | ind = np.argsort(c, kind='mergesort') 118 | if len(u) == 1 and u[ind[-1]] == 0: 119 | # only contain background 120 | self.IoU[label] = 0 121 | self.match_dict[label] = 0 122 | self.recall[label] = 0 123 | self.precision[label] = 0 124 | else: 125 | # take the gt label with the largest overlap 126 | i = ind[-2] if u[ind[-1]] == 0 else ind[-1] 127 | intersect = c[i] 128 | union = pred_area + self.gt_area[u[i]] - intersect 129 | self.IoU[label] = intersect/union 130 | self.match_dict[label] = u[i] 131 | self.recall[label] = intersect/self.gt_area[u[i]] 132 | self.precision[label] = intersect/pred_area 133 | else: 134 | intersect = np.multiply(self.gt, np.expand_dims(self.pred[label-1], axis=0)) 135 | intersect = np.sum(intersect, axis=(1,2)) 136 | ind = np.argsort(intersect, kind='mergesort') 137 | if intersect[ind[-1]] == 0: 138 | # no overlapp with any object 139 | self.IoU[label] = 0 140 | self.match_dict[label] = 0 141 | self.recall[label] = 0 142 | self.precision[label] = 0 143 | else: 144 | # take the gt label with the largest overlap 145 | union = pred_area + self.gt_area[ind[-1]+1] - intersect[ind[-1]] 146 | self.IoU[label] = intersect[ind[-1]]/union 147 | self.match_dict[label] = ind[-1] + 1 148 | self.recall[label] = intersect[ind[-1]]/self.gt_area[ind[-1]+1] 149 | self.precision[label] = intersect[ind[-1]]/pred_area 150 | else: 151 | intersect = [] 152 | if self.single_slide: 153 | pts_pred = np.transpose(np.array(np.nonzero(self.pred==label))) 154 | for l in np.unique(self.gt): 155 | if l == 0: 156 | continue 157 | pts_gt = np.transpose(np.array(np.nonzero(self.gt==l))) 158 | bpGraph = distance_matrix(pts_pred, pts_gt) < self.line_match_thres 159 | g = GFG(bpGraph) 160 | intersect.append(g.maxBPM()) 161 | else: 162 | pts_pred = np.transpose(np.array(np.nonzero(self.pred[label-1]>0))) 163 | for g in self.gt: 164 | pts_gt = np.transpose(np.array(np.nonzero(g>0))) 165 | bpGraph = distance_matrix(pts_pred, pts_gt) < self.line_match_thres 166 | g = GFG(bpGraph) 167 | intersect.append(g.maxBPM()) 168 | 169 | if len(intersect) != 0: 170 | intersect = np.array(intersect) 171 | ind = np.argsort(intersect, kind='mergesort') 172 | if intersect[ind[-1]] != 0: 173 | # take the gt label with the largest overlap 174 | union = pred_area + self.gt_area[ind[-1]+1] - intersect[ind[-1]] 175 | self.IoU[label] = intersect[ind[-1]]/union 176 | self.match_dict[label] = ind[-1] + 1 177 | self.recall[label] = intersect[ind[-1]]/self.gt_area[ind[-1]+1] 178 | self.precision[label] = intersect[ind[-1]]/pred_area 179 | 180 | def get_area_dict(self, label_map): 181 | if self.single_slide: 182 | props = ski.measure.regionprops(label_map) 183 | area_dict = {p.label: p.area for p in props} 184 | else: 185 | area_dict = {i+1: np.sum(label_map[i]>0) for i in range(label_map.shape[0])} 186 | if 0 in area_dict.keys(): 187 | del area_dict[0] 188 | return area_dict 189 | 190 | def map2stack(self, map): 191 | map = np.squeeze(map) 192 | if map.ndim == 2: 193 | stack = [] 194 | for l in np.unique(map): 195 | if l == 0: 196 | continue 197 | stack.append(map==l) 198 | return np.array(stack)>0 199 | else: 200 | return map>0 201 | 202 | def return_match_num(self, thres, metric='IoU'): 203 | match_label = np.array(list(self.match_dict.values())) 204 | if metric=='F': 205 | ind = (np.array(list(self.precision.values())) + np.array(list(self.recall.values())))/2 > thres 206 | else: 207 | ind = np.array(list(self.IoU.values())) > thres 208 | return np.sum(ind), len(np.unique(match_label[ind])) 209 | 210 | class GFG: 211 | # maximal Bipartite matching. 212 | def __init__(self,graph): 213 | 214 | # residual graph 215 | self.graph = graph 216 | self.ppl = len(graph) 217 | self.jobs = len(graph[0]) 218 | 219 | # A DFS based recursive function 220 | # that returns true if a matching 221 | # for vertex u is possible 222 | def bpm(self, u, matchR, seen): 223 | 224 | # Try every job one by one 225 | for v in range(self.jobs): 226 | 227 | # If applicant u is interested 228 | # in job v and v is not seen 229 | if self.graph[u][v] and seen[v] == False: 230 | 231 | # Mark v as visited 232 | seen[v] = True 233 | 234 | '''If job 'v' is not assigned to 235 | an applicant OR previously assigned 236 | applicant for job v (which is matchR[v]) 237 | has an alternate job available. 238 | Since v is marked as visited in the 239 | above line, matchR[v] in the following 240 | recursive call will not get job 'v' again''' 241 | if matchR[v] == -1 or self.bpm(matchR[v], 242 | matchR, seen): 243 | matchR[v] = u 244 | return True 245 | return False 246 | 247 | # Returns maximum number of matching 248 | def maxBPM(self): 249 | '''An array to keep track of the 250 | applicants assigned to jobs. 251 | The value of matchR[i] is the 252 | applicant number assigned to job i, 253 | the value -1 indicates nobody is assigned.''' 254 | matchR = [-1] * self.jobs 255 | 256 | # Count of jobs assigned to applicants 257 | result = 0 258 | for i in range(self.ppl): 259 | 260 | # Mark all jobs as not seen for next applicant. 261 | seen = [False] * self.jobs 262 | 263 | # Find if the applicant 'u' can get a job 264 | if self.bpm(i, matchR, seen): 265 | result += 1 266 | return result 267 | -------------------------------------------------------------------------------- /utils/img_io.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import math 4 | 5 | 6 | P = [252, 233, 79, 114, 159, 207, 239, 41, 41, 173, 127, 168, 138, 226, 52, 7 | 233, 185, 110, 252, 175, 62, 211, 215, 207, 196, 160, 0, 32, 74, 135, 164, 0, 0, 8 | 92, 53, 102, 78, 154, 6, 143, 89, 2, 206, 92, 0, 136, 138, 133, 237, 212, 0, 52, 9 | 101, 164, 204, 0, 0, 117, 80, 123, 115, 210, 22, 193, 125, 17, 245, 121, 0, 186, 10 | 189, 182, 85, 87, 83, 46, 52, 54, 238, 238, 236, 0, 0, 10, 252, 233, 89, 114, 159, 11 | 217, 239, 41, 51, 173, 127, 178, 138, 226, 62, 233, 185, 120, 252, 175, 72, 211, 215, 12 | 217, 196, 160, 10, 32, 74, 145, 164, 0, 10, 92, 53, 112, 78, 154, 16, 143, 89, 12, 13 | 206, 92, 10, 136, 138, 143, 237, 212, 10, 52, 101, 174, 204, 0, 10, 117, 80, 133, 115, 14 | 210, 32, 193, 125, 27, 245, 121, 10, 186, 189, 192, 85, 87, 93, 46, 52, 64, 238, 238, 246] 15 | 16 | P = P * math.floor(255*3/len(P)) 17 | l = int(255 - len(P)/3) 18 | P = P + P[3:(l+1)*3] 19 | P = [0,0,0] + P 20 | 21 | 22 | 23 | def read_indexed_png(fname): 24 | im = Image.open(fname) 25 | palette = im.getpalette() 26 | im = np.array(im) 27 | return im, palette 28 | 29 | 30 | def save_indexed_png(fname, label_map, palette=P): 31 | label_map = np.squeeze(label_map.astype(np.uint8)) 32 | im = Image.fromarray(label_map, 'P') 33 | im.putpalette(palette) 34 | im.save(fname, 'PNG') 35 | 36 | 37 | # if __name__ == "__main__": 38 | # im1 = 'D:/Datasets/CVPPP/CVPPP2017_LSC_training/train/plant0252_label.png' 39 | # im2 = 'D:/Datasets/CVPPP/CVPPP2017_LSC_training/train/plant0160_label.png' 40 | # _, p1 = read_indexed_png(im1) 41 | # label, p2 = read_indexed_png(im2) 42 | # save_indexed_png('./test.png', label, p2) 43 | 44 | -------------------------------------------------------------------------------- /utils/process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from skimage.measure import regionprops 4 | from skimage import io as ski_io 5 | from scipy.ndimage.morphology import distance_transform_edt 6 | 7 | def relabel_map(map): 8 | map = map.copy() 9 | new_map = map.copy() 10 | index = 1 11 | for u in np.unique(map): 12 | if u == 0: 13 | continue 14 | new_map[map==u] = index 15 | index += 1 16 | return new_map 17 | 18 | def remove_small(map, size, relabel=True): 19 | map = map.copy() 20 | props = regionprops(map) 21 | for p in props: 22 | if p.area < size: 23 | map[map == p.label] = 0 24 | if relabel: 25 | map = relabel_map(map) 26 | return map.astype(np.int32) 27 | 28 | 29 | # def stack2map(s, close=0, remove_small=0): 30 | # s = np.squeeze(s) 31 | # assert len(s.shape) == 3 32 | # s = s > 0 33 | 34 | # map = np.zeros((s.shape[0], s.shape[1]), dtype=np.int32) 35 | 36 | # for i in range(s.shape[-1]): 37 | # obj = s[:, :, i] 38 | 39 | # if np.sum(obj) == 0: 40 | # continue 41 | 42 | # map[obj > 0] = i+1 43 | 44 | # if close != 0 or remove_small != 0: 45 | # map = process_map(map, close=close, remove_small=remove_small) 46 | 47 | # return map.astype(np.int32) 48 | 49 | 50 | # def map2stack(map, close=0, remove_small=0): 51 | # map = np.squeeze(map) 52 | # assert len(map.shape) == 2 53 | 54 | # if close != 0 or remove_small != 0: 55 | # map = process_map(map, close=close, remove_small=remove_small) 56 | 57 | # unique = np.unique(map) 58 | # s = np.zeros((map.shape[0], map.shape[1], len(unique)-1), dtype=bool) \ 59 | # if 0 in unique else np.zeros((map.shape[0], map.shape[1], len(unique))) 60 | 61 | # counter = 0 62 | # for i in range(len(unique)): 63 | # if unique[i] == 0: 64 | # continue 65 | # obj = map == unique[i] 66 | 67 | # s = s.astype(np.uint8) 68 | # s[:, :, counter] = obj 69 | # counter += 1 70 | # return s[:, :, 0:counter] 71 | 72 | 73 | # def read_stack_from_files(files): 74 | # if len(files) == 0: 75 | # return None 76 | 77 | # for i, f in enumerate(files): 78 | # obj = ski_io.imread(f) 79 | # obj = obj[:, :, 0] if len(obj.shape) == 3 else obj 80 | 81 | # if i == 0: 82 | # s = np.zeros((obj.shape[0], obj.shape[1], len(files)), dtype=np.uint8) 83 | # s[:, :, i] = (obj>0).astype(np.uint8) 84 | 85 | # return s 86 | 87 | 88 | def boundary_of_label_map(map): 89 | 90 | map = map.copy() 91 | 92 | kernel = np.ones((3, 3), np.uint8) 93 | erosion = cv2.erode(map.astype(np.uint16), kernel, iterations=1) 94 | dilation = cv2.dilate(map.astype(np.uint16), kernel, iterations=1) 95 | boundary = np.not_equal(erosion, dilation) 96 | return boundary 97 | 98 | def distance_map(map, normalize=False): 99 | 100 | map = map.copy() 101 | boundary = boundary_of_label_map(map) 102 | map = np.multiply(map, 1-boundary) 103 | 104 | dist_map = cv2.distanceTransform((map>0).astype(np.uint8), cv2.DIST_L2, cv2.DIST_MASK_PRECISE) 105 | 106 | if normalize: 107 | unique = np.unique(map) 108 | for u in unique: 109 | if u == 0: 110 | continue 111 | max_dist = np.max(dist_map[map==u]) 112 | if max_dist != 0: 113 | dist_map[map == u] = dist_map[map == u]/max_dist 114 | return dist_map 115 | 116 | 117 | # def centroid_map(label_map): 118 | # c_map = np.zeros(np.squeeze(label_map).shape) 119 | # for obj in regionprops(label_map): 120 | # c = obj['centroid'] 121 | # c_map[int(c[0]), int(c[1])]=1 122 | # return c_map 123 | 124 | 125 | def get_neighbor_by_distance(label_map, distance=10, max_neighbor=50): 126 | 127 | label_map = label_map.copy() 128 | 129 | def _adjust_size(x): 130 | if len(x) >= max_neighbor: 131 | return x[0:max_neighbor] 132 | else: 133 | return np.pad(x, (0, max_neighbor-len(x)), 'constant', constant_values=(0, 0)) 134 | 135 | unique = np.unique(label_map) 136 | assert unique[0] == 0 137 | # only one object 138 | if len(unique) <= 2: 139 | return None 140 | 141 | neighbor_indice = np.zeros((len(unique)-1, max_neighbor)) 142 | label_flat = label_map.reshape((-1)) 143 | kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (distance * 2 + 1, distance * 2 + 1)) 144 | for i, label in enumerate(unique[1:]): 145 | assert i+1 == label 146 | mask = label_map == label 147 | dilated_mask = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1).reshape((-1)) 148 | neighbor_pixel_ind = np.logical_and(dilated_mask > 0, label_flat != 0) 149 | neighbor_pixel_ind = np.logical_and(neighbor_pixel_ind, label_flat != label) 150 | neighbors = np.unique(label_flat[neighbor_pixel_ind]) 151 | neighbor_indice[i,:] = _adjust_size(neighbors) 152 | 153 | return neighbor_indice.astype(np.int32) 154 | 155 | 156 | 157 | if __name__ == '__main__': 158 | # test 159 | import os 160 | im = ski_io.imread('D:\Datasets\BBBC006_U2OScell\ground_truth\mcf-z-stacks-03212011_a12_s1_w197a9b240-1624-42e2-86a3-d50f7b607ff6.png') 161 | 162 | # dist = get_neighbor_by_distance(im, distance=20, max_neighbor=5) 163 | # print(dist) 164 | dist = distance_map(im, normalize=True) 165 | ski_io.imsave('test.png', im) 166 | 167 | 168 | -------------------------------------------------------------------------------- /utils/tfrecord_creation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Utilities for creating TFRecords of TF examples for the Open Images dataset. 16 | """ 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | import skimage.io as ski_io 23 | import os 24 | from . import tfrecord_type 25 | import numpy as np 26 | import cv2 27 | 28 | 29 | def inject_fn_img(img_path, resize=None): 30 | 31 | if not os.path.exists(img_path): 32 | return None 33 | 34 | # all images save as png encoding 35 | img = cv2.imread(img_path, -1) 36 | if resize is not None: 37 | img = cv2.resize(img, resize, interpolation=cv2.INTER_LINEAR) 38 | _, img_encoded = cv2.imencode('.png', img) 39 | img_encoded = img_encoded.tobytes() 40 | 41 | channels = 1 if len(img.shape) == 2 else img.shape[2] 42 | 43 | feature_dict = { 44 | 'image/height': tfrecord_type.int64_feature(img.shape[0]), 45 | 'image/width': tfrecord_type.int64_feature(img.shape[1]), 46 | 'image/channels': tfrecord_type.int64_feature(channels), 47 | 'image/filename': tfrecord_type.bytes_feature(os.path.basename(img_path).encode('utf8')), 48 | 'image/image': tfrecord_type.bytes_feature(img_encoded), 49 | 'image/format': tfrecord_type.bytes_feature(".png".encode('utf8')) 50 | } 51 | return feature_dict 52 | 53 | def open_sharded_output_tfrecords(exit_stack, base_path, num_shards): 54 | """Opens all TFRecord shards for writing and adds them to an exit stack. 55 | 56 | Args: 57 | exit_stack: A context2.ExitStack used to automatically closed the TFRecords 58 | opened in this function. 59 | base_path: The base path for all shards 60 | num_shards: The number of shards 61 | 62 | Returns: 63 | The list of opened TFRecords. Position k in the list corresponds to shard k. 64 | """ 65 | tf_record_output_filenames = ['{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards) 66 | for idx in range(num_shards)] 67 | tfrecords = [exit_stack.enter_context(tf.python_io.TFRecordWriter(file_name)) 68 | for file_name in tf_record_output_filenames] 69 | 70 | return tfrecords 71 | -------------------------------------------------------------------------------- /utils/tfrecord_parse.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def extract_fn_base(data_record, image_depth='uint8'): 5 | feature_dict = { 6 | 'image/height': tf.FixedLenFeature([], tf.int64), 7 | 'image/width': tf.FixedLenFeature([], tf.int64), 8 | 'image/channels': tf.FixedLenFeature([], tf.int64), 9 | 'image/filename': tf.FixedLenFeature([], tf.string), 10 | 'image/image': tf.FixedLenFeature([], tf.string), 11 | 'image/format': tf.FixedLenFeature([], tf.string) 12 | } 13 | sample = tf.parse_single_example(data_record, feature_dict) 14 | if image_depth == 'uint8': 15 | sample['image/image'] = tf.image.decode_png(sample['image/image'], dtype=tf.uint8) 16 | else: 17 | sample['image/image'] = tf.image.decode_png(sample['image/image'], dtype=tf.uint16) 18 | 19 | return sample 20 | 21 | 22 | def extract_fn_local_dis(data_record, image_depth='uint8', dist_map=False): 23 | 24 | sample = extract_fn_base(data_record, image_depth) 25 | 26 | feature_dict = { 27 | 'image/label': tf.FixedLenFeature([], tf.string), 28 | 'image/neighbor': tf.FixedLenFeature([], tf.string), 29 | 'image/obj_count': tf.FixedLenFeature([], tf.int64), 30 | 'image/max_neighbor': tf.FixedLenFeature([], tf.int64) 31 | } 32 | 33 | if dist_map: 34 | feature_dict['image/dist_map'] = tf.FixedLenFeature([], tf.string) 35 | 36 | sample_add = tf.parse_single_example(data_record, feature_dict) 37 | 38 | sample['image/label'] = tf.image.decode_png(sample_add['image/label'], dtype=tf.uint16) 39 | sample['image/neighbor'] = tf.reshape(tf.decode_raw(sample_add['image/neighbor'], tf.int32), 40 | tf.stack([sample_add['image/obj_count'], sample_add['image/max_neighbor']])) 41 | sample['image/obj_count'] = sample_add['image/obj_count'] 42 | sample['image/max_neighbor'] = sample_add['image/max_neighbor'] 43 | 44 | if dist_map: 45 | sample['image/dist_map'] = tf.image.decode_png(sample_add['image/dist_map']) 46 | 47 | return sample 48 | 49 | -------------------------------------------------------------------------------- /utils/tfrecord_type.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for creating TFRecord data sets.""" 17 | 18 | import tensorflow as tf 19 | 20 | 21 | def int64_feature(value): 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 23 | 24 | 25 | def int64_list_feature(value): 26 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 27 | 28 | 29 | def bytes_feature(value): 30 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 31 | 32 | 33 | def bytes_list_feature(value): 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 35 | 36 | 37 | def float_list_feature(value): 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 39 | 40 | 41 | 42 | def read_examples_list(path): 43 | """Read list of training or validation examples. 44 | 45 | The file is assumed to contain a single example per line where the first 46 | token in the line is an identifier that allows us to find the image and 47 | annotation xml for that example. 48 | 49 | For example, the line: 50 | xyz 3 51 | would allow us to find files xyz.jpg and xyz.xml (the 3 would be ignored). 52 | 53 | Args: 54 | path: absolute path to examples list file. 55 | 56 | Returns: 57 | list of example identifiers (strings). 58 | """ 59 | with tf.gfile.GFile(path) as fid: 60 | lines = fid.readlines() 61 | return [line.strip().split(' ')[0] for line in lines] 62 | 63 | 64 | -------------------------------------------------------------------------------- /utils/tfrecords_convert.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.dirname(__file__))) 4 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | import cv2 9 | 10 | from . import tfrecord_creation, tfrecord_type 11 | from . import process 12 | from img_io import read_indexed_png 13 | 14 | MAX_NEIGHBOR = 32 15 | 16 | def convert_to_tf_example(img_path, 17 | gt_path, 18 | neighbor_distance_in_percent=0.02, 19 | resize=None, 20 | dist_map=False, 21 | gt_type="label", 22 | max_neighbor=MAX_NEIGHBOR): 23 | 24 | characteristicLength = 3 25 | 26 | # inject images 27 | feature_dict = tfrecord_creation.inject_fn_img(img_path, resize) 28 | # read and process ground truth image 29 | if gt_type == "indexed": 30 | gt, _ = read_indexed_png(gt_path) 31 | else: 32 | gt = cv2.imread(gt_path, -1) 33 | if resize is not None: 34 | gt = cv2.resize(gt, resize, interpolation=cv2.INTER_NEAREST) 35 | label = process.remove_small(gt, size=25, relabel=True) 36 | # ignore images which contains only one object 37 | unique = np.unique(label) 38 | assert unique[0] == 0 39 | if len(unique) <= 2: 40 | print("Omit an image {} containing only one object".format(img_path)) 41 | return None 42 | # save label map as uint16 image 43 | label = label.astype(np.uint16) 44 | _, label_encoded = cv2.imencode('.png', label) 45 | label_encoded = label_encoded.tobytes() 46 | feature_dict['image/label'] = tfrecord_type.bytes_feature(label_encoded) 47 | feature_dict['image/obj_count'] = tfrecord_type.int64_feature(len(unique)-1) 48 | # save neighbor relationship 49 | neighbor_distance = int(label.shape[1] * neighbor_distance_in_percent) 50 | neighbors = process.get_neighbor_by_distance(label, distance=neighbor_distance, max_neighbor=max_neighbor) 51 | feature_dict['image/neighbor'] = tfrecord_type.bytes_feature(neighbors.reshape(-1).tostring()) 52 | feature_dict['image/max_neighbor'] = tfrecord_type.int64_feature(neighbors.shape[1]) 53 | # save dist_map as uint8 image, resolution 1/255=0.0039 54 | if dist_map: 55 | d_map = process.distance_map(label, normalize=True) 56 | d_map = (((d_map-d_map.min())/(d_map.max()-d_map.min()))*255).astype(np.uint8) 57 | _, d_map_encoded = cv2.imencode('.png', d_map) 58 | d_map_encoded = d_map_encoded.tobytes() 59 | feature_dict['image/dist_map'] = tfrecord_type.bytes_feature(d_map_encoded) 60 | 61 | example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) 62 | 63 | return example 64 | 65 | 66 | def create_tf_record(image_dict, 67 | gt_dict, 68 | output_file, 69 | neighbor_distance_in_percent=0.02, 70 | resize=None, 71 | dist_map=False, 72 | num_shards=5, 73 | gt_type="label", 74 | max_neighbor=MAX_NEIGHBOR): 75 | 76 | """Creates a TFRecord file from examples. 77 | 78 | Args: 79 | image_dict: dict of image paths 80 | gt_dict: dict of ground truth paths 81 | output_file: name of tfrecord files 82 | num_shards: number of tfrecord shards 83 | generate_dist_map: generate distance map or not 84 | max_neighbor: max. number of neighbors saved 85 | """ 86 | import contextlib2 87 | 88 | # image_dict = name_dict(image_list) 89 | # gt_dict = name_dict(gt_list) 90 | 91 | total = len(gt_dict) 92 | 93 | with contextlib2.ExitStack() as tf_record_close_stack: 94 | 95 | output_tfrecords = tfrecord_creation.open_sharded_output_tfrecords( 96 | tf_record_close_stack, output_file, num_shards) 97 | 98 | processed_count = 0 99 | count = 0 100 | for k, gt_path in gt_dict.items(): 101 | if k in image_dict.keys(): 102 | tf_example = convert_to_tf_example(image_dict[k], 103 | gt_path, 104 | neighbor_distance_in_percent=neighbor_distance_in_percent, 105 | resize=resize, 106 | dist_map=dist_map, 107 | gt_type=gt_type, 108 | max_neighbor=max_neighbor) 109 | if tf_example is not None: 110 | processed_count += 1 111 | shard_idx = processed_count % num_shards 112 | output_tfrecords[shard_idx].write(tf_example.SerializeToString()) 113 | 114 | count += 1 115 | 116 | if count % 10 == 0: 117 | print('On image {} of {}, processed images: {}'.format(count, total, processed_count)) 118 | 119 | # debug 120 | # if count == 50: 121 | # break 122 | --------------------------------------------------------------------------------