├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── kitti │ └── README.md ├── dataset.py ├── fcn.py ├── graphs └── kitti │ └── README.md ├── helper.py ├── images ├── fcn_graph.png └── loss_graph.png ├── loss.py ├── model_utils.py ├── pretrained_weights └── README.md └── sample_output ├── um_000014.png ├── um_000032.png ├── umm_000003.png ├── umm_000015.png ├── umm_000021.png ├── umm_000034.png ├── umm_000091.png ├── uu_000009.png ├── uu_000017.png ├── uu_000022.png ├── uu_000081.png └── uu_000099.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Upul Bandara 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Segmentation using a Fully Convolutional Neural Network 2 | 3 | ### Introduction 4 | This repository contains a set of python scripts to train and test semantic segmentation using a fully convolutional neural network. The semantic segmentation network is based on the [paper](https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf) described by Jonathan Long et al. 5 | 6 | #### How to Train the Model 7 | 1. Since the network uses VGG-16 weights, first, you have to download VGG-16 pre-trained weights from [https://www.cs.toronto.edu/~frossard/vgg16/vgg16_weights.npz](https://www.cs.toronto.edu/~frossard/vgg16/vgg16_weights.npz) and save in the the `pretrained_weights` folder. 8 | 2. Download [KITTI dataset](http://www.cvlibs.net/datasets/kitti/eval_road.php) and save it in the `data/data_road` folder. 9 | 3. Next, open a command window and type `python fcn.py` and hit the enter key. 10 | 11 | Please note that training checkpointing will be saved to `checkpoints/kitti` folder and logs will be saved to `graphs/kitti` folder. So by using `tensorboard --logdir=graphs/kitti` command, you can start tensorboard to inspect the training process. 12 | 13 | Following images show sample output we obtained with the trained model. 14 | 15 | ![img_1](./sample_output/um_000014.png) 16 | ![img_1](./sample_output/um_000032.png) 17 | ![img_1](./sample_output/uu_000022.png) 18 | ![img_1](./sample_output/uu_000081.png) 19 | 20 | ### Network Architecture 21 | 22 | We implement the `FCN-8s` model described in the [paper](https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf) by Jonathan Long et al. Following figure shows the architecture of the network. We generated this figure using TensorBoard. 23 | 24 | ![architecture](./images/fcn_graph.png) 25 | 26 | Additionally, we would like to describe main functionalities of the `python` scripts of this repository in the following table. 27 | 28 | |Script |Description| 29 | |:------|:----------| 30 | |`fcn.py`|This is the main script of the repository. The key methods of this script are:`build`, `optimize` and `inference`. The `build` method load pre-trained weights and build the network. The `optimize` method does the training and `inference` is used for testing with new images.| 31 | |`loss.py`|The script contains the loss function we optimize during the training.| 32 | |`helper.py`|This script contains some useful utility function for generating training and testing batches.| 33 | |`model_utils.py`|This script contains some useful utility functions to building fully convolutional network using VGG-16 pre-trained weights.| 34 | 35 | ### The KITTI dataset 36 | 37 | For training the semantic segmentation network, we used the [KITTI dataset](http://www.cvlibs.net/datasets/kitti/eval_road.php). The dataset consists of 289 training and 290 test images. It contains three different categories of road scenes: 38 | 39 | * uu - urban unmarked (98/100) 40 | * um - urban marked (95/96) 41 | * umm - urban multiple marked lanes (96/94) 42 | 43 | ### Training the Model 44 | 45 | When it comes to training any deep learning algorithm, selecting suitable hyper-parameters play a big role. For this project, we carefully select following hyper-parameters 46 | 47 | |Parameter |Value |Description| 48 | |:---------|:------|:----------| 49 | |Learning Rate|1e-5|We used `Adam` optimizer and normally 1e-3 or 1e-4 is the suggested learning rate. However, when we were experimenting with different learning rates we found out that 1e-5 works better than above values.| 50 | |Number of epochs|25|The training dataset is not too big and it has only 289 training examples. Hence, we use a moderate number of epochs.| 51 | |Batch Size|8|Based on the size of the training dataset, we selected batch size of 8 images.| 52 | 53 | The following image shows how the training loss changes when we train the model. 54 | ![loss_graph](./images/loss_graph.png) 55 | 56 | ### Conclusion 57 | 58 | In this project, we investigated how to use a fully convolutional neural network for semantic segmentation. We tested our model against KITTI dataset. The results indicate that our model is quite capable of separating road pixels form the rest. However, we would like to work on following additional ta to increase the accuracy of our model. 59 | 1. Data Augmentation: During our testing, we have found that our mode failed to label road surface when inadequate lighting in the environment. We think data augmentation can be used to generate more training examples with different lighting conditions. So additional data generated using data augmentation will help us to overcome the above-mentioned issue. 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /checkpoints/kitti/README.md: -------------------------------------------------------------------------------- 1 | Checkpointing will be saved to this folder. 2 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/dataset.py -------------------------------------------------------------------------------- /fcn.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import shutil 4 | import tensorflow as tf 5 | import time 6 | import numpy as np 7 | 8 | import helper 9 | from loss import logistic_loss 10 | from model_utils import conv_layer 11 | from model_utils import max_pool_layer 12 | from model_utils import fully_collected_layer 13 | from model_utils import upsample_layer 14 | from model_utils import skip_layer_connection 15 | from model_utils import preprocess 16 | 17 | 18 | class FCN: 19 | def __init__(self, input_shape, num_train_examples, viz_dir, batch_size=8, num_classes=2): 20 | self.images_batch, self.labels_batch, self.images_viz, self.dropout, self.global_step = self._build_placeholders( 21 | input_shape) 22 | self.input_shape = input_shape 23 | self.viz_dir = viz_dir 24 | self.num_train_examples = num_train_examples 25 | self.batch_size = batch_size 26 | self.n_classes = num_classes 27 | self.sess = tf.Session() 28 | self.logits = self.build() 29 | 30 | def build(self): 31 | images = preprocess(self.images_batch) 32 | 33 | conv1_1 = conv_layer(images, 'conv1_1_W', 'conv1_1_b', name='conv1_1') 34 | conv1_2 = conv_layer(conv1_1, 'conv1_2_W', 'conv1_2_b', name='conv1_2') 35 | pool1 = max_pool_layer(conv1_2, [1, 2, 2, 1], [1, 2, 2, 1], name='pool1') 36 | 37 | conv2_1 = conv_layer(pool1, 'conv2_1_W', 'conv2_1_b', name='conv2_1') 38 | conv2_2 = conv_layer(conv2_1, 'conv2_2_W', 'conv2_2_b', name='conv2_2') 39 | pool2 = max_pool_layer(conv2_2, [1, 2, 2, 1], [1, 2, 2, 1], name='pool2') 40 | 41 | conv3_1 = conv_layer(pool2, 'conv3_1_W', 'conv3_1_b', name='conv3_1') 42 | conv3_2 = conv_layer(conv3_1, 'conv3_2_W', 'conv3_2_b', name='conv3_2') 43 | conv3_3 = conv_layer(conv3_2, 'conv3_3_W', 'conv3_3_b', name='conv3_3') 44 | pool3 = max_pool_layer(conv3_3, [1, 2, 2, 1], [1, 2, 2, 1], name='pool3') 45 | 46 | conv4_1 = conv_layer(pool3, 'conv4_1_W', 'conv4_1_b', name='conv4_1') 47 | conv4_2 = conv_layer(conv4_1, 'conv4_2_W', 'conv4_2_b', name='conv4_2') 48 | conv4_3 = conv_layer(conv4_2, 'conv4_3_W', 'conv4_3_b', name='conv4_3') 49 | pool4 = max_pool_layer(conv4_3, [1, 2, 2, 1], [1, 2, 2, 1], name='pool4') 50 | 51 | conv5_1 = conv_layer(pool4, 'conv5_1_W', 'conv5_1_b', name='conv5_1') 52 | conv5_2 = conv_layer(conv5_1, 'conv5_2_W', 'conv5_2_b', name='conv5_2') 53 | conv5_3 = conv_layer(conv5_2, 'conv5_3_W', 'conv5_3_b', name='conv5_3') 54 | pool5 = max_pool_layer(conv5_3, [1, 2, 2, 1], [1, 2, 2, 1], name='pool5') 55 | 56 | fc_1 = fully_collected_layer(pool5, 'fc_1', self.dropout) 57 | fc_2 = fully_collected_layer(fc_1, 'fc_2', self.dropout) 58 | fc_3 = fully_collected_layer(fc_2, 'fc_3', self.dropout) 59 | 60 | # New we start upsampling and skip layer connections. 61 | img_shape = tf.shape(self.images_batch) 62 | dconv3_shape = tf.stack([img_shape[0], img_shape[1], img_shape[2], self.n_classes]) 63 | upsample_1 = upsample_layer(fc_3, dconv3_shape, self.n_classes, 'upsample_1', 32) 64 | 65 | skip_1 = skip_layer_connection(pool4, 'skip_1', 512, stddev=0.00001) 66 | upsample_2 = upsample_layer(skip_1, dconv3_shape, self.n_classes, 'upsample_2', 16) 67 | 68 | skip_2 = skip_layer_connection(pool3, 'skip_2', 256, stddev=0.0001) 69 | upsample_3 = upsample_layer(skip_2, dconv3_shape, self.n_classes, 'upsample_3', 8) 70 | 71 | logit = tf.add(upsample_3, tf.add(2 * upsample_2, 4 * upsample_1)) 72 | return logit 73 | 74 | def optimize(self, batch_generator, learning_rate=1e-5, keep_prob=0.75, num_epochs=1): 75 | loss = logistic_loss(logits=self.logits, labels=self.labels_batch, n_classes=self.n_classes) 76 | summary_op = self._build_summary(loss=loss) 77 | 78 | optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=self.global_step) 79 | 80 | validation_img_summary_op = tf.summary.image('validation_img', self.images_viz) 81 | 82 | self.sess.run(tf.global_variables_initializer()) 83 | saver = tf.train.Saver() 84 | # to visualize using TensorBoard 85 | writer = tf.summary.FileWriter('./graphs/kitti/', self.sess.graph) 86 | ckpt = tf.train.get_checkpoint_state(os.path.dirname('./checkpoints/kitti/checkpoint')) 87 | if ckpt and ckpt.model_checkpoint_path: 88 | print('Graph is available in hard disk. Hence, loading it.') 89 | saver.restore(self.sess, ckpt.model_checkpoint_path) 90 | 91 | initial_step = self.sess.run(self.global_step) 92 | num_batches = int(self.num_train_examples / self.batch_size) 93 | for itr in range(initial_step, num_batches * num_epochs): 94 | image, gt_image = next(batch_generator(self.batch_size)) 95 | _, loss_val, summary = self.sess.run([optimizer, loss, summary_op], 96 | feed_dict={self.images_batch: image, 97 | self.labels_batch: gt_image, 98 | self.dropout: keep_prob}) 99 | writer.add_summary(summary, global_step=itr) 100 | 101 | if (itr < 10) or (itr < 100 and itr % 10 == 0) or \ 102 | (itr < 1000 and itr % 100 == 0) or (itr >= 1000 and itr % 200 == 0): 103 | epoch_no = int(itr / num_batches) 104 | print('epoch: {0:>3d} iter: {1:>4d} loss: {2:>8.4e}'.format(epoch_no, itr, loss_val)) 105 | 106 | if itr % 10 == 0: 107 | viz_images = self.training_visulize() 108 | tt = self.sess.run([validation_img_summary_op], feed_dict={self.images_viz: viz_images}) 109 | writer.add_summary(tt[0], itr) 110 | 111 | if ((itr + 1) % (num_batches * 20) == 0) or (itr == num_batches * num_epochs): 112 | print('At iteration: {} save a checkpoint'.format(itr)) 113 | saver.save(self.sess, './checkpoints/kitti/state', itr) 114 | 115 | def inference(self, runs_dirs, data_dirs): 116 | reshape_logits = tf.reshape(self.logits, (-1, self.n_classes)) 117 | helper.save_inference_samples(runs_dirs, data_dirs, self.sess, self.input_shape, reshape_logits, self.dropout, 118 | self.images_batch) 119 | 120 | def close_session(self): 121 | self.sess.close() 122 | 123 | def training_visulize(self): 124 | reshape_logits = tf.reshape(self.logits, (-1, self.n_classes)) 125 | viz_images = [] 126 | img_output = helper.gen_test_output(self.sess, reshape_logits, self.dropout, self.images_batch, self.viz_dir, 127 | self.input_shape) 128 | for _, ouput in img_output: 129 | viz_images.append(ouput) 130 | return np.array(viz_images) 131 | 132 | @staticmethod 133 | def _build_placeholders(shape): 134 | with tf.name_scope('data'): 135 | X = tf.placeholder(tf.float32, [None, shape[0], shape[1], 3], name='X_placeholder') 136 | Y = tf.placeholder(tf.float32, [None, shape[0], shape[1], 2], name='Y_placeholder') 137 | X_viz = tf.placeholder(tf.float32, [None, shape[0], shape[1], 3], name='X_valid_placeholder') 138 | 139 | dropout = tf.placeholder(tf.float32, name='dropout') 140 | global_step = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step') 141 | 142 | return X, Y, X_viz, dropout, global_step 143 | 144 | @staticmethod 145 | def _build_summary(loss): 146 | with tf.name_scope('summaries'): 147 | tf.summary.scalar('loss', loss) 148 | tf.summary.histogram('histogram loss', loss) 149 | return tf.summary.merge_all() 150 | 151 | 152 | if __name__ == '__main__': 153 | num_classes = 2 154 | images_per_batch = 8 155 | data_dir = './data' 156 | runs_dir = './runs' 157 | viz_dir = './data/data_road/validating' 158 | input_size = (160, 576) 159 | num_train_examples = 289 160 | 161 | get_batches_fn = helper.gen_batch_function(os.path.join(data_dir, 'data_road/training'), input_size) 162 | fc_network = FCN(input_size, num_train_examples, viz_dir, images_per_batch, num_classes) 163 | fc_network.optimize(get_batches_fn, num_epochs=25) 164 | fc_network.inference(runs_dir, data_dir) 165 | -------------------------------------------------------------------------------- /graphs/kitti/README.md: -------------------------------------------------------------------------------- 1 | Logs will be saved to this folder. 2 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | import numpy as np 4 | import os.path 5 | import scipy.misc 6 | import shutil 7 | import zipfile 8 | import time 9 | import tensorflow as tf 10 | from glob import glob 11 | from urllib.request import urlretrieve 12 | from tqdm import tqdm 13 | 14 | 15 | class DLProgress(tqdm): 16 | last_block = 0 17 | 18 | def hook(self, block_num=1, block_size=1, total_size=None): 19 | self.total = total_size 20 | self.update((block_num - self.last_block) * block_size) 21 | self.last_block = block_num 22 | 23 | 24 | def maybe_download_pretrained_vgg(data_dir): 25 | """ 26 | Download and extract pretrained vgg model if it doesn't exist 27 | :param data_dir: Directory to download the model to 28 | """ 29 | vgg_filename = 'vgg.zip' 30 | vgg_path = os.path.join(data_dir, 'vgg') 31 | vgg_files = [ 32 | os.path.join(vgg_path, 'variables/variables.data-00000-of-00001'), 33 | os.path.join(vgg_path, 'variables/variables.index'), 34 | os.path.join(vgg_path, 'saved_model.pb')] 35 | 36 | missing_vgg_files = [vgg_file for vgg_file in vgg_files if not os.path.exists(vgg_file)] 37 | if missing_vgg_files: 38 | # Clean vgg dir 39 | if os.path.exists(vgg_path): 40 | shutil.rmtree(vgg_path) 41 | os.makedirs(vgg_path) 42 | 43 | # Download vgg 44 | print('Downloading pre-trained vgg model...') 45 | with DLProgress(unit='B', unit_scale=True, miniters=1) as pbar: 46 | urlretrieve( 47 | 'https://s3-us-west-1.amazonaws.com/udacity-selfdrivingcar/vgg.zip', 48 | os.path.join(vgg_path, vgg_filename), 49 | pbar.hook) 50 | 51 | # Extract vgg 52 | print('Extracting model...') 53 | zip_ref = zipfile.ZipFile(os.path.join(vgg_path, vgg_filename), 'r') 54 | zip_ref.extractall(data_dir) 55 | zip_ref.close() 56 | 57 | # Remove zip file to save space 58 | os.remove(os.path.join(vgg_path, vgg_filename)) 59 | 60 | 61 | def gen_batch_function(data_folder, image_shape): 62 | """ 63 | Generate function to create batches of training data 64 | :param data_folder: Path to folder that contains all the datasets 65 | :param image_shape: Tuple - Shape of image 66 | :return: 67 | """ 68 | 69 | def get_batches_fn(batch_size): 70 | """ 71 | Create batches of training data 72 | :param batch_size: Batch Size 73 | :return: Batches of training data 74 | """ 75 | image_paths = glob(os.path.join(data_folder, 'image_2', '*.png')) 76 | label_paths = { 77 | re.sub(r'_(lane|road)_', '_', os.path.basename(path)): path 78 | for path in glob(os.path.join(data_folder, 'gt_image_2', '*_road_*.png'))} 79 | background_color = np.array([255, 0, 0]) 80 | 81 | random.shuffle(image_paths) 82 | for batch_i in range(0, len(image_paths), batch_size): 83 | images = [] 84 | gt_images = [] 85 | for image_file in image_paths[batch_i:batch_i + batch_size]: 86 | gt_image_file = label_paths[os.path.basename(image_file)] 87 | 88 | image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape) 89 | gt_image = scipy.misc.imresize(scipy.misc.imread(gt_image_file), image_shape) 90 | 91 | gt_bg = np.all(gt_image == background_color, axis=2) 92 | gt_bg = gt_bg.reshape(*gt_bg.shape, 1) 93 | gt_image = np.concatenate((gt_bg, np.invert(gt_bg)), axis=2) 94 | 95 | images.append(image) 96 | gt_images.append(gt_image) 97 | 98 | yield np.array(images), np.array(gt_images) 99 | 100 | return get_batches_fn 101 | 102 | 103 | def gen_test_output(sess, logits, keep_prob, image_pl, data_folder, image_shape): 104 | """ 105 | Generate test output using the test images 106 | :param sess: TF session 107 | :param logits: TF Tensor for the logits 108 | :param keep_prob: TF Placeholder for the dropout keep probability 109 | :param image_pl: TF Placeholder for the image placeholder 110 | :param data_folder: Path to the folder that contains the datasets 111 | :param image_shape: Tuple - Shape of image 112 | :return: Output for for each test image 113 | """ 114 | for image_file in glob(os.path.join(data_folder, 'image_2', '*.png')): 115 | image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape) 116 | 117 | im_softmax = sess.run( 118 | [tf.nn.softmax(logits)], 119 | {keep_prob: 1.0, image_pl: [image]}) 120 | im_softmax = im_softmax[0][:, 1].reshape(image_shape[0], image_shape[1]) 121 | segmentation = (im_softmax > 0.5).reshape(image_shape[0], image_shape[1], 1) 122 | mask = np.dot(segmentation, np.array([[221, 28, 199, 127]])) 123 | mask = scipy.misc.toimage(mask, mode="RGBA") 124 | street_im = scipy.misc.toimage(image) 125 | street_im.paste(mask, box=None, mask=mask) 126 | 127 | yield os.path.basename(image_file), np.array(street_im) 128 | 129 | 130 | def save_inference_samples(runs_dir, data_dir, sess, image_shape, logits, keep_prob, input_image): 131 | # Make folder for current run 132 | output_dir = os.path.join(runs_dir, str(time.time())) 133 | if os.path.exists(output_dir): 134 | shutil.rmtree(output_dir) 135 | os.makedirs(output_dir) 136 | 137 | # Run NN on test images and save them to HD 138 | print('Training Finished. Saving test images to: {}'.format(output_dir)) 139 | image_outputs = gen_test_output( 140 | sess, logits, keep_prob, input_image, os.path.join(data_dir, 'data_road/testing'), image_shape) 141 | for name, image in image_outputs: 142 | scipy.misc.imsave(os.path.join(output_dir, name), image) 143 | -------------------------------------------------------------------------------- /images/fcn_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/images/fcn_graph.png -------------------------------------------------------------------------------- /images/loss_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/images/loss_graph.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def logistic_loss(logits, labels, n_classes): 5 | with tf.variable_scope('logistic_loss'): 6 | reshaped_logits = tf.reshape(logits, (-1, n_classes)) 7 | reshaped_labels = tf.reshape(labels, (-1, n_classes)) 8 | entropy = tf.nn.softmax_cross_entropy_with_logits(logits=reshaped_logits, 9 | labels=reshaped_labels) 10 | loss = tf.reduce_mean(entropy, name='logistic_loss') 11 | return loss 12 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | vgg_weights = np.load('./pretrained_weights/vgg16_weights.npz') 5 | 6 | 7 | def conv_layer(parent, kernel_name, bias_name, name): 8 | """ 9 | This simple utility function create a convolution layer 10 | and applied relu activation. 11 | 12 | :param parent: 13 | :param kernel_name: Kernel weight tensor 14 | :param bias: Bias tensor 15 | :param name: Name of this layer 16 | :return: Convolution layer created according to the given parameters. 17 | """ 18 | with tf.variable_scope(name) as scope: 19 | kernel_weights = _get_kernel(kernel_name) 20 | init = tf.constant_initializer(value=kernel_weights, dtype=tf.float32) 21 | kernel = tf.get_variable(name="weights", initializer=init, shape=kernel_weights.shape) 22 | conv = tf.nn.conv2d(parent, kernel, [1, 1, 1, 1], padding='SAME') 23 | 24 | bias = _get_bias(bias_name) 25 | init = tf.constant_initializer(value=bias, dtype=tf.float32) 26 | biases = tf.get_variable(name="biases", initializer=init, shape=bias.shape) 27 | 28 | conv_with_bias = tf.nn.bias_add(conv, biases) 29 | conv_with_relu = tf.nn.relu(conv_with_bias, name=scope.name) 30 | return conv_with_relu 31 | 32 | 33 | def max_pool_layer(parent, kernel, stride, name, padding='SAME'): 34 | max_pool = tf.nn.max_pool(parent, ksize=kernel, strides=stride, padding=padding, name=name) 35 | return max_pool 36 | 37 | 38 | def fully_collected_layer(parent, name, dropout, num_classes=2): 39 | with tf.variable_scope(name) as scope: 40 | if name == 'fc_1': 41 | kernel = _reshape_fc_weights('fc6_W', [7, 7, 512, 4096]) 42 | conv = tf.nn.conv2d(parent, kernel, [1, 1, 1, 1], padding='SAME') 43 | bias = _get_bias('fc6_b') 44 | output = tf.nn.bias_add(conv, bias) 45 | output = tf.nn.relu(output, name=scope.name) 46 | return tf.nn.dropout(output, dropout) 47 | 48 | if name == 'fc_2': 49 | kernel = _reshape_fc_weights('fc7_W', [1, 1, 4096, 4096]) 50 | conv = tf.nn.conv2d(parent, kernel, [1, 1, 1, 1], padding='SAME') 51 | bias = _get_bias('fc7_b') 52 | output = tf.nn.bias_add(conv, bias) 53 | output = tf.nn.relu(output, name=scope.name) 54 | return tf.nn.dropout(output, dropout) 55 | 56 | if name == 'fc_3': 57 | initial = tf.truncated_normal([1, 1, 4096, num_classes], stddev=0.0001) 58 | kernel = tf.get_variable('kernel', initializer=initial) 59 | conv = tf.nn.conv2d(parent, kernel, [1, 1, 1, 1], padding='SAME') 60 | initial = tf.constant(0.0, shape=[num_classes]) 61 | bias = tf.get_variable('bias', initializer=initial) 62 | return tf.nn.bias_add(conv, bias) 63 | 64 | raise RuntimeError('{} is not supported as a fully connected name'.format(name)) 65 | 66 | 67 | def upsample_layer(bottom, shape, n_channels, name, upscale_factor, num_classes=2): 68 | kernel_size = 2 * upscale_factor - upscale_factor % 2 69 | stride = upscale_factor 70 | strides = [1, stride, stride, 1] 71 | with tf.variable_scope(name): 72 | output_shape = [shape[0], shape[1], shape[2], num_classes] 73 | filter_shape = [kernel_size, kernel_size, n_channels, n_channels] 74 | weights = _get_bilinear_filter(filter_shape, upscale_factor) 75 | deconv = tf.nn.conv2d_transpose(bottom, weights, output_shape, 76 | strides=strides, padding='SAME') 77 | 78 | bias_init = tf.constant(0.0, shape=[num_classes]) 79 | bias = tf.get_variable('bias', initializer=bias_init) 80 | dconv_with_bias = tf.nn.bias_add(deconv, bias) 81 | 82 | return dconv_with_bias 83 | 84 | 85 | def preprocess(images): 86 | mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='mean') 87 | return images - mean 88 | 89 | 90 | def skip_layer_connection(parent, name, num_input_layers, num_classes=2, stddev=0.0005): 91 | with tf.variable_scope(name) as scope: 92 | initial = tf.truncated_normal([1, 1, num_input_layers, num_classes], stddev=stddev) 93 | kernel = tf.get_variable('kernel', initializer=initial) 94 | conv = tf.nn.conv2d(parent, kernel, [1, 1, 1, 1], padding='SAME') 95 | 96 | bias_init = tf.constant(0.0, shape=[num_classes]) 97 | bias = tf.get_variable('bias', initializer=bias_init) 98 | skip_layer = tf.nn.bias_add(conv, bias) 99 | 100 | return skip_layer 101 | 102 | 103 | def _get_kernel(kernel_name): 104 | kernel = vgg_weights[kernel_name] 105 | return kernel 106 | 107 | 108 | def _reshape_fc_weights(name, new_shape): 109 | w = vgg_weights[name] 110 | w = w.reshape(new_shape) 111 | init = tf.constant_initializer(value=w, 112 | dtype=tf.float32) 113 | var = tf.get_variable(name="weights", initializer=init, shape=new_shape) 114 | return var 115 | 116 | 117 | def _get_bias(name): 118 | bias_weights = vgg_weights[name] 119 | return bias_weights 120 | 121 | 122 | def _get_bilinear_filter(filter_shape, upscale_factor): 123 | kernel_size = filter_shape[1] 124 | if kernel_size % 2 == 1: 125 | centre_location = upscale_factor - 1 126 | else: 127 | centre_location = upscale_factor - 0.5 128 | 129 | bilinear = np.zeros([filter_shape[0], filter_shape[1]]) 130 | for x in range(filter_shape[0]): 131 | for y in range(filter_shape[1]): 132 | value = (1 - abs((x - centre_location) / upscale_factor)) * ( 133 | 1 - abs((y - centre_location) / upscale_factor)) 134 | bilinear[x, y] = value 135 | weights = np.zeros(filter_shape) 136 | for i in range(filter_shape[2]): 137 | weights[:, :, i, i] = bilinear 138 | init = tf.constant_initializer(value=weights, 139 | dtype=tf.float32) 140 | 141 | bilinear_weights = tf.get_variable(name="decon_bilinear_filter", initializer=init, 142 | shape=weights.shape) 143 | return bilinear_weights 144 | -------------------------------------------------------------------------------- /pretrained_weights/README.md: -------------------------------------------------------------------------------- 1 | Please copy VGG-16 pre-trained weights to this folder. 2 | -------------------------------------------------------------------------------- /sample_output/um_000014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/um_000014.png -------------------------------------------------------------------------------- /sample_output/um_000032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/um_000032.png -------------------------------------------------------------------------------- /sample_output/umm_000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/umm_000003.png -------------------------------------------------------------------------------- /sample_output/umm_000015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/umm_000015.png -------------------------------------------------------------------------------- /sample_output/umm_000021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/umm_000021.png -------------------------------------------------------------------------------- /sample_output/umm_000034.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/umm_000034.png -------------------------------------------------------------------------------- /sample_output/umm_000091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/umm_000091.png -------------------------------------------------------------------------------- /sample_output/uu_000009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/uu_000009.png -------------------------------------------------------------------------------- /sample_output/uu_000017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/uu_000017.png -------------------------------------------------------------------------------- /sample_output/uu_000022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/uu_000022.png -------------------------------------------------------------------------------- /sample_output/uu_000081.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/uu_000081.png -------------------------------------------------------------------------------- /sample_output/uu_000099.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/upul/Semantic_Segmentation/70db59061b57413ae7de8df3f62f789c1ac250ef/sample_output/uu_000099.png --------------------------------------------------------------------------------