├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── clustering.py ├── data └── tusimple_dataset_processing.py ├── datagenerator.py ├── doc ├── cluster_000000.png ├── cluster_001000.png ├── cluster_002000.png ├── cluster_012000.png ├── cluster_015000.png ├── cluster_016000.png ├── cluster_017000.png ├── cluster_018000.png └── training_pipeline.png ├── enet.py ├── inference.py ├── inference_test ├── images │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ └── 6.jpg └── results │ ├── cluster_0000.png │ ├── cluster_0001.png │ ├── cluster_0002.png │ ├── cluster_0003.png │ ├── cluster_0004.png │ └── cluster_0005.png ├── loss.py ├── pretrained_semantic_model ├── checkpoint ├── saved_model-24999.meta ├── saved_model-29999.data-00000-of-00001 ├── saved_model-29999.index └── saved_model-29999.meta ├── todo_semantic_segmentation ├── helper.py └── transfer_semantic.py ├── trained_model ├── checkpoint ├── model.ckpt-104999.data-00000-of-00001 ├── model.ckpt-104999.index └── model.ckpt-104999.meta ├── training.py ├── utils.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Kwotsin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Instance Segmentation with a Discriminative Loss Function 2 | 3 | Tensorflow implementation of [Semantic Instance Segmentation with a Discriminative Loss Function](https://arxiv.org/abs/1708.02551) trained on the [TuSimple dataset](http://benchmark.tusimple.ai/#/t/1) 4 | 5 | --- 6 | ### Files 7 | ├── __data__ here the data should be stored 8 | │        └── __tusimple_dataset_processing.py__ processes the TuSimple dataset 9 | ├── __doc__ documentation 10 | ├── __inference_test__ inference related data 11 | │        └── __images__ for testing the inference 12 | ├── __trained_model__ pretrained model for finetuning 13 | ├── __clustering.py__ mean-shift clustering 14 | ├── __datagenerator.py__ feeds data for training and evaluation 15 | ├── __enet.py__ [Enet architecture](https://github.com/kwotsin/TensorFlow-ENet) 16 | ├── __inference.py__ tests inference on images 17 | ├── __loss.py__ defines discriminative loss function 18 | ├── __README.md__ 19 | ├── __training.py__ contains training pipeline 20 | ├── __utils.py__ contains utilities files for building and initializing the graph 21 | └── __visualization.py__ contains visualization of the clustering and pixel embeddings 22 | 23 | 24 | ### Instructions 25 | 26 | #### Inference 27 | 1. To test the inference of the trained model execute: 28 | `python inference.py --modeldir trained_model --outdir inference_test/results` 29 | 30 | #### Training 31 | 32 | 1. Download the [TuSimple training dataset](http://benchmark.tusimple.ai/#/t/1) and extract its contents to the `data` folder. The folder structure should look like this: 33 | | data 34 | ├── train_set 35 | │     ├── clips 36 | │     ├── label_data_0313.json 37 | │     ├── label_data_0531.json 38 | │     ├── label_data_0601.json 39 | │     └── readme.md 40 | 2. Run the following script to prepare images and labels. 41 | `python data/tusimple_dataset_processing.py ` 42 | This should create the following images and labels folders: 43 | | data 44 | ├── train_set 45 | ├── images 46 | └── labels 47 | 3. For training on the dataset execute: 48 | `python training.py` 49 | alternatively use optional parameters (default parameters in this example): 50 | `python training --srcdir data --modeldir pretrained_semantic_model --outdir saved_model --logdir log --epochs 50 --var 1.0 --dist 1.0 --reg 1.0 --dvar 0.5 --ddist 1.5 51 | ` 52 | 4. To test the trained network execute: 53 | `python inference.py --modeldir saved_model` 54 | 55 | ### Training Pipeline 56 | 57 | 58 | ### Training Visualization 59 | Feature space projection of one image for consecutive gradient steps. Each point represents one pixel embedding and each color represents an instance in the label. 60 | 61 | 62 | 63 | 64 | ### Results 65 | 66 | 67 | 68 | 69 | ### Todo 70 | - pip requirements 71 | - semantic segmentation code 72 | - visualization 73 | 74 | Tensorflow version 1.2 75 | 76 | ### Reference and Credits 77 | This application uses Open Source components. We acknowledge and are grateful to these developers for their contributions to open source: 78 | - Project: TensorFlow-ENet https://github.com/kwotsin/TensorFlow-ENet 79 | - Project: TuSimple dataset http://benchmark.tusimple.ai 80 | - Project: Fast Scene Understanding https://github.com/DavyNeven/fastSceneUnderstanding 81 | 82 | ### Related work 83 | - Paper: [Towards End-to-End Lane Detection: an Instance Segmentation Approach](https://arxiv.org/abs/1802.05591) 84 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/__init__.py -------------------------------------------------------------------------------- /clustering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.cluster import MeanShift, estimate_bandwidth 4 | import time 5 | import cv2 6 | 7 | COLOR=[np.array([255,0,0]), 8 | np.array([0,255,0]), 9 | np.array([0,0,255]), 10 | np.array([125,125,0]), 11 | np.array([0,125,125]), 12 | np.array([125,0,125]), 13 | np.array([50,100,50]), 14 | np.array([100,50,100])] 15 | 16 | def cluster(prediction, bandwidth): 17 | ms = MeanShift(bandwidth, bin_seeding=True) 18 | print ('Mean shift clustering, might take some time ...') 19 | tic = time.time() 20 | ms.fit(prediction) 21 | print ('time for clustering', time.time() - tic) 22 | labels = ms.labels_ 23 | cluster_centers = ms.cluster_centers_ 24 | 25 | num_clusters = cluster_centers.shape[0] 26 | 27 | return num_clusters, labels, cluster_centers 28 | 29 | def get_instance_masks(prediction, bandwidth): 30 | batch_size, h, w, feature_dim = prediction.shape 31 | 32 | instance_masks = [] 33 | for i in range(batch_size): 34 | num_clusters, labels, cluster_centers = cluster(prediction[i].reshape([h*w, feature_dim]), bandwidth) 35 | print ('Number of predicted clusters', num_clusters) 36 | labels = np.array(labels, dtype=np.uint8).reshape([h,w]) 37 | mask = np.zeros([h,w,3], dtype=np.uint8) 38 | 39 | num_clusters = min([num_clusters,8]) 40 | for mask_id in range(num_clusters): 41 | ind = np.where(labels==mask_id) 42 | mask[ind] = COLOR[mask_id] 43 | 44 | 45 | instance_masks.append(mask) 46 | 47 | return instance_masks 48 | 49 | 50 | def save_instance_masks(prediction,output_dir, bandwidth, count): 51 | batch_size, h, w, feature_dim = prediction.shape 52 | 53 | instance_masks = [] 54 | for i in range(batch_size): 55 | num_clusters, labels, cluster_centers = cluster(prediction[i].reshape([h*w, feature_dim]), bandwidth) 56 | print ('Number of predicted clusters', num_clusters) 57 | labels = np.array(labels, dtype=np.uint8).reshape([h,w]) 58 | mask = np.zeros([h,w,3], dtype=np.uint8) 59 | 60 | num_clusters = min([num_clusters,8]) 61 | for mask_id in range(num_clusters): 62 | mask = np.zeros([h,w,3], dtype=np.uint8) 63 | ind = np.where(labels==mask_id) 64 | mask[ind] = np.array([255,255,255]) 65 | output_file_name = os.path.join(output_dir, 'cluster_{}_{}.png'.format(str(count).zfill(4), str(mask_id))) 66 | cv2.imwrite(output_file_name, mask) 67 | 68 | 69 | instance_masks.append(mask) 70 | 71 | return instance_masks 72 | -------------------------------------------------------------------------------- /data/tusimple_dataset_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import csv 4 | import glob 5 | import argparse 6 | import math 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | import cv2 10 | import numpy as np 11 | 12 | def read_json(data_dir, json_string): 13 | json_paths = glob.glob(os.path.join(data_dir,json_string)) 14 | print (json_paths) 15 | data = [] 16 | for path in json_paths: 17 | with open(path) as f: 18 | d = (line.strip() for line in f) 19 | d_str = "[{0}]".format(','.join(d)) 20 | data.append(json.loads(d_str)) 21 | 22 | num_samples = 0 23 | for d in data: 24 | num_samples += len(d) 25 | print ('Number of labeled images:', num_samples) 26 | print ('data keys:', data[0][0].keys()) 27 | 28 | return data 29 | 30 | def read_image_strings(data, input_dir): 31 | img_paths = [] 32 | for datum in data: 33 | for d in datum: 34 | path = os.path.join(input_dir, d['raw_file']) 35 | img_paths.append(path) 36 | 37 | num_samples = 0 38 | for d in data: 39 | num_samples += len(d) 40 | assert len(img_paths)==num_samples, 'Number of samples do not match' 41 | print (img_paths[0:2]) 42 | 43 | return img_paths 44 | 45 | def save_input_images(output_dir, img_paths): 46 | for i, path in tqdm(enumerate(img_paths), total=len(img_paths)): 47 | img = cv2.imread(path) 48 | output_path = os.path.join(output_dir,'images', '{}.png'.format(str(i).zfill(4))) 49 | cv2.imwrite(output_path, img) 50 | 51 | def draw_lines(img, lanes, height, instancewise=False): 52 | for i, lane in enumerate(lanes): 53 | pts = [[x,y] for x, y in zip(lane, height) if (x!=-2 and y!=-2)] 54 | pts = np.array([pts]) 55 | if not instancewise: 56 | cv2.polylines(img, pts, False,255, thickness=7) 57 | else: 58 | cv2.polylines(img, pts, False,50*i+20, thickness=7) 59 | 60 | def draw_single_line(img, lane, height): 61 | pts = [[x,y] for x, y in zip(lane, height) if (x!=-2 and y!=-2)] 62 | pts = np.array([pts]) 63 | cv2.polylines(img, pts, False,255, thickness=15) 64 | 65 | def save_label_images(output_dir, data, instancewise=True): 66 | counter = 0 67 | 68 | for i in range(len(data)): 69 | for j in tqdm(range(len(data[i]))): 70 | img = np.zeros([720, 1280], dtype=np.uint8) 71 | lanes = data[i][j]['lanes'] 72 | height = data[i][j]['h_samples'] 73 | draw_lines(img, lanes, height, instancewise) 74 | output_path = os.path.join(output_dir,'labels', '{}.png'.format(str(counter).zfill(4))) 75 | cv2.imwrite(output_path, img) 76 | counter += 1 77 | 78 | 79 | if __name__=='__main__': 80 | 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('srcdir', help="Source directory of TuSimple dataset") 83 | parser.add_argument('-o', '--outdir', default='.', help="Output directory of extracted data") 84 | args = parser.parse_args() 85 | 86 | if not os.path.isdir(args.srcdir): 87 | raise IOError('Directory does not exist') 88 | if not os.path.isdir('images'): 89 | os.mkdir('images') 90 | if not os.path.isdir('labels'): 91 | os.mkdir('labels') 92 | 93 | json_string = 'label_data_*.json' 94 | data = read_json(args.srcdir, json_string) 95 | img_paths = read_image_strings(data, args.srcdir) 96 | save_input_images(args.outdir, img_paths) 97 | save_label_images(args.outdir, data) -------------------------------------------------------------------------------- /datagenerator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import scipy.misc 5 | import random 6 | from sklearn.utils import shuffle 7 | import shutil 8 | import time 9 | import tensorflow as tf 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | ### Mean and std deviation for whole training data set (RGB format) 15 | mean = np.array([92.14031982, 103.20146942, 103.47182465]) 16 | std = np.array([49.157, 54.9057, 59.4065]) 17 | 18 | INSTANCE_COLORS = [np.array([0,0,0]), 19 | np.array([20.,20.,20.]), 20 | np.array([70.,70.,70.]), 21 | np.array([120.,120.,120.]), 22 | np.array([170.,170.,170.]), 23 | np.array([220.,220.,220.]) 24 | ] 25 | 26 | def get_batches_fn(batch_size, image_shape, image_paths, label_paths): 27 | """ 28 | Create batches of training data 29 | :param batch_size: Batch Size 30 | :return: Batches of training data 31 | """ 32 | 33 | #print ('Number of total labels:', len(label_paths)) 34 | assert len(image_paths)==len(label_paths), 'Number of images and labels do not match' 35 | 36 | image_paths.sort() 37 | label_paths.sort() 38 | 39 | #image_paths = image_paths[:10] 40 | #label_paths = label_paths[:10] 41 | 42 | image_paths, label_paths = shuffle(image_paths, label_paths) 43 | for batch_i in range(0, len(image_paths), batch_size): 44 | images = [] 45 | gt_images = [] 46 | for image_file, gt_image_file in zip(image_paths[batch_i:batch_i+batch_size], label_paths[batch_i:batch_i+batch_size]): 47 | 48 | image = cv2.resize(cv2.imread(image_file), image_shape, interpolation=cv2.INTER_LINEAR) 49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 50 | #image = (image.astype(np.float32)-mean)/std 51 | 52 | gt_image = cv2.imread(gt_image_file, cv2.IMREAD_COLOR) 53 | gt_image = cv2.resize(gt_image[:,:,0], image_shape, interpolation=cv2.INTER_NEAREST) 54 | 55 | images.append(image) 56 | gt_images.append(gt_image) 57 | 58 | yield np.array(images), np.array(gt_images) 59 | 60 | 61 | def get_validation_batch(data_dir, image_shape): 62 | valid_image_paths = [os.path.join(data_dir,'images','0000.png')] 63 | 64 | valid_label_paths = [os.path.join(data_dir,'labels','0000.png')] 65 | 66 | images = [] 67 | gt_images = [] 68 | for image_file, gt_image_file in zip(valid_image_paths, valid_label_paths): 69 | 70 | image = cv2.resize(cv2.imread(image_file), image_shape, interpolation=cv2.INTER_LINEAR) 71 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 72 | #image = (image.astype(np.float32)-mean)/std 73 | 74 | gt_image = cv2.imread(gt_image_file, cv2.IMREAD_COLOR) 75 | gt_image = cv2.resize(gt_image[:,:,0], image_shape, interpolation=cv2.INTER_NEAREST) 76 | 77 | images.append(image) 78 | gt_images.append(gt_image) 79 | 80 | return np.array(images), np.array(gt_images) 81 | 82 | 83 | 84 | if __name__=="__main__": 85 | pass -------------------------------------------------------------------------------- /doc/cluster_000000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/doc/cluster_000000.png -------------------------------------------------------------------------------- /doc/cluster_001000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/doc/cluster_001000.png -------------------------------------------------------------------------------- /doc/cluster_002000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/doc/cluster_002000.png -------------------------------------------------------------------------------- /doc/cluster_012000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/doc/cluster_012000.png -------------------------------------------------------------------------------- /doc/cluster_015000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/doc/cluster_015000.png -------------------------------------------------------------------------------- /doc/cluster_016000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/doc/cluster_016000.png -------------------------------------------------------------------------------- /doc/cluster_017000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/doc/cluster_017000.png -------------------------------------------------------------------------------- /doc/cluster_018000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/doc/cluster_018000.png -------------------------------------------------------------------------------- /doc/training_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/doc/training_pipeline.png -------------------------------------------------------------------------------- /enet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import initializers 3 | slim = tf.contrib.slim 4 | 5 | ''' 6 | ============================================================================ 7 | ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation 8 | ============================================================================ 9 | Based on the paper: https://arxiv.org/pdf/1606.02147.pdf 10 | ''' 11 | @slim.add_arg_scope 12 | def prelu(x, scope, decoder=False): 13 | ''' 14 | Performs the parametric relu operation. This implementation is based on: 15 | https://stackoverflow.com/questions/39975676/how-to-implement-prelu-activation-in-tensorflow 16 | 17 | For the decoder portion, prelu becomes just a normal prelu 18 | 19 | INPUTS: 20 | - x(Tensor): a 4D Tensor that undergoes prelu 21 | - scope(str): the string to name your prelu operation's alpha variable. 22 | - decoder(bool): if True, prelu becomes a normal relu. 23 | 24 | OUTPUTS: 25 | - pos + neg / x (Tensor): gives prelu output only during training; otherwise, just return x. 26 | 27 | ''' 28 | #If decoder, then perform relu and just return the output 29 | if decoder: 30 | return tf.nn.relu(x, name=scope) 31 | 32 | alpha= tf.get_variable(scope + 'alpha', x.get_shape()[-1], 33 | initializer=tf.constant_initializer(0.0), 34 | dtype=tf.float32) 35 | pos = tf.nn.relu(x) 36 | neg = alpha * (x - abs(x)) * 0.5 37 | return pos + neg 38 | 39 | def spatial_dropout(x, p, seed, scope, is_training=True): 40 | ''' 41 | Performs a 2D spatial dropout that drops layers instead of individual elements in an input feature map. 42 | Note that p stands for the probability of dropping, but tf.nn.relu uses probability of keeping. 43 | 44 | ------------------ 45 | Technical Details 46 | ------------------ 47 | The noise shape must be of shape [batch_size, 1, 1, num_channels], with the height and width set to 1, because 48 | it will represent either a 1 or 0 for each layer, and these 1 or 0 integers will be broadcasted to the entire 49 | dimensions of each layer they interact with such that they can decide whether each layer should be entirely 50 | 'dropped'/set to zero or have its activations entirely kept. 51 | -------------------------- 52 | 53 | INPUTS: 54 | - x(Tensor): a 4D Tensor of the input feature map. 55 | - p(float): a float representing the probability of dropping a layer 56 | - seed(int): an integer for random seeding the random_uniform distribution that runs under tf.nn.relu 57 | - scope(str): the string name for naming the spatial_dropout 58 | - is_training(bool): to turn on dropout only when training. Optional. 59 | 60 | OUTPUTS: 61 | - output(Tensor): a 4D Tensor that is in exactly the same size as the input x, 62 | with certain layers having their elements all set to 0 (i.e. dropped). 63 | ''' 64 | if is_training: 65 | keep_prob = 1.0 - p 66 | input_shape = x.get_shape().as_list() 67 | noise_shape = tf.constant(value=[input_shape[0], 1, 1, input_shape[3]]) 68 | output = tf.nn.dropout(x, keep_prob, noise_shape, seed=seed, name=scope) 69 | 70 | return output 71 | 72 | return x 73 | 74 | def unpool(updates, mask, k_size=[1, 2, 2, 1], output_shape=None, scope=''): 75 | ''' 76 | Unpooling function based on the implementation by Panaetius at https://github.com/tensorflow/tensorflow/issues/2169 77 | 78 | INPUTS: 79 | - inputs(Tensor): a 4D tensor of shape [batch_size, height, width, num_channels] that represents the input block to be upsampled 80 | - mask(Tensor): a 4D tensor that represents the argmax values/pooling indices of the previously max-pooled layer 81 | - k_size(list): a list of values representing the dimensions of the unpooling filter. 82 | - output_shape(list): a list of values to indicate what the final output shape should be after unpooling 83 | - scope(str): the string name to name your scope 84 | 85 | OUTPUTS: 86 | - ret(Tensor): the returned 4D tensor that has the shape of output_shape. 87 | 88 | ''' 89 | with tf.variable_scope(scope): 90 | mask = tf.cast(mask, tf.int32) 91 | input_shape = tf.shape(updates, out_type=tf.int32) 92 | # calculation new shape 93 | if output_shape is None: 94 | output_shape = (input_shape[0], input_shape[1] * ksize[1], input_shape[2] * ksize[2], input_shape[3]) 95 | 96 | # calculation indices for batch, height, width and feature maps 97 | one_like_mask = tf.ones_like(mask, dtype=tf.int32) 98 | batch_shape = tf.concat([[input_shape[0]], [1], [1], [1]], 0) 99 | batch_range = tf.reshape(tf.range(output_shape[0], dtype=tf.int32), shape=batch_shape) 100 | b = one_like_mask * batch_range 101 | y = mask // (output_shape[2] * output_shape[3]) 102 | x = (mask // output_shape[3]) % output_shape[2] #mask % (output_shape[2] * output_shape[3]) // output_shape[3] 103 | feature_range = tf.range(output_shape[3], dtype=tf.int32) 104 | f = one_like_mask * feature_range 105 | 106 | # transpose indices & reshape update values to one dimension 107 | updates_size = tf.size(updates) 108 | indices = tf.transpose(tf.reshape(tf.stack([b, y, x, f]), [4, updates_size])) 109 | values = tf.reshape(updates, [updates_size]) 110 | ret = tf.scatter_nd(indices, values, output_shape) 111 | return ret 112 | 113 | @slim.add_arg_scope 114 | def initial_block(inputs, is_training=True, scope='initial_block'): 115 | ''' 116 | The initial block for Enet has 2 branches: The convolution branch and Maxpool branch. 117 | 118 | The conv branch has 13 layers, while the maxpool branch gives 3 layers corresponding to the RGB channels. 119 | Both output layers are then concatenated to give an output of 16 layers. 120 | 121 | NOTE: Does not need to store pooling indices since it won't be used later for the final upsampling. 122 | 123 | INPUTS: 124 | - inputs(Tensor): A 4D tensor of shape [batch_size, height, width, channels] 125 | 126 | OUTPUTS: 127 | - net_concatenated(Tensor): a 4D Tensor that contains the 128 | ''' 129 | #Convolutional branch 130 | net_conv = slim.conv2d(inputs, 13, [3,3], stride=2, activation_fn=None, scope=scope+'_conv') 131 | net_conv = slim.batch_norm(net_conv, is_training=is_training, fused=True, scope=scope+'_batchnorm') 132 | net_conv = prelu(net_conv, scope=scope+'_prelu') 133 | 134 | #Max pool branch 135 | net_pool = slim.max_pool2d(inputs, [2,2], stride=2, scope=scope+'_max_pool') 136 | 137 | #Concatenated output - does it matter max pool comes first or conv comes first? probably not. 138 | net_concatenated = tf.concat([net_conv, net_pool], axis=3, name=scope+'_concat') 139 | return net_concatenated 140 | 141 | @slim.add_arg_scope 142 | def bottleneck(inputs, 143 | output_depth, 144 | filter_size, 145 | regularizer_prob, 146 | projection_ratio=4, 147 | seed=0, 148 | is_training=True, 149 | downsampling=False, 150 | upsampling=False, 151 | pooling_indices=None, 152 | output_shape=None, 153 | dilated=False, 154 | dilation_rate=None, 155 | asymmetric=False, 156 | decoder=False, 157 | scope='bottleneck'): 158 | ''' 159 | The bottleneck module has three different kinds of variants: 160 | 161 | 1. A regular convolution which you can decide whether or not to downsample. 162 | 2. A dilated convolution, which requires you to have a dilation factor. 163 | 3. An asymmetric convolution that has a decomposed filter size of 5x1 and 1x5 separately. 164 | 165 | INPUTS: 166 | - inputs(Tensor): a 4D Tensor of the previous convolutional block of shape [batch_size, height, width, num_channels]. 167 | - output_depth(int): an integer indicating the output depth of the output convolutional block. 168 | - filter_size(int): an integer that gives the height and width of the filter size to use for a regular/dilated convolution. 169 | - regularizer_prob(float): the float p that represents the prob of dropping a layer for spatial dropout regularization. 170 | - projection_ratio(int): the amount of depth to reduce for initial 1x1 projection. Depth is divided by projection ratio. Default is 4. 171 | - seed(int): an integer for the random seed used in the random normal distribution within dropout. 172 | - is_training(bool): a boolean value to indicate whether or not is training. Decides batch_norm and prelu activity. 173 | 174 | - downsampling(bool): if True, a max-pool2D layer is added to downsample the spatial sizes. 175 | - upsampling(bool): if True, the upsampling bottleneck is activated but requires pooling indices to upsample. 176 | - pooling_indices(Tensor): the argmax values that are obtained after performing tf.nn.max_pool_with_argmax. 177 | - output_shape(list): A list of integers indicating the output shape of the unpooling layer. 178 | - dilated(bool): if True, then dilated convolution is done, but requires a dilation rate to be given. 179 | - dilation_rate(int): the dilation factor for performing atrous convolution/dilated convolution. 180 | - asymmetric(bool): if True, then asymmetric convolution is done, and the only filter size used here is 5. 181 | - decoder(bool): if True, then all the prelus become relus according to ENet author. 182 | - scope(str): a string name that names your bottleneck. 183 | 184 | OUTPUTS: 185 | - net(Tensor): The convolution block output after a bottleneck 186 | - pooling_indices(Tensor): If downsample, then this tensor is produced for use in upooling later. 187 | - inputs_shape(list): The shape of the input to the downsampling conv block. For use in unpooling later. 188 | 189 | ''' 190 | #Calculate the depth reduction based on the projection ratio used in 1x1 convolution. 191 | reduced_depth = int(inputs.get_shape().as_list()[3] / projection_ratio) 192 | 193 | with slim.arg_scope([prelu], decoder=decoder): 194 | 195 | #=============DOWNSAMPLING BOTTLENECK==================== 196 | if downsampling: 197 | #=============MAIN BRANCH============= 198 | #Just perform a max pooling 199 | net_main, pooling_indices = tf.nn.max_pool_with_argmax(inputs, 200 | ksize=[1,2,2,1], 201 | strides=[1,2,2,1], 202 | padding='SAME', 203 | name=scope+'_main_max_pool') 204 | 205 | #First get the difference in depth to pad, then pad with zeros only on the last dimension. 206 | inputs_shape = inputs.get_shape().as_list() 207 | depth_to_pad = abs(inputs_shape[3] - output_depth) 208 | paddings = tf.convert_to_tensor([[0,0], [0,0], [0,0], [0, depth_to_pad]]) 209 | net_main = tf.pad(net_main, paddings=paddings, name=scope+'_main_padding') 210 | 211 | #=============SUB BRANCH============== 212 | #First projection that has a 2x2 kernel and stride 2 213 | net = slim.conv2d(inputs, reduced_depth, [2,2], stride=2, scope=scope+'_conv1') 214 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm1') 215 | net = prelu(net, scope=scope+'_prelu1') 216 | 217 | #Second conv block 218 | net = slim.conv2d(net, reduced_depth, [filter_size, filter_size], scope=scope+'_conv2') 219 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm2') 220 | net = prelu(net, scope=scope+'_prelu2') 221 | 222 | #Final projection with 1x1 kernel 223 | net = slim.conv2d(net, output_depth, [1,1], scope=scope+'_conv3') 224 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm3') 225 | net = prelu(net, scope=scope+'_prelu3') 226 | 227 | #Regularizer 228 | net = spatial_dropout(net, p=regularizer_prob, seed=seed, scope=scope+'_spatial_dropout') 229 | 230 | #Finally, combine the two branches together via an element-wise addition 231 | net = tf.add(net, net_main, name=scope+'_add') 232 | net = prelu(net, scope=scope+'_last_prelu') 233 | 234 | #also return inputs shape for convenience later 235 | return net, pooling_indices, inputs_shape 236 | 237 | #============DILATION CONVOLUTION BOTTLENECK==================== 238 | #Everything is the same as a regular bottleneck except for the dilation rate argument 239 | elif dilated: 240 | #Check if dilation rate is given 241 | if not dilation_rate: 242 | raise ValueError('Dilation rate is not given.') 243 | 244 | #Save the main branch for addition later 245 | net_main = inputs 246 | 247 | #First projection with 1x1 kernel (dimensionality reduction) 248 | net = slim.conv2d(inputs, reduced_depth, [1,1], scope=scope+'_conv1') 249 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm1') 250 | net = prelu(net, scope=scope+'_prelu1') 251 | 252 | #Second conv block --- apply dilated convolution here 253 | net = slim.conv2d(net, reduced_depth, [filter_size, filter_size], rate=dilation_rate, scope=scope+'_dilated_conv2') 254 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm2') 255 | net = prelu(net, scope=scope+'_prelu2') 256 | 257 | #Final projection with 1x1 kernel (Expansion) 258 | net = slim.conv2d(net, output_depth, [1,1], scope=scope+'_conv3') 259 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm3') 260 | net = prelu(net, scope=scope+'_prelu3') 261 | 262 | #Regularizer 263 | net = spatial_dropout(net, p=regularizer_prob, seed=seed, scope=scope+'_spatial_dropout') 264 | net = prelu(net, scope=scope+'_prelu4') 265 | 266 | #Add the main branch 267 | net = tf.add(net_main, net, name=scope+'_add_dilated') 268 | net = prelu(net, scope=scope+'_last_prelu') 269 | 270 | return net 271 | 272 | #===========ASYMMETRIC CONVOLUTION BOTTLENECK============== 273 | #Everything is the same as a regular bottleneck except for a [5,5] kernel decomposed into two [5,1] then [1,5] 274 | elif asymmetric: 275 | #Save the main branch for addition later 276 | net_main = inputs 277 | 278 | #First projection with 1x1 kernel (dimensionality reduction) 279 | net = slim.conv2d(inputs, reduced_depth, [1,1], scope=scope+'_conv1') 280 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm1') 281 | net = prelu(net, scope=scope+'_prelu1') 282 | 283 | #Second conv block --- apply asymmetric conv here 284 | net = slim.conv2d(net, reduced_depth, [filter_size, 1], scope=scope+'_asymmetric_conv2a') 285 | net = slim.conv2d(net, reduced_depth, [1, filter_size], scope=scope+'_asymmetric_conv2b') 286 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm2') 287 | net = prelu(net, scope=scope+'_prelu2') 288 | 289 | #Final projection with 1x1 kernel 290 | net = slim.conv2d(net, output_depth, [1,1], scope=scope+'_conv3') 291 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm3') 292 | net = prelu(net, scope=scope+'_prelu3') 293 | 294 | #Regularizer 295 | net = spatial_dropout(net, p=regularizer_prob, seed=seed, scope=scope+'_spatial_dropout') 296 | net = prelu(net, scope=scope+'_prelu4') 297 | 298 | #Add the main branch 299 | net = tf.add(net_main, net, name=scope+'_add_asymmetric') 300 | net = prelu(net, scope=scope+'_last_prelu') 301 | 302 | return net 303 | 304 | #============UPSAMPLING BOTTLENECK================ 305 | #Everything is the same as a regular one, except convolution becomes transposed. 306 | elif upsampling: 307 | #Check if pooling indices is given 308 | if pooling_indices == None: 309 | raise ValueError('Pooling indices are not given.') 310 | 311 | #Check output_shape given or not 312 | if output_shape == None: 313 | raise ValueError('Output depth is not given') 314 | 315 | #=======MAIN BRANCH======= 316 | #Main branch to upsample. output shape must match with the shape of the layer that was pooled initially, in order 317 | #for the pooling indices to work correctly. However, the initial pooled layer was padded, so need to reduce dimension 318 | #before unpooling. In the paper, padding is replaced with convolution for this purpose of reducing the depth! 319 | net_unpool = slim.conv2d(inputs, output_depth, [1,1], scope=scope+'_main_conv1') 320 | net_unpool = slim.batch_norm(net_unpool, is_training=is_training, scope=scope+'batch_norm1') 321 | net_unpool = unpool(net_unpool, pooling_indices, output_shape=output_shape, scope='unpool') 322 | 323 | #======SUB BRANCH======= 324 | #First 1x1 projection to reduce depth 325 | net = slim.conv2d(inputs, reduced_depth, [1,1], scope=scope+'_conv1') 326 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm2') 327 | net = prelu(net, scope=scope+'_prelu1') 328 | 329 | #Second conv block -----------------------------> NOTE: using tf.nn.conv2d_transpose for variable input shape. 330 | net_unpool_shape = net_unpool.get_shape().as_list() 331 | output_shape = [net_unpool_shape[0], net_unpool_shape[1], net_unpool_shape[2], reduced_depth] 332 | output_shape = tf.convert_to_tensor(output_shape) 333 | filter_size = [filter_size, filter_size, reduced_depth, reduced_depth] 334 | filters = tf.get_variable(shape=filter_size, initializer=initializers.xavier_initializer(), dtype=tf.float32, name=scope+'_transposed_conv2_filters') 335 | 336 | # net = slim.conv2d_transpose(net, reduced_depth, [filter_size, filter_size], stride=2, scope=scope+'_transposed_conv2') 337 | net = tf.nn.conv2d_transpose(net, filter=filters, strides=[1,2,2,1], output_shape=output_shape, name=scope+'_transposed_conv2') 338 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm3') 339 | net = prelu(net, scope=scope+'_prelu2') 340 | 341 | #Final projection with 1x1 kernel 342 | net = slim.conv2d(net, output_depth, [1,1], scope=scope+'_conv3') 343 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm4') 344 | net = prelu(net, scope=scope+'_prelu3') 345 | 346 | #Regularizer 347 | net = spatial_dropout(net, p=regularizer_prob, seed=seed, scope=scope+'_spatial_dropout') 348 | net = prelu(net, scope=scope+'_prelu4') 349 | 350 | #Finally, add the unpooling layer and the sub branch together 351 | net = tf.add(net, net_unpool, name=scope+'_add_upsample') 352 | net = prelu(net, scope=scope+'_last_prelu') 353 | 354 | return net 355 | 356 | #OTHERWISE, just perform a regular bottleneck! 357 | #==============REGULAR BOTTLENECK================== 358 | #Save the main branch for addition later 359 | net_main = inputs 360 | 361 | #First projection with 1x1 kernel 362 | net = slim.conv2d(inputs, reduced_depth, [1,1], scope=scope+'_conv1') 363 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm1') 364 | net = prelu(net, scope=scope+'_prelu1') 365 | 366 | #Second conv block 367 | net = slim.conv2d(net, reduced_depth, [filter_size, filter_size], scope=scope+'_conv2') 368 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm2') 369 | net = prelu(net, scope=scope+'_prelu2') 370 | 371 | #Final projection with 1x1 kernel 372 | net = slim.conv2d(net, output_depth, [1,1], scope=scope+'_conv3') 373 | net = slim.batch_norm(net, is_training=is_training, scope=scope+'_batch_norm3') 374 | net = prelu(net, scope=scope+'_prelu3') 375 | 376 | #Regularizer 377 | net = spatial_dropout(net, p=regularizer_prob, seed=seed, scope=scope+'_spatial_dropout') 378 | net = prelu(net, scope=scope+'_prelu4') 379 | 380 | #Add the main branch 381 | net = tf.add(net_main, net, name=scope+'_add_regular') 382 | net = prelu(net, scope=scope+'_last_prelu') 383 | 384 | return net 385 | 386 | #Now actually start building the network 387 | def ENet(inputs, 388 | num_classes, 389 | batch_size, 390 | num_initial_blocks=1, 391 | stage_two_repeat=2, 392 | skip_connections=True, 393 | reuse=None, 394 | is_training=True, 395 | scope='ENet'): 396 | ''' 397 | The ENet model for real-time semantic segmentation! 398 | 399 | INPUTS: 400 | - inputs(Tensor): a 4D Tensor of shape [batch_size, image_height, image_width, num_channels] that represents one batch of preprocessed images. 401 | - num_classes(int): an integer for the number of classes to predict. This will determine the final output channels as the answer. 402 | - batch_size(int): the batch size to explictly set the shape of the inputs in order for operations to work properly. 403 | - num_initial_blocks(int): the number of times to repeat the initial block. 404 | - stage_two_repeat(int): the number of times to repeat stage two in order to make the network deeper. 405 | - skip_connections(bool): if True, add the corresponding encoder feature maps to the decoder. They are of exact same shapes. 406 | - reuse(bool): Whether or not to reuse the variables for evaluation. 407 | - is_training(bool): if True, switch on batch_norm and prelu only during training, otherwise they are turned off. 408 | - scope(str): a string that represents the scope name for the variables. 409 | 410 | OUTPUTS: 411 | - net(Tensor): a 4D Tensor output of shape [batch_size, image_height, image_width, num_classes], where each pixel has a one-hot encoded vector 412 | determining the label of the pixel. 413 | ''' 414 | #Set the shape of the inputs first to get the batch_size information 415 | inputs_shape = inputs.get_shape().as_list() 416 | inputs.set_shape(shape=(batch_size, inputs_shape[1], inputs_shape[2], inputs_shape[3])) 417 | 418 | with tf.variable_scope(scope, reuse=reuse): 419 | #Set the primary arg scopes. Fused batch_norm is faster than normal batch norm. 420 | with slim.arg_scope([initial_block, bottleneck], is_training=is_training),\ 421 | slim.arg_scope([slim.batch_norm], fused=True), \ 422 | slim.arg_scope([slim.conv2d, slim.conv2d_transpose], activation_fn=None): 423 | #=================INITIAL BLOCK================= 424 | net = initial_block(inputs, scope='initial_block_1') 425 | for i in range(2, max(num_initial_blocks, 1) + 1): 426 | net = initial_block(net, scope='initial_block_' + str(i)) 427 | 428 | #Save for skip connection later 429 | if skip_connections: 430 | net_one = net 431 | 432 | #===================STAGE ONE======================= 433 | net, pooling_indices_1, inputs_shape_1 = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, downsampling=True, scope='bottleneck1_0') 434 | net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, scope='bottleneck1_1') 435 | net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, scope='bottleneck1_2') 436 | net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, scope='bottleneck1_3') 437 | net = bottleneck(net, output_depth=64, filter_size=3, regularizer_prob=0.01, scope='bottleneck1_4') 438 | 439 | #Save for skip connection later 440 | if skip_connections: 441 | net_two = net 442 | 443 | #regularization prob is 0.1 from bottleneck 2.0 onwards 444 | with slim.arg_scope([bottleneck], regularizer_prob=0.1): 445 | net, pooling_indices_2, inputs_shape_2 = bottleneck(net, output_depth=128, filter_size=3, downsampling=True, scope='bottleneck2_0') 446 | 447 | #Repeat the stage two at least twice to get stage 2 and 3: 448 | for i in range(2, max(stage_two_repeat, 2) + 2): 449 | net = bottleneck(net, output_depth=128, filter_size=3, scope='bottleneck'+str(i)+'_1') 450 | net = bottleneck(net, output_depth=128, filter_size=3, dilated=True, dilation_rate=2, scope='bottleneck'+str(i)+'_2') 451 | net = bottleneck(net, output_depth=128, filter_size=5, asymmetric=True, scope='bottleneck'+str(i)+'_3') 452 | net = bottleneck(net, output_depth=128, filter_size=3, dilated=True, dilation_rate=4, scope='bottleneck'+str(i)+'_4') 453 | net = bottleneck(net, output_depth=128, filter_size=3, scope='bottleneck'+str(i)+'_5') 454 | net = bottleneck(net, output_depth=128, filter_size=3, dilated=True, dilation_rate=8, scope='bottleneck'+str(i)+'_6') 455 | net = bottleneck(net, output_depth=128, filter_size=5, asymmetric=True, scope='bottleneck'+str(i)+'_7') 456 | net = bottleneck(net, output_depth=128, filter_size=3, dilated=True, dilation_rate=16, scope='bottleneck'+str(i)+'_8') 457 | 458 | with slim.arg_scope([bottleneck], regularizer_prob=0.1, decoder=True): 459 | #===================STAGE FOUR======================== 460 | bottleneck_scope_name = "bottleneck" + str(i + 1) 461 | 462 | #The decoder section, so start to upsample. 463 | net = bottleneck(net, output_depth=64, filter_size=3, upsampling=True, 464 | pooling_indices=pooling_indices_2, output_shape=inputs_shape_2, scope=bottleneck_scope_name+'_0') 465 | 466 | #Perform skip connections here 467 | if skip_connections: 468 | net = tf.add(net, net_two, name=bottleneck_scope_name+'_skip_connection') 469 | 470 | net = bottleneck(net, output_depth=64, filter_size=3, scope=bottleneck_scope_name+'_1') 471 | net = bottleneck(net, output_depth=64, filter_size=3, scope=bottleneck_scope_name+'_2') 472 | 473 | #===================STAGE FIVE======================== 474 | bottleneck_scope_name = "bottleneck" + str(i + 2) 475 | 476 | net = bottleneck(net, output_depth=16, filter_size=3, upsampling=True, 477 | pooling_indices=pooling_indices_1, output_shape=inputs_shape_1, scope=bottleneck_scope_name+'_0') 478 | 479 | #perform skip connections here 480 | if skip_connections: 481 | net = tf.add(net, net_one, name=bottleneck_scope_name+'_skip_connection') 482 | 483 | net = bottleneck(net, output_depth=16, filter_size=3, scope=bottleneck_scope_name+'_1') 484 | 485 | #=============FINAL CONVOLUTION============= 486 | logits = slim.conv2d_transpose(net, num_classes, [2,2], stride=2, scope='fullconv') 487 | probabilities = tf.nn.softmax(logits, name='logits_to_softmax') 488 | 489 | return logits, probabilities 490 | 491 | 492 | def ENet_arg_scope(weight_decay=2e-4, 493 | batch_norm_decay=0.1, 494 | batch_norm_epsilon=0.001): 495 | ''' 496 | The arg scope for enet model. The weight decay is 2e-4 as seen in the paper. 497 | Batch_norm decay is 0.1 (momentum 0.1) according to official implementation. 498 | 499 | INPUTS: 500 | - weight_decay(float): the weight decay for weights variables in conv2d and separable conv2d 501 | - batch_norm_decay(float): decay for the moving average of batch_norm momentums. 502 | - batch_norm_epsilon(float): small float added to variance to avoid dividing by zero. 503 | 504 | OUTPUTS: 505 | - scope(arg_scope): a tf-slim arg_scope with the parameters needed for xception. 506 | ''' 507 | # Set weight_decay for weights in conv2d and separable_conv2d layers. 508 | with slim.arg_scope([slim.conv2d], 509 | weights_regularizer=slim.l2_regularizer(weight_decay), 510 | biases_regularizer=slim.l2_regularizer(weight_decay)): 511 | 512 | # Set parameters for batch_norm. 513 | with slim.arg_scope([slim.batch_norm], 514 | decay=batch_norm_decay, 515 | epsilon=batch_norm_epsilon) as scope: 516 | return scope 517 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import argparse 5 | from glob import glob 6 | import numpy as np 7 | import cv2 8 | import tensorflow as tf 9 | 10 | slim = tf.contrib.slim 11 | from enet import ENet, ENet_arg_scope 12 | from clustering import cluster, get_instance_masks, save_instance_masks 13 | import time 14 | 15 | 16 | def rebuild_graph(sess, checkpoint_dir, input_image, batch_size, feature_dim): 17 | checkpoint = tf.train.latest_checkpoint(checkpoint_dir) 18 | 19 | num_initial_blocks = 1 20 | skip_connections = False 21 | stage_two_repeat = 2 22 | 23 | with slim.arg_scope(ENet_arg_scope()): 24 | _, _ = ENet(input_image, 25 | num_classes=12, 26 | batch_size=batch_size, 27 | is_training=True, 28 | reuse=None, 29 | num_initial_blocks=num_initial_blocks, 30 | stage_two_repeat=stage_two_repeat, 31 | skip_connections=skip_connections) 32 | 33 | graph = tf.get_default_graph() 34 | last_prelu = graph.get_tensor_by_name('ENet/bottleneck5_1_last_prelu:0') 35 | logits = slim.conv2d_transpose(last_prelu, feature_dim, [2,2], stride=2, 36 | scope='Instance/transfer_layer/conv2d_transpose') 37 | 38 | variables_to_restore = slim.get_variables_to_restore() 39 | saver = tf.train.Saver(variables_to_restore) 40 | saver.restore(sess, checkpoint) 41 | 42 | return logits 43 | 44 | def save_image_with_features_as_color(pred): 45 | p_min = np.min(pred) 46 | p_max = np.max(pred) 47 | pred = (pred - p_min)*255/(p_max-p_min) 48 | pred = pred.astype(np.uint8) 49 | output_file_name = os.path.join(output_dir, 'color_{}.png'.format(str(i).zfill(4))) 50 | cv2.imwrite(output_file_name, np.squeeze(pred)) 51 | 52 | 53 | if __name__=='__main__': 54 | 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument('-m','--modeldir', default='trained_model', help="Directory of trained model") 57 | parser.add_argument('-i', '--indir', default=os.path.join('inference_test', 'images'), help='Input image directory (jpg format)') 58 | parser.add_argument('-o', '--outdir', default=os.path.join('inference_test', 'results'), help='Output directory for inference images') 59 | args = parser.parse_args() 60 | 61 | data_dir = args.indir 62 | output_dir = args.outdir 63 | checkpoint_dir = args.modeldir 64 | 65 | if not os.path.isdir(output_dir): 66 | os.mkdir(output_dir) 67 | 68 | image_paths = glob(os.path.join(data_dir, '*.jpg')) 69 | image_paths.sort() 70 | 71 | num_images = len(image_paths) 72 | 73 | image_shape = (512, 512) 74 | batch_size = 1 75 | feature_dim = 3 76 | 77 | ### Limit GPU memory usage due to occasional crashes 78 | config = tf.ConfigProto() 79 | config.gpu_options.allow_growth = True 80 | config.gpu_options.per_process_gpu_memory_fraction = 0.5 81 | 82 | with tf.Session(config=config) as sess: 83 | 84 | input_image = tf.placeholder(tf.float32, shape=(None, image_shape[1], image_shape[0], 3)) 85 | logits = rebuild_graph(sess, checkpoint_dir, input_image, batch_size, feature_dim) 86 | 87 | inference_time = 0 88 | cluster_time = 0 89 | for i, path in enumerate(image_paths): 90 | 91 | image = cv2.resize(cv2.imread(path), image_shape, interpolation=cv2.INTER_LINEAR) 92 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 93 | image = np.expand_dims(image, axis=0) 94 | 95 | tic = time.time() 96 | prediction = sess.run(logits, feed_dict={input_image: image}) 97 | pred_time = time.time()-tic 98 | print 'Inference time', pred_time 99 | inference_time += pred_time 100 | 101 | 102 | pred_color = np.squeeze(prediction.copy()) 103 | print 'Save prediction', i 104 | #save_image_with_features_as_color(pred_color) 105 | 106 | pred_cluster = prediction.copy() 107 | tic = time.time() 108 | instance_mask = get_instance_masks(pred_cluster, bandwidth=1.)[0] 109 | #save_instance_masks(prediction, output_dir, bandwidth=1., count=i) 110 | print instance_mask.shape 111 | output_file_name = os.path.join(output_dir, 'cluster_{}.png'.format(str(i).zfill(4))) 112 | colors, counts = np.unique(instance_mask.reshape(image_shape[0]*image_shape[1],3), 113 | return_counts=True, axis=0) 114 | max_count = 0 115 | for color, count in zip(colors, counts): 116 | if count > max_count: 117 | max_count = count 118 | bg_color = color 119 | ind = np.where(instance_mask==bg_color) 120 | instance_mask[ind] = 0. 121 | instance_mask = cv2.addWeighted(np.squeeze(image), 1, instance_mask, 0.3, 0) 122 | instance_mask = cv2.resize(instance_mask, (1280,720)) 123 | clust_time = time.time()-tic 124 | cluster_time += clust_time 125 | cv2.imwrite(output_file_name, cv2.cvtColor(instance_mask, cv2.COLOR_RGB2BGR)) 126 | 127 | print 'Mean inference time:', inference_time/num_images, 'fps:', num_images/inference_time 128 | print 'Mean cluster time:', cluster_time/num_images, 'fps:', num_images/cluster_time 129 | print 'Mean total time:', cluster_time/num_images + inference_time/num_images, 'fps:', 1./(cluster_time/num_images + inference_time/num_images) 130 | -------------------------------------------------------------------------------- /inference_test/images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/images/1.jpg -------------------------------------------------------------------------------- /inference_test/images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/images/2.jpg -------------------------------------------------------------------------------- /inference_test/images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/images/3.jpg -------------------------------------------------------------------------------- /inference_test/images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/images/4.jpg -------------------------------------------------------------------------------- /inference_test/images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/images/5.jpg -------------------------------------------------------------------------------- /inference_test/images/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/images/6.jpg -------------------------------------------------------------------------------- /inference_test/results/cluster_0000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/results/cluster_0000.png -------------------------------------------------------------------------------- /inference_test/results/cluster_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/results/cluster_0001.png -------------------------------------------------------------------------------- /inference_test/results/cluster_0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/results/cluster_0002.png -------------------------------------------------------------------------------- /inference_test/results/cluster_0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/results/cluster_0003.png -------------------------------------------------------------------------------- /inference_test/results/cluster_0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/results/cluster_0004.png -------------------------------------------------------------------------------- /inference_test/results/cluster_0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/inference_test/results/cluster_0005.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def discriminative_loss_single(prediction, correct_label, feature_dim, label_shape, 4 | delta_v, delta_d, param_var, param_dist, param_reg): 5 | 6 | ''' Discriminative loss for a single prediction/label pair. 7 | :param prediction: inference of network 8 | :param correct_label: instance label 9 | :feature_dim: feature dimension of prediction 10 | :param label_shape: shape of label 11 | :param delta_v: cutoff variance distance 12 | :param delta_d: curoff cluster distance 13 | :param param_var: weight for intra cluster variance 14 | :param param_dist: weight for inter cluster distances 15 | :param param_reg: weight regularization 16 | ''' 17 | 18 | ### Reshape so pixels are aligned along a vector 19 | correct_label = tf.reshape(correct_label, [label_shape[1]*label_shape[0]]) 20 | reshaped_pred = tf.reshape(prediction, [label_shape[1]*label_shape[0], feature_dim]) 21 | 22 | ### Count instances 23 | unique_labels, unique_id, counts = tf.unique_with_counts(correct_label) 24 | counts = tf.cast(counts, tf.float32) 25 | num_instances = tf.size(unique_labels) 26 | 27 | segmented_sum = tf.unsorted_segment_sum(reshaped_pred, unique_id, num_instances) 28 | 29 | mu = tf.div(segmented_sum, tf.reshape(counts, (-1, 1))) 30 | mu_expand = tf.gather(mu, unique_id) 31 | 32 | ### Calculate l_var 33 | distance = tf.norm(tf.subtract(mu_expand, reshaped_pred), axis=1) 34 | distance = tf.subtract(distance, delta_v) 35 | distance = tf.clip_by_value(distance, 0., distance) 36 | distance = tf.square(distance) 37 | 38 | l_var = tf.unsorted_segment_sum(distance, unique_id, num_instances) 39 | l_var = tf.div(l_var, counts) 40 | l_var = tf.reduce_sum(l_var) 41 | l_var = tf.divide(l_var, tf.cast(num_instances, tf.float32)) 42 | 43 | ### Calculate l_dist 44 | 45 | # Get distance for each pair of clusters like this: 46 | # mu_1 - mu_1 47 | # mu_2 - mu_1 48 | # mu_3 - mu_1 49 | # mu_1 - mu_2 50 | # mu_2 - mu_2 51 | # mu_3 - mu_2 52 | # mu_1 - mu_3 53 | # mu_2 - mu_3 54 | # mu_3 - mu_3 55 | 56 | mu_interleaved_rep = tf.tile(mu, [num_instances, 1]) 57 | mu_band_rep = tf.tile(mu, [1, num_instances]) 58 | mu_band_rep = tf.reshape(mu_band_rep, (num_instances*num_instances, feature_dim)) 59 | 60 | mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep) 61 | 62 | # Filter out zeros from same cluster subtraction 63 | intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff),axis=1) 64 | zero_vector = tf.zeros(1, dtype=tf.float32) 65 | bool_mask = tf.not_equal(intermediate_tensor, zero_vector) 66 | mu_diff_bool = tf.boolean_mask(mu_diff, bool_mask) 67 | 68 | mu_norm = tf.norm(mu_diff_bool, axis=1) 69 | mu_norm = tf.subtract(2.*delta_d, mu_norm) 70 | mu_norm = tf.clip_by_value(mu_norm, 0., mu_norm) 71 | mu_norm = tf.square(mu_norm) 72 | 73 | l_dist = tf.reduce_mean(mu_norm) 74 | 75 | ### Calculate l_reg 76 | l_reg = tf.reduce_mean(tf.norm(mu, axis=1)) 77 | 78 | param_scale = 1. 79 | l_var = param_var * l_var 80 | l_dist = param_dist * l_dist 81 | l_reg = param_reg * l_reg 82 | 83 | loss = param_scale*(l_var + l_dist + l_reg) 84 | 85 | return loss, l_var, l_dist, l_reg 86 | 87 | 88 | def discriminative_loss(prediction, correct_label, feature_dim, image_shape, 89 | delta_v, delta_d, param_var, param_dist, param_reg): 90 | ''' Iterate over a batch of prediction/label and cumulate loss 91 | :return: discriminative loss and its three components 92 | ''' 93 | def cond(label, batch, out_loss, out_var, out_dist, out_reg, i): 94 | return tf.less(i, tf.shape(batch)[0]) 95 | 96 | def body(label, batch, out_loss, out_var, out_dist, out_reg, i): 97 | disc_loss, l_var, l_dist, l_reg = discriminative_loss_single(prediction[i], correct_label[i], feature_dim, image_shape, 98 | delta_v, delta_d, param_var, param_dist, param_reg) 99 | 100 | out_loss = out_loss.write(i, disc_loss) 101 | out_var = out_var.write(i, l_var) 102 | out_dist = out_dist.write(i, l_dist) 103 | out_reg = out_reg.write(i, l_reg) 104 | 105 | return label, batch, out_loss, out_var, out_dist, out_reg, i + 1 106 | 107 | # TensorArray is a data structure that support dynamic writing 108 | output_ta_loss = tf.TensorArray(dtype=tf.float32, 109 | size=0, 110 | dynamic_size=True) 111 | output_ta_var = tf.TensorArray(dtype=tf.float32, 112 | size=0, 113 | dynamic_size=True) 114 | output_ta_dist = tf.TensorArray(dtype=tf.float32, 115 | size=0, 116 | dynamic_size=True) 117 | output_ta_reg = tf.TensorArray(dtype=tf.float32, 118 | size=0, 119 | dynamic_size=True) 120 | 121 | _, _, out_loss_op, out_var_op, out_dist_op, out_reg_op, _ = tf.while_loop(cond, body, [correct_label, 122 | prediction, 123 | output_ta_loss, 124 | output_ta_var, 125 | output_ta_dist, 126 | output_ta_reg, 127 | 0]) 128 | out_loss_op = out_loss_op.stack() 129 | out_var_op = out_var_op.stack() 130 | out_dist_op = out_dist_op.stack() 131 | out_reg_op = out_reg_op.stack() 132 | 133 | disc_loss = tf.reduce_mean(out_loss_op) 134 | l_var = tf.reduce_mean(out_var_op) 135 | l_dist = tf.reduce_mean(out_dist_op) 136 | l_reg = tf.reduce_mean(out_reg_op) 137 | 138 | return disc_loss, l_var, l_dist, l_reg -------------------------------------------------------------------------------- /pretrained_semantic_model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "saved_model-29999" 2 | all_model_checkpoint_paths: "saved_model-9999" 3 | all_model_checkpoint_paths: "saved_model-14999" 4 | all_model_checkpoint_paths: "saved_model-19999" 5 | all_model_checkpoint_paths: "saved_model-24999" 6 | all_model_checkpoint_paths: "saved_model-29999" 7 | -------------------------------------------------------------------------------- /pretrained_semantic_model/saved_model-24999.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/pretrained_semantic_model/saved_model-24999.meta -------------------------------------------------------------------------------- /pretrained_semantic_model/saved_model-29999.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/pretrained_semantic_model/saved_model-29999.data-00000-of-00001 -------------------------------------------------------------------------------- /pretrained_semantic_model/saved_model-29999.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/pretrained_semantic_model/saved_model-29999.index -------------------------------------------------------------------------------- /pretrained_semantic_model/saved_model-29999.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/pretrained_semantic_model/saved_model-29999.meta -------------------------------------------------------------------------------- /todo_semantic_segmentation/helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import numpy as np 4 | import random 5 | from sklearn.utils import shuffle 6 | import shutil 7 | import time 8 | import tensorflow as tf 9 | import cv2 10 | 11 | 12 | ### Mean and std deviation for whole training data set (RGB format) 13 | mean = 0.#np.array([92.14031982, 103.20146942, 103.47182465]) 14 | std = 1.#np.array([49.157, 54.9057, 59.4065]) 15 | 16 | 17 | def get_batches_fn(batch_size, image_shape, image_paths, label_paths): 18 | """ 19 | Create batches of training data 20 | :param batch_size: Batch Size 21 | :param image_shape: input image shape 22 | :param image_paths: list of paths for training or validation 23 | :param label_paths: list of paths for training or validation 24 | :return: Batches of training data 25 | """ 26 | 27 | image_paths.sort() 28 | label_paths.sort() 29 | 30 | #image_paths = image_paths[:20] 31 | #label_paths = label_paths[:20] 32 | 33 | background_color = np.array([0, 0, 0]) 34 | 35 | image_paths, label_paths = shuffle(image_paths, label_paths) 36 | 37 | for batch_i in range(0, len(image_paths), batch_size): 38 | images = [] 39 | gt_images = [] 40 | for image_file, gt_image_file in zip(image_paths[batch_i:batch_i+batch_size], label_paths[batch_i:batch_i+batch_size]): 41 | 42 | ### Image preprocessing 43 | image = cv2.resize(cv2.imread(image_file), (image_shape[1], image_shape[0]), cv2.INTER_LINEAR) 44 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 45 | #image = (image.astype(np.float32)-mean)/std 46 | 47 | ### Label preprocessing 48 | gt_image = cv2.resize(cv2.imread(gt_image_file, cv2.IMREAD_COLOR), (image_shape[1], image_shape[0]), cv2.INTER_NEAREST) 49 | gt_bg = np.all(gt_image == background_color, axis=2) 50 | gt_bg = gt_bg.reshape(gt_bg.shape[0], gt_bg.shape[1], 1) 51 | gt_image = np.concatenate((gt_bg, np.invert(gt_bg)), axis=2) 52 | 53 | images.append(image) 54 | gt_images.append(gt_image) 55 | 56 | yield np.array(images), np.array(gt_images) 57 | 58 | 59 | # Source http://www.pyimagesearch.com/2015/10/05/opencv-gamma-correction/ 60 | def adjust_gamma(image, gamma=1.0): 61 | # build a lookup table mapping the pixel values [0, 255] to 62 | # their adjusted gamma values 63 | invGamma = 1.0 / gamma 64 | table = np.array([((i / 255.0) ** invGamma) * 255 65 | for i in np.arange(0, 256)]).astype("float32") 66 | 67 | # apply gamma correction using the lookup table 68 | return cv2.LUT(image, table) 69 | 70 | 71 | def gen_test_output(sess, logits, keep_prob, image_pl, data_folder, image_shape): 72 | """ 73 | Generate test output using the test images 74 | :param sess: TF session 75 | :param logits: TF Tensor for the logits 76 | :param keep_prob: TF Placeholder for the dropout keep robability 77 | :param image_pl: TF Placeholder for the image placeholder 78 | :param data_folder: Path to the folder that contains the datasets 79 | :param image_shape: Tuple - Shape of image 80 | :return: Output for for each test image 81 | """ 82 | for image_file in glob(os.path.join(data_folder, 'test_images', '*.png'))[:40]: 83 | image = cv2.resize(cv2.imread(image_file), (image_shape[1], image_shape[0]), cv2.INTER_LINEAR) 84 | 85 | ### Run inference 86 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 87 | img_origin = image.copy() 88 | #image = (image.astype(np.float32)-mean)/std 89 | 90 | im_softmax = sess.run( 91 | tf.nn.softmax(logits), 92 | {keep_prob: 1.0, image_pl: [image]}) 93 | 94 | ### Threshholding 95 | im_softmax = im_softmax[:, 1].reshape(image_shape[0], image_shape[1]) 96 | mask_ind = np.where(im_softmax > 0.3) 97 | 98 | ### Overlay class mask over original image 99 | blend = np.zeros_like(img_origin) 100 | blend[mask_ind] = np.array([0,255,0]) 101 | blended = cv2.addWeighted(img_origin, 1, blend, 0.7, 0) 102 | blended = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB) 103 | 104 | yield os.path.basename(image_file), np.array(blended) 105 | 106 | 107 | def save_inference_samples(runs_dir, data_dir, sess, image_shape, logits, keep_prob, input_image): 108 | # Make folder for current run 109 | output_dir = os.path.join(runs_dir, str(time.time())) 110 | if os.path.exists(output_dir): 111 | shutil.rmtree(output_dir) 112 | os.makedirs(output_dir) 113 | 114 | # Run NN on test images and save them to HD 115 | print('Training Finished. Saving test images to: {}'.format(output_dir)) 116 | image_outputs = gen_test_output( 117 | sess, logits, keep_prob, input_image, data_dir, image_shape) 118 | for name, image in image_outputs: 119 | cv2.imwrite(os.path.join(output_dir, name), image) -------------------------------------------------------------------------------- /todo_semantic_segmentation/transfer_semantic.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import tensorflow as tf 3 | import helper 4 | import warnings 5 | from sklearn.utils import shuffle 6 | from sklearn.model_selection import train_test_split 7 | from glob import glob 8 | from enet import ENet, ENet_arg_scope 9 | from tensorflow.contrib.layers.python.layers import initializers 10 | import numpy as np 11 | slim = tf.contrib.slim 12 | 13 | # Check for a GPU 14 | if not tf.test.gpu_device_name(): 15 | warnings.warn('No GPU found. Please use a GPU to train your neural network.') 16 | else: 17 | print('Default GPU Device: {}'.format(tf.test.gpu_device_name())) 18 | 19 | 20 | def load_enet(sess, checkpoint_dir, input_image, batch_size, num_classes): 21 | checkpoint = tf.train.latest_checkpoint(checkpoint_dir) 22 | 23 | num_initial_blocks = 1 24 | skip_connections = False 25 | stage_two_repeat = 2 26 | 27 | with slim.arg_scope(ENet_arg_scope()): 28 | logits, _ = ENet(input_image, 29 | num_classes=12, 30 | batch_size=batch_size, 31 | is_training=True, 32 | reuse=None, 33 | num_initial_blocks=num_initial_blocks, 34 | stage_two_repeat=stage_two_repeat, 35 | skip_connections=skip_connections) 36 | 37 | 38 | variables_to_restore = slim.get_variables_to_restore() 39 | saver = tf.train.Saver(variables_to_restore) 40 | saver.restore(sess, checkpoint) 41 | graph = tf.get_default_graph() 42 | 43 | last_prelu = graph.get_tensor_by_name('ENet/bottleneck5_1_last_prelu:0') 44 | output = slim.conv2d_transpose(last_prelu, num_classes, [2,2], stride=2, 45 | weights_initializer=initializers.xavier_initializer(), 46 | scope='Semantic/transfer_layer/conv2d_transpose') 47 | 48 | probabilities = tf.nn.softmax(output, name='Semantic/transfer_layer/logits_to_softmax') 49 | 50 | with tf.variable_scope('', reuse=True): 51 | weight = tf.get_variable('Semantic/transfer_layer/conv2d_transpose/weights') 52 | bias = tf.get_variable('Semantic/transfer_layer/conv2d_transpose/biases') 53 | sess.run([weight.initializer, bias.initializer]) 54 | 55 | return output, probabilities 56 | 57 | 58 | 59 | def optimize(sess, logits, correct_label, learning_rate, num_classes, trainables, global_step): 60 | """ 61 | Build the TensorFLow loss and optimizer operations. 62 | :param nn_last_layer: TF Tensor of the last layer in the neural network 63 | :param correct_label: TF Placeholder for the correct label image 64 | :param learning_rate: TF Placeholder for the learning rate 65 | :param num_classes: Number of classes to classify 66 | :return: Tuple of (logits, train_op, cross_entropy_loss) 67 | """ 68 | # TODO: Implement function 69 | 70 | #correct_label = tf.reshape(correct_label, (-1, num_classes)) 71 | #logits = tf.reshape(nn_last_layer, (-1, num_classes)) 72 | 73 | weights = correct_label * np.array([1., 40.]) 74 | weights = tf.reduce_sum(weights, axis=3) 75 | loss = tf.reduce_mean(tf.losses.softmax_cross_entropy(onehot_labels=correct_label, logits=logits, weights=weights)) 76 | 77 | 78 | #loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=correct_label, logits=logits)) 79 | with tf.name_scope('Semantic/Adam'): 80 | train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss, var_list=trainables, global_step=global_step) 81 | adam_initializers = [var.initializer for var in tf.global_variables() if 'Adam' in var.name] 82 | sess.run(adam_initializers) 83 | return logits, train_op, loss 84 | 85 | 86 | def run(): 87 | 88 | ### Initialization 89 | image_shape = (512, 512) # (width, height) 90 | model_dir = '../checkpoint' 91 | data_dir = '../../tusimple_api/clean_data' 92 | log_dir = './log' 93 | output_dir = './saved_model' 94 | 95 | num_classes = 2 96 | epochs = 20 97 | batch_size = 1 98 | starter_learning_rate = 1e-4 99 | learning_rate_decay_interval = 500 100 | learning_rate_decay_rate = 0.96 101 | ### Load images and labels 102 | image_paths = glob(os.path.join(data_dir, 'images', '*.png')) 103 | label_paths = glob(os.path.join(data_dir, 'labels', '*.png')) 104 | 105 | #image_paths = image_paths[:20] 106 | #label_paths = label_paths[:20] 107 | 108 | X_train, X_valid, y_train, y_valid = train_test_split(image_paths, label_paths, test_size=0.20, random_state=42) 109 | 110 | ### Limit GPU memory usage due to ocassional crashes 111 | config = tf.ConfigProto() 112 | config.gpu_options.allow_growth = True 113 | config.gpu_options.per_process_gpu_memory_fraction = 0.7 114 | 115 | 116 | 117 | with tf.Session(config=config) as sess: 118 | 119 | ### Load ENet and replace layers 120 | input_image = tf.placeholder(tf.float32, shape=[batch_size, image_shape[1], image_shape[0], 3]) 121 | correct_label = tf.placeholder(dtype=tf.float32, shape=(None, image_shape[1], image_shape[0], 2), name='Semantic/input_image') 122 | 123 | logits, probabilities = load_enet(sess, model_dir, input_image, batch_size, num_classes) 124 | predictions_val = tf.argmax(probabilities, axis=-1) 125 | predictions_val = tf.cast(predictions_val, dtype=tf.float32) 126 | predictions_val = tf.reshape(predictions_val, shape=[batch_size, image_shape[1], image_shape[0], 1]) 127 | 128 | 129 | ### Set up learning rate decay 130 | global_step = tf.Variable(0, trainable=False) 131 | sess.run(global_step.initializer) 132 | learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 133 | learning_rate_decay_interval, learning_rate_decay_rate, staircase=True) 134 | 135 | for i, var in enumerate(tf.trainable_variables()): 136 | print i, var 137 | tf.summary.histogram(var.name, var) 138 | 139 | trainables = [var for var in tf.trainable_variables() if 'bias' not in var.name and 'ENet/fullconv' not in var.name] 140 | 141 | ### Print variables which are actually trained 142 | for var in trainables: 143 | print var 144 | 145 | logits, train_op, cross_entropy_loss = optimize(sess, logits, correct_label, learning_rate, num_classes, trainables, global_step) 146 | tf.summary.scalar('training_loss', cross_entropy_loss) 147 | tf.summary.image('Images/Validation_original_image', input_image, max_outputs=1) 148 | tf.summary.image('Images/Validation_segmentation_output', predictions_val, max_outputs=1) 149 | summary_train = tf.summary.merge_all() 150 | summary_valid = tf.summary.scalar('validation_loss', cross_entropy_loss) 151 | 152 | train_writer = tf.summary.FileWriter(log_dir) 153 | 154 | saver = tf.train.Saver() 155 | 156 | ### Training pipeline 157 | step_train = 0 158 | step_valid = 0 159 | summary_cycle = 10 160 | for epoch in range(epochs): 161 | print 'epoch', epoch 162 | print 'training ...' 163 | train_loss = 0 164 | for image, label in helper.get_batches_fn(batch_size, image_shape, X_train, y_train): 165 | # Training 166 | lr = sess.run(learning_rate) 167 | if step_train%summary_cycle==0: 168 | _, summary, loss = sess.run([train_op, summary_train, cross_entropy_loss], 169 | feed_dict={input_image: image, correct_label: label}) 170 | train_writer.add_summary(summary, step_train) 171 | print 'epoch', epoch, '\t step_train', step_train, '\t batch loss', loss, '\t current learning rate', lr 172 | else: 173 | _, loss = sess.run([train_op, cross_entropy_loss], 174 | feed_dict={input_image: image, correct_label: label}) 175 | step_train+=1 176 | train_loss += loss 177 | 178 | if (step_train%5000==4999): 179 | saver.save(sess, os.path.join(output_dir, 'model.ckpt'), global_step=global_step) 180 | 181 | 182 | print 'train epoch loss', train_loss 183 | 184 | print 'validating ...' 185 | valid_loss = 0 186 | for image, label in helper.get_batches_fn(batch_size, image_shape, X_valid, y_valid): 187 | # Validation 188 | if step_valid%summary_cycle==0: 189 | summary, loss = sess.run([summary_valid, cross_entropy_loss], 190 | feed_dict={input_image: image, correct_label: label}) 191 | train_writer.add_summary(summary, step_valid) 192 | print 'batch loss', loss 193 | else: 194 | loss = sess.run(cross_entropy_loss, 195 | feed_dict={input_image: image, correct_label: label}) 196 | 197 | valid_loss += loss 198 | step_valid+=1 199 | 200 | print 'valid epoch loss', valid_loss 201 | 202 | 203 | 204 | 205 | if __name__ == '__main__': 206 | run() 207 | -------------------------------------------------------------------------------- /trained_model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "model.ckpt-104999" 2 | all_model_checkpoint_paths: "model.ckpt-44999" 3 | all_model_checkpoint_paths: "model.ckpt-59999" 4 | all_model_checkpoint_paths: "model.ckpt-74999" 5 | all_model_checkpoint_paths: "model.ckpt-89999" 6 | all_model_checkpoint_paths: "model.ckpt-104999" 7 | -------------------------------------------------------------------------------- /trained_model/model.ckpt-104999.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/trained_model/model.ckpt-104999.data-00000-of-00001 -------------------------------------------------------------------------------- /trained_model/model.ckpt-104999.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/trained_model/model.ckpt-104999.index -------------------------------------------------------------------------------- /trained_model/model.ckpt-104999.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hq-jiang/instance-segmentation-with-discriminative-loss-tensorflow/355430dae9c36c3a1fc006f774eebf407d905d54/trained_model/model.ckpt-104999.meta -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import sys 5 | import warnings 6 | import copy 7 | from glob import glob 8 | import argparse 9 | 10 | import numpy as np 11 | import cv2 12 | import tensorflow as tf 13 | from sklearn.model_selection import train_test_split 14 | from tensorflow.contrib.layers.python.layers import initializers 15 | 16 | import matplotlib as mpl 17 | import matplotlib.pyplot as plt 18 | from mpl_toolkits.mplot3d import Axes3D 19 | import matplotlib.cm as cm 20 | 21 | import utils 22 | from loss import discriminative_loss 23 | import datagenerator 24 | import visualization 25 | import clustering 26 | 27 | 28 | 29 | def run(): 30 | parser = argparse.ArgumentParser() 31 | # Directories 32 | parser.add_argument('-s','--srcdir', default='data', help="Source directory of TuSimple dataset") 33 | parser.add_argument('-m', '--modeldir', default='pretrained_semantic_model', help="Output directory of extracted data") 34 | parser.add_argument('-o', '--outdir', default='saved_model', help="Directory for trained model") 35 | parser.add_argument('-l', '--logdir', default='log', help="Log directory for tensorboard and evaluation files") 36 | # Hyperparameters 37 | parser.add_argument('--epochs', type=int, default=50, help="Number of epochs") 38 | parser.add_argument('--var', type=float, default=1., help="Weight of variance loss") 39 | parser.add_argument('--dist', type=float, default=1., help="Weight of distance loss") 40 | parser.add_argument('--reg', type=float, default=0.001, help="Weight of regularization loss") 41 | parser.add_argument('--dvar', type=float, default=0.5, help="Cutoff variance") 42 | parser.add_argument('--ddist', type=float, default=1.5, help="Cutoff distance") 43 | 44 | args = parser.parse_args() 45 | 46 | if not os.path.isdir(args.srcdir): 47 | raise IOError('Directory does not exist') 48 | if not os.path.isdir(args.modeldir): 49 | raise IOError('Directory does not exist') 50 | if not os.path.isdir(args.logdir): 51 | os.mkdir(args.logdir) 52 | 53 | image_shape = (512, 512) 54 | data_dir = args.srcdir #os.path.join('.', 'data') 55 | model_dir = args.modeldir 56 | output_dir = args.outdir 57 | log_dir = args.logdir 58 | 59 | image_paths = glob(os.path.join(data_dir, 'images', '*.png')) 60 | label_paths = glob(os.path.join(data_dir, 'labels', '*.png')) 61 | 62 | image_paths.sort() 63 | label_paths.sort() 64 | 65 | #image_paths = image_paths[0:10] 66 | #label_paths = label_paths[0:10] 67 | 68 | X_train, X_valid, y_train, y_valid = train_test_split(image_paths, label_paths, test_size=0.10, random_state=42) 69 | 70 | print ('Number of train samples', len(y_train)) 71 | print ('Number of valid samples', len(y_valid)) 72 | 73 | 74 | ### Debugging 75 | debug_clustering = True 76 | bandwidth = 0.7 77 | cluster_cycle = 5000 78 | eval_cycle=1000 79 | save_cycle=15000 80 | 81 | ### Hyperparameters 82 | epochs = args.epochs 83 | batch_size = 1 84 | starter_learning_rate = 1e-4 85 | learning_rate_decay_rate = 0.96 86 | learning_rate_decay_interval = 5000 87 | 88 | feature_dim = 3 89 | param_var = args.var 90 | param_dist = args.dist 91 | param_reg = args.reg 92 | delta_v = args.dvar 93 | delta_d = args.ddist 94 | 95 | param_string = 'fdim'+str(feature_dim)+'_var'+str(param_var)+'_dist'+str(param_dist)+'_reg'+str(param_reg) \ 96 | +'_dv'+str(delta_v)+'_dd'+str(delta_d) \ 97 | +'_lr'+str(starter_learning_rate)+'_btch'+str(batch_size) 98 | 99 | if not os.path.exists(os.path.join(log_dir, param_string)): 100 | os.makedirs(os.path.join(log_dir, param_string)) 101 | 102 | 103 | ### Limit GPU memory usage due to ocassional crashes 104 | config = tf.ConfigProto() 105 | #config.gpu_options.allow_growth = True 106 | #config.gpu_options.per_process_gpu_memory_fraction = 0.5 107 | 108 | 109 | with tf.Session(config=config) as sess: 110 | 111 | ### Build network 112 | input_image = tf.placeholder(tf.float32, shape=(None, image_shape[1], image_shape[0], 3)) 113 | correct_label = tf.placeholder(dtype=tf.float32, shape=(None, image_shape[1], image_shape[0])) 114 | 115 | last_prelu = utils.load_enet(sess, model_dir, input_image, batch_size) 116 | prediction = utils.add_transfer_layers_and_initialize(sess, last_prelu, feature_dim) 117 | 118 | print ('Number of parameters in the model', utils.count_parameters()) 119 | ### Set up learning rate decay 120 | global_step = tf.Variable(0, trainable=False) 121 | sess.run(global_step.initializer) 122 | learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 123 | learning_rate_decay_interval, learning_rate_decay_rate, staircase=True) 124 | 125 | ### Set variables to train 126 | trainables = utils.get_trainable_variables_and_initialize(sess, debug=False) 127 | 128 | ### Optimization operations 129 | disc_loss, l_var, l_dist, l_reg = discriminative_loss(prediction, correct_label, feature_dim, image_shape, 130 | delta_v, delta_d, param_var, param_dist, param_reg) 131 | with tf.name_scope('Instance/Adam'): 132 | train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(disc_loss, var_list=trainables, global_step=global_step) 133 | adam_initializers = [var.initializer for var in tf.global_variables() if 'Adam' in var.name] 134 | sess.run(adam_initializers) 135 | 136 | 137 | ### Collect summaries 138 | summary_op_train, summary_op_valid = utils.collect_summaries(disc_loss, l_var, l_dist, l_reg, input_image, prediction, correct_label) 139 | 140 | train_writer = tf.summary.FileWriter(log_dir) 141 | 142 | 143 | ### Check if image and labels match 144 | valid_image_chosen, valid_label_chosen = datagenerator.get_validation_batch(data_dir, image_shape) 145 | print (valid_image_chosen.shape) 146 | #visualization.save_image_overlay(valid_image_chosen.copy(), valid_label_chosen.copy()) 147 | 148 | 149 | ### Training pipeline 150 | saver = tf.train.Saver() 151 | step_train=0 152 | step_valid=0 153 | for epoch in range(epochs): 154 | print ('epoch', epoch) 155 | 156 | train_loss = 0 157 | for image, label in datagenerator.get_batches_fn(batch_size, image_shape, X_train, y_train): 158 | 159 | lr = sess.run(learning_rate) 160 | 161 | if (step_train%eval_cycle!=0): 162 | ### Training 163 | _, step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([ 164 | train_op, 165 | prediction, 166 | disc_loss, 167 | l_var, 168 | l_dist, 169 | l_reg], 170 | feed_dict={input_image: image, correct_label: label}) 171 | else: 172 | # First run normal training step and record summaries 173 | print ('Evaluating on chosen images ...') 174 | _, summary, step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([ 175 | train_op, 176 | summary_op_train, 177 | prediction, 178 | disc_loss, 179 | l_var, 180 | l_dist, 181 | l_reg], 182 | feed_dict={input_image: image, correct_label: label}) 183 | train_writer.add_summary(summary, step_train) 184 | 185 | # Then run model on some chosen images and save feature space visualization 186 | valid_pred = sess.run(prediction, feed_dict={input_image: np.expand_dims(valid_image_chosen[0], axis=0), 187 | correct_label: np.expand_dims(valid_label_chosen[0], axis=0)}) 188 | visualization.evaluate_scatter_plot(log_dir, valid_pred, valid_label_chosen, feature_dim, param_string, step_train) 189 | 190 | # Perform mean-shift clustering on prediction 191 | if (step_train%cluster_cycle==0): 192 | if debug_clustering: 193 | instance_masks = clustering.get_instance_masks(valid_pred, bandwidth) 194 | for img_id, mask in enumerate(instance_masks): 195 | cv2.imwrite(os.path.join(log_dir, param_string, 'cluster_{}_{}.png'.format(str(step_train).zfill(6), str(img_id)) ), mask) 196 | 197 | step_train += 1 198 | 199 | ### Save intermediate model 200 | if (step_train%save_cycle==(save_cycle-1)): 201 | try: 202 | print ('Saving model ...') 203 | saver.save(sess, os.path.join(output_dir, 'model.ckpt'), global_step=step_train) 204 | except: 205 | print ('FAILED saving model') 206 | #print 'gradient', step_gradient 207 | print ('step', step_train, '\tloss', step_loss, '\tl_var', step_l_var, '\tl_dist', step_l_dist, '\tl_reg', step_l_reg, '\tcurrent lr', lr) 208 | 209 | 210 | ### Regular validation 211 | print ('Evaluating current model ...') 212 | for image, label in datagenerator.get_batches_fn(batch_size, image_shape, X_valid, y_valid): 213 | if step_valid%100==0: 214 | summary, step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([ 215 | summary_op_valid, 216 | prediction, 217 | disc_loss, 218 | l_var, 219 | l_dist, 220 | l_reg], 221 | feed_dict={input_image: image, correct_label: label}) 222 | train_writer.add_summary(summary, step_valid) 223 | else: 224 | step_prediction, step_loss, step_l_var, step_l_dist, step_l_reg = sess.run([ 225 | prediction, 226 | disc_loss, 227 | l_var, 228 | l_dist, 229 | l_reg], 230 | feed_dict={input_image: image, correct_label: label}) 231 | step_valid += 1 232 | 233 | 234 | print ('step_valid', step_valid, 'valid loss', step_loss, '\tvalid l_var', step_l_var, '\tvalid l_dist', step_l_dist, '\tvalid l_reg', step_l_reg) 235 | 236 | saver = tf.train.Saver() 237 | saver.save(sess, os.path.join(output_dir, 'model.ckpt'), global_step=step_train) 238 | 239 | 240 | if __name__ == '__main__': 241 | run() 242 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | sys.path.append('../base') 4 | import tensorflow as tf 5 | from enet import ENet, ENet_arg_scope 6 | slim = tf.contrib.slim 7 | 8 | def load_enet(sess, checkpoint_dir, input_image, batch_size): 9 | checkpoint = tf.train.latest_checkpoint(checkpoint_dir) 10 | 11 | num_initial_blocks = 1 12 | skip_connections = False 13 | stage_two_repeat = 2 14 | 15 | with slim.arg_scope(ENet_arg_scope()): 16 | _, _ = ENet(input_image, 17 | num_classes=12, 18 | batch_size=batch_size, 19 | is_training=True, 20 | reuse=None, 21 | num_initial_blocks=num_initial_blocks, 22 | stage_two_repeat=stage_two_repeat, 23 | skip_connections=skip_connections) 24 | 25 | variables_to_restore = slim.get_variables_to_restore() 26 | saver = tf.train.Saver(variables_to_restore) 27 | saver.restore(sess, checkpoint) 28 | 29 | graph = tf.get_default_graph() 30 | last_prelu = graph.get_tensor_by_name('ENet/bottleneck5_1_last_prelu:0') 31 | return last_prelu 32 | 33 | def add_transfer_layers_and_initialize(sess, last_prelu, feature_dim): 34 | 35 | logits = slim.conv2d_transpose(last_prelu, feature_dim, [2,2], stride=2, 36 | biases_initializer=tf.constant_initializer(10.0), 37 | weights_initializer=tf.contrib.layers.xavier_initializer(), 38 | scope='Instance/transfer_layer/conv2d_transpose') 39 | 40 | with tf.variable_scope('', reuse=True): 41 | weight = tf.get_variable('Instance/transfer_layer/conv2d_transpose/weights') 42 | bias = tf.get_variable('Instance/transfer_layer/conv2d_transpose/biases') 43 | sess.run([weight.initializer, bias.initializer]) 44 | 45 | return logits 46 | 47 | def get_trainable_variables_and_initialize(sess, debug=False): 48 | ''' Determine which variables to train and reset 49 | We accumulate all variables we want to train in a list to pass it to the optimizer. 50 | As mentioned in the 'Fast Scene Understanding' paper we want to freeze stage 1 and 2 51 | from the ENet and train stage 3-5. The variables from the later stages are reseted. 52 | Additionally all biases are not trained. 53 | 54 | :return: trainables: List of variables we want to train 55 | 56 | ''' 57 | ### Freeze shared encode 58 | trainables = [var for var in tf.trainable_variables() if 'bias' not in var.name]# and \ 59 | #'ENet/fullconv' not in var.name and \ 60 | #'ENet/initial_block_1' not in var.name and \ 61 | #'ENet/bottleneck1' not in var.name and \ 62 | #'ENet/bottleneck2' not in var.name 63 | #] 64 | if debug: 65 | print ('All trainable variables') 66 | for i, var in enumerate(tf.trainable_variables()): 67 | print (i, var) 68 | print ('variables which are actually trained') 69 | for var in trainables: 70 | print (var) 71 | 72 | ### Design choice: reset decoder network to default initialize weights 73 | # Reset all trainable variables 74 | #sess.run(tf.variables_initializer(trainables)) 75 | # Additionally reset all biases in the decoder network 76 | # Encoder retains pretrained biases 77 | sess.run(tf.variables_initializer([var for var in tf.trainable_variables() if 'bias' in var.name and \ 78 | 'ENet/initial_block_1' not in var.name and \ 79 | 'ENet/bottleneck1' not in var.name and \ 80 | 'ENet/bottleneck2' not in var.name]) 81 | ) 82 | return trainables 83 | 84 | def collect_summaries(disc_loss, l_var, l_dist, l_reg, input_image, prediction, correct_label): 85 | 86 | summaries = [] 87 | # Collect all variables 88 | for var in tf.trainable_variables(): 89 | summaries.append(tf.summary.histogram(var.name, var)) 90 | # Collect losses 91 | summaries.append(tf.summary.scalar('Train/disc_loss', disc_loss)) 92 | summaries.append(tf.summary.scalar('Train/l_var', l_var)) 93 | summaries.append(tf.summary.scalar('Train/l_dist', l_dist)) 94 | summaries.append(tf.summary.scalar('Train/l_reg', l_reg)) 95 | # Collect images 96 | summaries.append(tf.summary.image('Train/Images/Input', input_image, max_outputs=1)) 97 | summaries.append(tf.summary.image('Train/Images/Prediction', tf.expand_dims(prediction[:,:,:,0], axis=3), max_outputs=1)) 98 | summaries.append(tf.summary.image('Train/Images/Label', tf.expand_dims(correct_label, axis=3), max_outputs=1)) 99 | 100 | for summ in summaries: 101 | tf.add_to_collection('CUSTOM_SUMMARIES', summ) 102 | 103 | summary_op_train = tf.summary.merge_all('CUSTOM_SUMMARIES') 104 | 105 | summaries_valid = [] 106 | summaries_valid.append(tf.summary.image('Valid/Images/Input', input_image, max_outputs=1)) 107 | summaries_valid.append(tf.summary.image('Valid/Images/Prediction', tf.expand_dims(prediction[:,:,:,0], axis=3), max_outputs=1)) 108 | summaries_valid.append(tf.summary.image('Valid/Images/Label', tf.expand_dims(correct_label, axis=3), max_outputs=1)) 109 | summaries_valid.append(tf.summary.scalar('Valid/disc_loss', disc_loss)) 110 | summary_op_valid = tf.summary.merge(summaries_valid) 111 | return summary_op_train, summary_op_valid 112 | 113 | def count_parameters(): 114 | total_parameters = 0 115 | for var in tf.trainable_variables(): 116 | shape = var.get_shape() 117 | variable_parameters = 1 118 | for dim in shape: 119 | variable_parameters *= dim.value 120 | total_parameters += variable_parameters 121 | return total_parameters -------------------------------------------------------------------------------- /visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | from clustering import cluster 6 | 7 | def save_image_overlay(valid_image, valid_label): 8 | 9 | assert len(valid_image.shape)==3 and len(valid_label.shape)==2, \ 10 | 'input dimensions should be [h,w,c]' 11 | 12 | num_unique = np.unique(valid_label) 13 | blended = valid_image 14 | for color_id, unique in enumerate(list(num_unique[1:])): 15 | instance_ind = np.where(valid_label==unique) 16 | alpha = np.zeros_like(valid_image) 17 | alpha[instance_ind] = np.array([color_id*70, color_id*70, 255-color_id*50]) 18 | 19 | blended = cv2.addWeighted(blended, 1, alpha, 1, 0) 20 | blended = cv2.cvtColor(blended, cv2.COLOR_RGB2BGR) 21 | cv2.imwrite('overlayed_image.png', blended) 22 | 23 | 24 | def evaluate_scatter_plot(log_dir, valid_pred, valid_label, feature_dim, param_string, step): 25 | 26 | assert len(valid_pred.shape)==4 and len(valid_label.shape)==3, \ 27 | 'input dimensions should be [b,h,w,c] and [b,h,w]' 28 | 29 | assert valid_pred.shape[3]==feature_dim, 'feature dimension and prediction do not match' 30 | 31 | 32 | fig = plt.figure() #plt.figure(figsize=(10,8)) 33 | if feature_dim==2: 34 | 35 | #for i in range(valid_pred.shape[0]): 36 | # plt.subplot(2,2,i+1) 37 | # #valid_label = valid_label[0] 38 | # #print 'valid_pred', valid_pred.shape 39 | # #print 'valid_label', valid_label.shape 40 | # num_unique = np.unique(valid_label[i]) 41 | num_unique = np.unique(valid_label[0]) 42 | for unique in list(num_unique): 43 | instance_ind = np.where(valid_label[0]==unique) 44 | #print 'instance id', instance_ind 45 | #print valid_pr[instance_ind].shape 46 | x = valid_pred[0,:,:,0][instance_ind] 47 | y = valid_pred[0,:,:,1][instance_ind] 48 | plt.plot(x, y, 'o') 49 | #plt.imshow(valid_label[i]) 50 | 51 | elif feature_dim==3: 52 | #for i in range(valid_pred.shape[0]): 53 | # ax = fig.add_subplot(2,2,i+1, projection='3d') 54 | # #valid_pred = valid_pred[0] 55 | # #valid_label = valid_label[0] 56 | ax = fig.add_subplot(1,1,1, projection='3d') 57 | num_unique = np.unique(valid_label[0]) 58 | colors = [(0., 0., 1., 0.05), 'g', 'r', 'c', 'm', 'y'] 59 | for color_id, unique in enumerate(list(num_unique)): 60 | instance_ind = np.where(valid_label[0]==unique) 61 | #print 'instance id', instance_ind 62 | #print valid_pr[instance_ind].shape 63 | x = valid_pred[0,:,:,0][instance_ind] 64 | y = valid_pred[0,:,:,1][instance_ind] 65 | z = valid_pred[0,:,:,2][instance_ind] 66 | 67 | ax.scatter(x, y, z, c=colors[color_id]) 68 | elif feature_dim > 3: 69 | plt.close(fig) 70 | return None 71 | 72 | plt.savefig(os.path.join(log_dir, param_string, 'cluster_{}.png'.format(str(step).zfill(6))), bbox_inches='tight') 73 | plt.close(fig) 74 | --------------------------------------------------------------------------------