├── .gitignore ├── FCN.png ├── README.md ├── helper.py ├── main.py ├── output.png ├── project_tests.py └── tensorboard_out.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | main2.py 6 | output.py 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | runs/ 14 | data/ 15 | logs/ 16 | env/ 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # SageMath parsed files 85 | *.sage.py 86 | 87 | # dotenv 88 | .env 89 | 90 | # virtualenv 91 | .venv 92 | venv/ 93 | ENV/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | RUNS/ 109 | -------------------------------------------------------------------------------- /FCN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asterixds/SemanticSegmentation/ad7e84d8e8d41cc6ebcad1238afc93d1e52c0183/FCN.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Semantic Segmentation Project 2 | The project involves the use of a fully convolutional neural network (FCN) to classify the pixels in an image. The deep learning model uses a pre-trained VGG-16 model as a foundation (see paper by Jonathan Long). In this implementation of FCN, we reuse pre-trained layers 3, 4 and 7 of the VGG model and then generate 1x1 convolutions of these layers. This phase is like an encoder that extracts features. The encoding is then followed by a decoding process using a series of upsampling layers. Upsampling is performed using transposed convolution or more accurately called fractionally strided convolution. This is an operation that goes in the opposite direction to a convolution and allows us to translate the activations into something meaningful related to the image size by scaling scaling up the activation size to the same image size. The encoding and decoding process is ilustrated below 3 | 4 | ![alt text](FCN.png "FCN") 5 | 6 | Image Credit: http://cvlab.postech.ac.kr/research/deconvnet/ 7 | 8 | In the process, we lose some resolution because the activations were downscaled and therefore to add back some resolution by adding activations from the previous layer called as skip connections. 9 | 10 | #### Training and optimisation 11 | The network was trained on a g3x2XLarge AWS instance on the [Kitti Road data set](http://www.cvlibs.net/datasets/kitti/eval_road.php) 12 | using the following hyper parameters: 13 | 'lr': 0.0001, 14 | 'keep_prob': 0.25, 15 | 'epochs': 25, 16 | 'batch_size': 16, 17 | 'std_init': 0.01, 18 | 'num_classes': 2, 19 | 'image_shape': (160, 576) 20 | 21 | Optimisation was done using cross entropy loss minimisation as the metric and ADAM optimiser. The cross entropy loss is computed against the correct ground truth image (also flattened) 22 | 23 | #### Retrospective 24 | The results were good for the most partas can be seen in the following output samples. 25 | 26 | ![alt text](output.png "Test outputs") 27 | 28 | It even seems to distinguish road and cossing rail tracks( see the third row third column). It does have some failings particularly small patches around cars. The model could benefit from data augmentation but have left with for a future run along with also trying the model on citiscapes data. -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | import numpy as np 4 | import os.path 5 | import scipy.misc 6 | import shutil 7 | import zipfile 8 | import time 9 | import tensorflow as tf 10 | from glob import glob 11 | from urllib.request import urlretrieve 12 | from tqdm import tqdm 13 | import cv2 14 | 15 | 16 | class DLProgress(tqdm): 17 | last_block = 0 18 | 19 | def hook(self, block_num=1, block_size=1, total_size=None): 20 | self.total = total_size 21 | self.update((block_num - self.last_block) * block_size) 22 | self.last_block = block_num 23 | 24 | 25 | def maybe_download_pretrained_vgg(data_dir): 26 | """ 27 | Download and extract pretrained vgg model if it doesn't exist 28 | :param data_dir: Directory to download the model to 29 | """ 30 | vgg_filename = 'vgg.zip' 31 | vgg_path = os.path.join(data_dir, 'vgg') 32 | vgg_files = [ 33 | os.path.join(vgg_path, 'variables/variables.data-00000-of-00001'), 34 | os.path.join(vgg_path, 'variables/variables.index'), 35 | os.path.join(vgg_path, 'saved_model.pb')] 36 | 37 | missing_vgg_files = [vgg_file for vgg_file in vgg_files if not os.path.exists(vgg_file)] 38 | if missing_vgg_files: 39 | # Clean vgg dir 40 | if os.path.exists(vgg_path): 41 | shutil.rmtree(vgg_path) 42 | os.makedirs(vgg_path) 43 | 44 | # Download vgg 45 | print('Downloading pre-trained vgg model...') 46 | with DLProgress(unit='B', unit_scale=True, miniters=1) as pbar: 47 | urlretrieve( 48 | 'https://s3-us-west-1.amazonaws.com/udacity-selfdrivingcar/vgg.zip', 49 | os.path.join(vgg_path, vgg_filename), 50 | pbar.hook) 51 | 52 | # Extract vgg 53 | print('Extracting model...') 54 | zip_ref = zipfile.ZipFile(os.path.join(vgg_path, vgg_filename), 'r') 55 | zip_ref.extractall(data_dir) 56 | zip_ref.close() 57 | 58 | # Remove zip file to save space 59 | os.remove(os.path.join(vgg_path, vgg_filename)) 60 | 61 | """flip the image around vertical axis""" 62 | 63 | 64 | def augment(image, gt_image, image_shape): 65 | #augment input 66 | def rotate(img, angle): 67 | return scipy.misc.imrotate(img, angle) 68 | 69 | """"flip the image around vertical axis and change sign of steering angle""" 70 | def flip(img): 71 | return cv2.flip(img,1) 72 | 73 | """ adjust the brightness using a factor for all pixels""" 74 | def brightness_jitter(img, factor): 75 | img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2HLS) 76 | img[:,:,1] = img[:,:,1] * (.2 + np.random.uniform(0.2, 0.8)) 77 | return cv2.cvtColor(img, cv2.COLOR_HLS2RGB) 78 | 79 | """warp""" 80 | def warp(img, image_shape, x,y): 81 | return cv2.warpAffine(img, np.float32([[1,0,x],[0,y,0]]), image_shape) 82 | 83 | factor = 0.8 + .4 * np.random.uniform(-1.0, 1.0) 84 | x = np.random.uniform(-45, 45) 85 | y = np.random.uniform(-45, 45) 86 | angle = np.random.uniform(-25, 25) 87 | image = flip(image) 88 | image = brightness_jitter(image, factor) 89 | image = warp(image, image_shape,x,y) 90 | 91 | #do the same for labels 92 | gt_image = flip(gt_image) 93 | #gt_image = brightness_jitter(gt_image,factor) 94 | gt_image = warp(gt_image, image_shape,x,y) 95 | 96 | return image, gt_image 97 | 98 | 99 | def gen_batch_function(data_folder, image_shape): 100 | """ 101 | Generate function to create batches of training data 102 | :param data_folder: Path to folder that contains all the datasets 103 | :param image_shape: Tuple - Shape of image 104 | :return: 105 | """ 106 | def get_batches_fn(batch_size): 107 | """ 108 | Create batches of training data 109 | :param batch_size: Batch Size 110 | :return: Batches of training data 111 | """ 112 | image_paths = glob(os.path.join(data_folder, 'image_2', '*.png')) 113 | label_paths = { 114 | re.sub(r'_(lane|road)_', '_', os.path.basename(path)): path 115 | for path in glob(os.path.join(data_folder, 'gt_image_2', '*_road_*.png'))} 116 | background_color = np.array([255, 0, 0]) 117 | 118 | random.shuffle(image_paths) 119 | for batch_i in range(0, len(image_paths), batch_size): 120 | images = [] 121 | gt_images = [] 122 | for image_file in image_paths[batch_i:batch_i+batch_size]: 123 | gt_image_file = label_paths[os.path.basename(image_file)] 124 | 125 | image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape) 126 | gt_image = scipy.misc.imresize(scipy.misc.imread(gt_image_file), image_shape) 127 | 128 | #image, gt_image = augment(image, gt_image, image_shape) 129 | 130 | gt_bg = np.all(gt_image == background_color, axis=2) 131 | gt_bg = gt_bg.reshape(*gt_bg.shape, 1) 132 | gt_image = np.concatenate((gt_bg, np.invert(gt_bg)), axis=2) 133 | 134 | images.append(image) 135 | gt_images.append(gt_image) 136 | 137 | yield np.array(images), np.array(gt_images) 138 | return get_batches_fn 139 | 140 | 141 | def gen_test_output(sess, logits, keep_prob, image_pl, data_folder, image_shape): 142 | """ 143 | Generate test output using the test images 144 | :param sess: TF session 145 | :param logits: TF Tensor for the logits 146 | :param keep_prob: TF Placeholder for the dropout keep robability 147 | :param image_pl: TF Placeholder for the image placeholder 148 | :param data_folder: Path to the folder that contains the datasets 149 | :param image_shape: Tuple - Shape of image 150 | :return: Output for for each test image 151 | """ 152 | for image_file in glob(os.path.join(data_folder, 'image_2', '*.png')): 153 | image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape) 154 | 155 | im_softmax = sess.run( 156 | [tf.nn.softmax(logits)], 157 | {keep_prob: 1.0, image_pl: [image]}) 158 | im_softmax = im_softmax[0][:, 1].reshape(image_shape[0], image_shape[1]) 159 | segmentation = (im_softmax > 0.5).reshape(image_shape[0], image_shape[1], 1) 160 | mask = np.dot(segmentation, np.array([[0, 255, 0, 127]])) 161 | mask = scipy.misc.toimage(mask, mode="RGBA") 162 | street_im = scipy.misc.toimage(image) 163 | street_im.paste(mask, box=None, mask=mask) 164 | 165 | yield os.path.basename(image_file), np.array(street_im) 166 | 167 | 168 | def save_inference_samples(runs_dir, data_dir, sess, image_shape, logits, keep_prob, input_image): 169 | # Make folder for current run 170 | output_dir = os.path.join(runs_dir, str(time.time())) 171 | if os.path.exists(output_dir): 172 | shutil.rmtree(output_dir) 173 | os.makedirs(output_dir) 174 | 175 | # Run NN on test images and save them to HD 176 | print('Training Finished. Saving test images to: {}'.format(output_dir)) 177 | image_outputs = gen_test_output( 178 | sess, logits, keep_prob, input_image, os.path.join(data_dir, 'data_road/testing'), image_shape) 179 | for name, image in image_outputs: 180 | scipy.misc.imsave(os.path.join(output_dir, name), image) 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import tensorflow as tf 3 | import helper 4 | import warnings 5 | from distutils.version import LooseVersion 6 | import project_tests as tests 7 | 8 | class FCNSegementer(object): 9 | 10 | ''' 11 | Constructor for setting params 12 | ''' 13 | def __init__(self, params): 14 | for p in params: 15 | setattr(self, p, params[p]) 16 | 17 | """ 18 | Load Pretrained VGG Model into TensorFlow. 19 | :param sess: TensorFlow Session 20 | :param vgg_path: Path to vgg folder, containing "variables/" and "saved_model.pb" 21 | :return: Tuple of Tensors from VGG model (image_input, keep_prob, layer3_out, layer4_out, layer7_out) 22 | """ 23 | def load_vgg(self, sess, vgg_path): 24 | vgg_tag = 'vgg16' 25 | vgg_input_tensor_name = 'image_input:0' 26 | vgg_keep_prob_tensor_name = 'keep_prob:0' 27 | vgg_layer3_out_tensor_name = 'layer3_out:0' 28 | vgg_layer4_out_tensor_name = 'layer4_out:0' 29 | vgg_layer7_out_tensor_name = 'layer7_out:0' 30 | 31 | # Use tf.saved_model.loader.load to load the model and weights 32 | tf.saved_model.loader.load(sess, [vgg_tag], vgg_path) 33 | 34 | default_graph = tf.get_default_graph() 35 | vgg_image_input = default_graph.get_tensor_by_name(vgg_input_tensor_name) 36 | vgg_keep = default_graph.get_tensor_by_name(vgg_keep_prob_tensor_name) 37 | vgg_layer3 = default_graph.get_tensor_by_name(vgg_layer3_out_tensor_name) 38 | vgg_layer4 = default_graph.get_tensor_by_name(vgg_layer4_out_tensor_name) 39 | vgg_layer7 = default_graph.get_tensor_by_name(vgg_layer7_out_tensor_name) 40 | 41 | return vgg_image_input, vgg_keep, vgg_layer3, vgg_layer4, vgg_layer7 42 | 43 | def save_model(self, sess): 44 | model_file = os.path.join(self.logs_location, "model") 45 | saver = tf.train.Saver() 46 | saver.save(sess, model_file) 47 | tf.train.write_graph(sess.graph_def, self.logs_location, "model.pb", False) 48 | print("Model saved") 49 | 50 | def layers(self, vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, num_classes): 51 | def conv_1_by_1(x, num_classes, 52 | kernel_regularizer = tf.contrib.layers.l2_regularizer(1e-3), 53 | init = tf.truncated_normal_initializer(stddev = 0.01)): 54 | return tf.layers.conv2d(x, num_classes, 1,1, padding = 'same', kernel_regularizer = kernel_regularizer, kernel_initializer = init) 55 | 56 | def upsample(x, num_classes, kernel_size, strides, 57 | kernel_regularizer = tf.contrib.layers.l2_regularizer(1e-3), 58 | init = tf.truncated_normal_initializer(stddev = 0.01)): 59 | return tf.layers.conv2d_transpose(x, num_classes, kernel_size, strides, padding = 'same', kernel_regularizer = kernel_regularizer, kernel_initializer = init) 60 | 61 | l7_1x1 = conv_1_by_1(vgg_layer7_out, num_classes) 62 | l4_1x1 = conv_1_by_1(vgg_layer4_out, num_classes) 63 | l3_1x1 = conv_1_by_1(vgg_layer3_out, num_classes) 64 | 65 | #upsample l7 by 2 66 | l7_upsample = upsample(l7_1x1, num_classes, 4, 2) 67 | #l7_upsample = tf.layers.batch_normalization(l7_upsample) 68 | 69 | #add skip connection from l4_1x1 70 | l7l4_skip = tf.add(l7_upsample, l4_1x1) 71 | 72 | #implement the another transposed convolution layer 73 | l7l4_upsample = upsample(l7l4_skip, num_classes, 4, 2) 74 | #l7l4_upsample = tf.layers.batch_normalization(l7l4_upsample) 75 | 76 | #add second skip connection from l3_1x1 77 | l7l4l3_skip = tf.add(l7l4_upsample, l3_1x1) 78 | 79 | return upsample(l7l4l3_skip, num_classes, 16, 8) 80 | 81 | """ 82 | Build the TensorFLow loss and optimizer operations. 83 | :param nn_last_layer: TF Tensor of the last layer in the neural network 84 | :param correct_label: TF Placeholder for the correct label image 85 | :param learning_rate: TF Placeholder for the learning rate 86 | :param num_classes: Number of classes to classify 87 | :return: Tuple of (logits, train_op, cross_entropy_loss) 88 | """ 89 | def optimize(self, nn_last_layer, correct_label, learning_rate, num_classes): 90 | logits = tf.reshape(nn_last_layer, (-1, num_classes)) 91 | correct_label = tf.reshape(correct_label, (-1, num_classes)) 92 | 93 | # define a loss function and a trainer/optimizer 94 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = logits, labels = correct_label)) 95 | optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss) 96 | 97 | return logits, optimizer, loss 98 | 99 | 100 | """ 101 | Train neural network and print out the loss during training. 102 | :param sess: TF Session 103 | :param epochs: Number of epochs 104 | :param batch_size: Batch size 105 | :param get_batches_fn: Function to get batches of training data. Call using get_batches_fn(batch_size) 106 | :param train_op: TF Operation to train the neural network 107 | :param cross_entropy_loss: TF Tensor for the amount of loss 108 | :param input_image: TF Placeholder for input images 109 | :param correct_label: TF Placeholder for label images 110 | :param keep_prob: TF Placeholder for dropout keep probability 111 | :param learning_rate: TF Placeholder for learning rate 112 | """ 113 | def train_nn(self, sess, epochs, batch_size, get_batches_fn, train_op, cross_entropy_loss, input_image, 114 | correct_label, keep_prob, learning_rate): 115 | for epoch in range(epochs): 116 | # train on batches 117 | for images, labels in get_batches_fn(batch_size): 118 | _, loss = sess.run([train_op, cross_entropy_loss], 119 | feed_dict={input_image: images, 120 | correct_label: labels, 121 | keep_prob:self.keep_prob, 122 | learning_rate:self.lr}) 123 | 124 | print("Epoch {} of {}...".format(epoch+1, epochs), "Training Loss: {:.5f}...".format(loss)) 125 | 126 | ''' 127 | Run tests 128 | ''' 129 | def run(self): 130 | config = tf.ConfigProto(log_device_placement=True) 131 | config.gpu_options.allow_growth = True 132 | config.gpu_options.allocator_type = 'BFC' 133 | # Download pretrained vgg model 134 | helper.maybe_download_pretrained_vgg(self.data_dir) 135 | 136 | # Path to vgg model and training data 137 | vgg_path = os.path.join(self.data_dir, 'vgg') 138 | train_path = os.path.join(self.data_dir, self.training_dir) 139 | 140 | # Generate batches 141 | get_batches_fn = helper.gen_batch_function(train_path, self.image_shape) 142 | 143 | with tf.Session() as sess: 144 | correct_label = tf.placeholder(tf.float32, [None, None, None, self.num_classes]) 145 | learning_rate = tf.placeholder(tf.float32) 146 | 147 | 148 | # Build FCN using load_vgg, layers 149 | vgg_image_input, keep_prob, vgg_layer3, vgg_layer4, vgg_layer7 = self.load_vgg(sess, vgg_path) 150 | nn_last_layer = self.layers(vgg_layer3, vgg_layer4, vgg_layer7, self.num_classes) 151 | 152 | # Optimise cross entropy loss 153 | logits, train_op, cross_entropy_loss = self.optimize(nn_last_layer, correct_label, learning_rate, self.num_classes) 154 | 155 | # Train NN 156 | sess.run(tf.global_variables_initializer()) 157 | self.train_nn(sess, self.epochs, self.batch_size, get_batches_fn, train_op, cross_entropy_loss, vgg_image_input, 158 | correct_label, keep_prob, learning_rate) 159 | 160 | #save the model 161 | self.save_model(sess) 162 | 163 | # Save inference data u 164 | helper.save_inference_samples(self.runs_dir, self.data_dir, sess, self.image_shape, logits, keep_prob, vgg_image_input) 165 | 166 | ''' 167 | Run tests 168 | ''' 169 | def run_tests(self): 170 | tests.test_load_vgg(self.load_vgg, tf) 171 | tests.test_layers(self.layers) 172 | tests.test_optimize(self.optimize_cross_entropy) 173 | tests.test_train_nn(self.train_nn) 174 | 175 | if __name__ == '__main__': 176 | 177 | # training hyper parameters 178 | params = { 179 | 'data_dir': 'data', 180 | 'runs_dir': 'runs', 181 | 'training_dir': 'data_road/training', 182 | 'logs_location': 'logs', 183 | 'lr': 0.0001, 184 | 'keep_prob': 0.25, 185 | 'epochs': 25, 186 | 'batch_size': 16, 187 | 'std_init': 0.01, 188 | 'num_classes': 2, 189 | 'image_shape': (160, 576) 190 | } 191 | fcn = FCNSegementer(params) 192 | fcn.run() 193 | -------------------------------------------------------------------------------- /output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asterixds/SemanticSegmentation/ad7e84d8e8d41cc6ebcad1238afc93d1e52c0183/output.png -------------------------------------------------------------------------------- /project_tests.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from copy import deepcopy 4 | from glob import glob 5 | from unittest import mock 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | 11 | def test_safe(func): 12 | """ 13 | Isolate tests 14 | """ 15 | def func_wrapper(*args): 16 | with tf.Graph().as_default(): 17 | result = func(*args) 18 | print('Tests Passed') 19 | return result 20 | 21 | return func_wrapper 22 | 23 | 24 | def _prevent_print(function, params): 25 | sys.stdout = open(os.devnull, "w") 26 | function(**params) 27 | sys.stdout = sys.__stdout__ 28 | 29 | 30 | def _assert_tensor_shape(tensor, shape, display_name): 31 | assert tf.assert_rank(tensor, len(shape), message='{} has wrong rank'.format(display_name)) 32 | 33 | tensor_shape = tensor.get_shape().as_list() if len(shape) else [] 34 | 35 | wrong_dimension = [ten_dim for ten_dim, cor_dim in zip(tensor_shape, shape) 36 | if cor_dim is not None and ten_dim != cor_dim] 37 | assert not wrong_dimension, \ 38 | '{} has wrong shape. Found {}'.format(display_name, tensor_shape) 39 | 40 | 41 | class TmpMock(object): 42 | """ 43 | Mock a attribute. Restore attribute when exiting scope. 44 | """ 45 | def __init__(self, module, attrib_name): 46 | self.original_attrib = deepcopy(getattr(module, attrib_name)) 47 | setattr(module, attrib_name, mock.MagicMock()) 48 | self.module = module 49 | self.attrib_name = attrib_name 50 | 51 | def __enter__(self): 52 | return getattr(self.module, self.attrib_name) 53 | 54 | def __exit__(self, type, value, traceback): 55 | setattr(self.module, self.attrib_name, self.original_attrib) 56 | 57 | 58 | @test_safe 59 | def test_load_vgg(load_vgg, tf_module): 60 | with TmpMock(tf_module.saved_model.loader, 'load') as mock_load_model: 61 | vgg_path = '' 62 | sess = tf.Session() 63 | test_input_image = tf.placeholder(tf.float32, name='image_input') 64 | test_keep_prob = tf.placeholder(tf.float32, name='keep_prob') 65 | test_vgg_layer3_out = tf.placeholder(tf.float32, name='layer3_out') 66 | test_vgg_layer4_out = tf.placeholder(tf.float32, name='layer4_out') 67 | test_vgg_layer7_out = tf.placeholder(tf.float32, name='layer7_out') 68 | 69 | input_image, keep_prob, vgg_layer3_out, vgg_layer4_out, vgg_layer7_out = load_vgg(sess, vgg_path) 70 | 71 | assert mock_load_model.called, \ 72 | 'tf.saved_model.loader.load() not called' 73 | assert mock_load_model.call_args == mock.call(sess, ['vgg16'], vgg_path), \ 74 | 'tf.saved_model.loader.load() called with wrong arguments.' 75 | 76 | assert input_image == test_input_image, 'input_image is the wrong object' 77 | assert keep_prob == test_keep_prob, 'keep_prob is the wrong object' 78 | assert vgg_layer3_out == test_vgg_layer3_out, 'layer3_out is the wrong object' 79 | assert vgg_layer4_out == test_vgg_layer4_out, 'layer4_out is the wrong object' 80 | assert vgg_layer7_out == test_vgg_layer7_out, 'layer7_out is the wrong object' 81 | 82 | 83 | @test_safe 84 | def test_layers(layers): 85 | num_classes = 2 86 | vgg_layer3_out = tf.placeholder(tf.float32, [None, None, None, 256]) 87 | vgg_layer4_out = tf.placeholder(tf.float32, [None, None, None, 512]) 88 | vgg_layer7_out = tf.placeholder(tf.float32, [None, None, None, 4096]) 89 | layers_output = layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, num_classes) 90 | 91 | _assert_tensor_shape(layers_output, [None, None, None, num_classes], 'Layers Output') 92 | 93 | 94 | @test_safe 95 | def test_optimize(optimize): 96 | num_classes = 2 97 | shape = [2, 3, 4, num_classes] 98 | layers_output = tf.Variable(tf.zeros(shape)) 99 | correct_label = tf.placeholder(tf.float32, [None, None, None, num_classes]) 100 | learning_rate = tf.placeholder(tf.float32) 101 | logits, train_op, cross_entropy_loss = optimize(layers_output, correct_label, learning_rate, num_classes) 102 | 103 | _assert_tensor_shape(logits, [2*3*4, num_classes], 'Logits') 104 | 105 | with tf.Session() as sess: 106 | sess.run(tf.global_variables_initializer()) 107 | sess.run([train_op], {correct_label: np.arange(np.prod(shape)).reshape(shape), learning_rate: 10}) 108 | test, loss = sess.run([layers_output, cross_entropy_loss], {correct_label: np.arange(np.prod(shape)).reshape(shape)}) 109 | 110 | assert test.min() != 0 or test.max() != 0, 'Training operation not changing weights.' 111 | 112 | 113 | @test_safe 114 | def test_train_nn(train_nn): 115 | epochs = 1 116 | batch_size = 2 117 | 118 | def get_batches_fn(batach_size_parm): 119 | shape = [batach_size_parm, 2, 3, 3] 120 | return np.arange(np.prod(shape)).reshape(shape) 121 | 122 | train_op = tf.constant(0) 123 | cross_entropy_loss = tf.constant(10.11) 124 | input_image = tf.placeholder(tf.float32, name='input_image') 125 | correct_label = tf.placeholder(tf.float32, name='correct_label') 126 | keep_prob = tf.placeholder(tf.float32, name='keep_prob') 127 | learning_rate = tf.placeholder(tf.float32, name='learning_rate') 128 | with tf.Session() as sess: 129 | parameters = { 130 | 'sess': sess, 131 | 'epochs': epochs, 132 | 'batch_size': batch_size, 133 | 'get_batches_fn': get_batches_fn, 134 | 'train_op': train_op, 135 | 'cross_entropy_loss': cross_entropy_loss, 136 | 'input_image': input_image, 137 | 'correct_label': correct_label, 138 | 'keep_prob': keep_prob, 139 | 'learning_rate': learning_rate} 140 | #_prevent_print(train_nn, parameters) 141 | 142 | 143 | @test_safe 144 | def test_for_kitti_dataset(data_dir): 145 | kitti_dataset_path = os.path.join(data_dir, 'data_road') 146 | training_labels_count = len(glob(os.path.join(kitti_dataset_path, 'training/gt_image_2/*_road_*.png'))) 147 | training_images_count = len(glob(os.path.join(kitti_dataset_path, 'training/image_2/*.png'))) 148 | testing_images_count = len(glob(os.path.join(kitti_dataset_path, 'testing/image_2/*.png'))) 149 | 150 | assert not (training_images_count == training_labels_count == testing_images_count == 0),\ 151 | 'Kitti dataset not found. Extract Kitti dataset in {}'.format(kitti_dataset_path) 152 | assert training_images_count == 289, 'Expected 289 training images, found {} images.'.format(training_images_count) 153 | assert training_labels_count == 289, 'Expected 289 training labels, found {} labels.'.format(training_labels_count) 154 | assert testing_images_count == 290, 'Expected 290 testing images, found {} images.'.format(testing_images_count) 155 | -------------------------------------------------------------------------------- /tensorboard_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asterixds/SemanticSegmentation/ad7e84d8e8d41cc6ebcad1238afc93d1e52c0183/tensorboard_out.png --------------------------------------------------------------------------------