├── data └── .gitignore ├── set_git.sh ├── LICENSE ├── .gitignore ├── README.md ├── main.py ├── helper.py └── project_tests.py /data/.gitignore: -------------------------------------------------------------------------------- 1 | data_road/ 2 | vgg/ 3 | gtFine_trainvaltest/ 4 | 5 | vgg16.npy 6 | -------------------------------------------------------------------------------- /set_git.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Make sure you have the latest version of the repo 4 | echo 5 | git pull 6 | echo 7 | 8 | # Ask the user for login details 9 | read -p 'Git repository url: ' upstreamVar 10 | read -p 'Git Username: ' userVar 11 | read -p 'Git email: ' emailVar 12 | 13 | echo 14 | echo Thank you $userVar!, we now have your credentials 15 | echo for upstream $upstreamVar. You must supply your password for each push. 16 | echo 17 | 18 | echo setting up git 19 | 20 | git config --global user.name $userVar 21 | git config --global user.email $emailVar 22 | git remote set-url origin $upstreamVar 23 | echo 24 | 25 | echo Please verify remote: 26 | git remote -v 27 | echo 28 | 29 | echo Please verify your credentials: 30 | echo username: `git config user.name` 31 | echo email: `git config user.email` 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017-2018 Udacity, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | RUNS/ 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Segmentation 2 | ### Introduction 3 | In this project, you'll label the pixels of a road in images using a Fully Convolutional Network (FCN). 4 | 5 | ### Setup 6 | ##### GPU 7 | `main.py` will check to make sure you are using GPU - if you don't have a GPU on your system, you can use AWS or another cloud computing platform. 8 | ##### Frameworks and Packages 9 | Make sure you have the following is installed: 10 | - [Python 3](https://www.python.org/) 11 | - [TensorFlow](https://www.tensorflow.org/) 12 | - [NumPy](http://www.numpy.org/) 13 | - [SciPy](https://www.scipy.org/) 14 | ##### Dataset 15 | Download the [Kitti Road dataset](http://www.cvlibs.net/datasets/kitti/eval_road.php) from [here](http://www.cvlibs.net/download.php?file=data_road.zip). Extract the dataset in the `data` folder. This will create the folder `data_road` with all the training a test images. 16 | 17 | ### Start 18 | ##### Implement 19 | Implement the code in the `main.py` module indicated by the "TODO" comments. 20 | The comments indicated with "OPTIONAL" tag are not required to complete. 21 | ##### Run 22 | Run the following command to run the project: 23 | ``` 24 | python main.py 25 | ``` 26 | **Note** If running this in Jupyter Notebook system messages, such as those regarding test status, may appear in the terminal rather than the notebook. 27 | 28 | ### Submission 29 | 1. Ensure you've passed all the unit tests. 30 | 2. Ensure you pass all points on [the rubric](https://review.udacity.com/#!/rubrics/989/view). 31 | 3. Submit the following in a zip file. 32 | - `helper.py` 33 | - `main.py` 34 | - `project_tests.py` 35 | - Newest inference images from `runs` folder (**all images from the most recent run**) 36 | 37 | ### Tips 38 | - The link for the frozen `VGG16` model is hardcoded into `helper.py`. The model can be found [here](https://s3-us-west-1.amazonaws.com/udacity-selfdrivingcar/vgg.zip). 39 | - The model is not vanilla `VGG16`, but a fully convolutional version, which already contains the 1x1 convolutions to replace the fully connected layers. Please see this [post](https://s3-us-west-1.amazonaws.com/udacity-selfdrivingcar/forum_archive/Semantic_Segmentation_advice.pdf) for more information. A summary of additional points, follow. 40 | - The original FCN-8s was trained in stages. The authors later uploaded a version that was trained all at once to their GitHub repo. The version in the GitHub repo has one important difference: The outputs of pooling layers 3 and 4 are scaled before they are fed into the 1x1 convolutions. As a result, some students have found that the model learns much better with the scaling layers included. The model may not converge substantially faster, but may reach a higher IoU and accuracy. 41 | - When adding l2-regularization, setting a regularizer in the arguments of the `tf.layers` is not enough. Regularization loss terms must be manually added to your loss function. otherwise regularization is not implemented. 42 | 43 | ### Using GitHub and Creating Effective READMEs 44 | If you are unfamiliar with GitHub , Udacity has a brief [GitHub tutorial](http://blog.udacity.com/2015/06/a-beginners-git-github-tutorial.html) to get you started. Udacity also provides a more detailed free [course on git and GitHub](https://www.udacity.com/course/how-to-use-git-and-github--ud775). 45 | 46 | To learn about REAMDE files and Markdown, Udacity provides a free [course on READMEs](https://www.udacity.com/courses/ud777), as well. 47 | 48 | GitHub also provides a [tutorial](https://guides.github.com/features/mastering-markdown/) about creating Markdown files. 49 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os.path 3 | import tensorflow as tf 4 | import helper 5 | import warnings 6 | from distutils.version import LooseVersion 7 | import project_tests as tests 8 | 9 | 10 | # Check TensorFlow Version 11 | assert LooseVersion(tf.__version__) >= LooseVersion('1.0'), 'Please use TensorFlow version 1.0 or newer. You are using {}'.format(tf.__version__) 12 | print('TensorFlow Version: {}'.format(tf.__version__)) 13 | 14 | # Check for a GPU 15 | if not tf.test.gpu_device_name(): 16 | warnings.warn('No GPU found. Please use a GPU to train your neural network.') 17 | else: 18 | print('Default GPU Device: {}'.format(tf.test.gpu_device_name())) 19 | 20 | 21 | def load_vgg(sess, vgg_path): 22 | """ 23 | Load Pretrained VGG Model into TensorFlow. 24 | :param sess: TensorFlow Session 25 | :param vgg_path: Path to vgg folder, containing "variables/" and "saved_model.pb" 26 | :return: Tuple of Tensors from VGG model (image_input, keep_prob, layer3_out, layer4_out, layer7_out) 27 | """ 28 | # TODO: Implement function 29 | # Use tf.saved_model.loader.load to load the model and weights 30 | vgg_tag = 'vgg16' 31 | vgg_input_tensor_name = 'image_input:0' 32 | vgg_keep_prob_tensor_name = 'keep_prob:0' 33 | vgg_layer3_out_tensor_name = 'layer3_out:0' 34 | vgg_layer4_out_tensor_name = 'layer4_out:0' 35 | vgg_layer7_out_tensor_name = 'layer7_out:0' 36 | 37 | return None, None, None, None, None 38 | tests.test_load_vgg(load_vgg, tf) 39 | 40 | 41 | def layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, num_classes): 42 | """ 43 | Create the layers for a fully convolutional network. Build skip-layers using the vgg layers. 44 | :param vgg_layer3_out: TF Tensor for VGG Layer 3 output 45 | :param vgg_layer4_out: TF Tensor for VGG Layer 4 output 46 | :param vgg_layer7_out: TF Tensor for VGG Layer 7 output 47 | :param num_classes: Number of classes to classify 48 | :return: The Tensor for the last layer of output 49 | """ 50 | # TODO: Implement function 51 | return None 52 | tests.test_layers(layers) 53 | 54 | 55 | def optimize(nn_last_layer, correct_label, learning_rate, num_classes): 56 | """ 57 | Build the TensorFLow loss and optimizer operations. 58 | :param nn_last_layer: TF Tensor of the last layer in the neural network 59 | :param correct_label: TF Placeholder for the correct label image 60 | :param learning_rate: TF Placeholder for the learning rate 61 | :param num_classes: Number of classes to classify 62 | :return: Tuple of (logits, train_op, cross_entropy_loss) 63 | """ 64 | # TODO: Implement function 65 | return None, None, None 66 | tests.test_optimize(optimize) 67 | 68 | 69 | def train_nn(sess, epochs, batch_size, get_batches_fn, train_op, cross_entropy_loss, input_image, 70 | correct_label, keep_prob, learning_rate): 71 | """ 72 | Train neural network and print out the loss during training. 73 | :param sess: TF Session 74 | :param epochs: Number of epochs 75 | :param batch_size: Batch size 76 | :param get_batches_fn: Function to get batches of training data. Call using get_batches_fn(batch_size) 77 | :param train_op: TF Operation to train the neural network 78 | :param cross_entropy_loss: TF Tensor for the amount of loss 79 | :param input_image: TF Placeholder for input images 80 | :param correct_label: TF Placeholder for label images 81 | :param keep_prob: TF Placeholder for dropout keep probability 82 | :param learning_rate: TF Placeholder for learning rate 83 | """ 84 | # TODO: Implement function 85 | pass 86 | tests.test_train_nn(train_nn) 87 | 88 | 89 | def run(): 90 | num_classes = 2 91 | image_shape = (160, 576) 92 | data_dir = './data' 93 | runs_dir = './runs' 94 | tests.test_for_kitti_dataset(data_dir) 95 | 96 | # Download pretrained vgg model 97 | helper.maybe_download_pretrained_vgg(data_dir) 98 | 99 | # OPTIONAL: Train and Inference on the cityscapes dataset instead of the Kitti dataset. 100 | # You'll need a GPU with at least 10 teraFLOPS to train on. 101 | # https://www.cityscapes-dataset.com/ 102 | 103 | with tf.Session() as sess: 104 | # Path to vgg model 105 | vgg_path = os.path.join(data_dir, 'vgg') 106 | # Create function to get batches 107 | get_batches_fn = helper.gen_batch_function(os.path.join(data_dir, 'data_road/training'), image_shape) 108 | 109 | # OPTIONAL: Augment Images for better results 110 | # https://datascience.stackexchange.com/questions/5224/how-to-prepare-augment-images-for-neural-network 111 | 112 | # TODO: Build NN using load_vgg, layers, and optimize function 113 | 114 | # TODO: Train NN using the train_nn function 115 | 116 | # TODO: Save inference data using helper.save_inference_samples 117 | # helper.save_inference_samples(runs_dir, data_dir, sess, image_shape, logits, keep_prob, input_image) 118 | 119 | # OPTIONAL: Apply the trained model to a video 120 | 121 | 122 | if __name__ == '__main__': 123 | run() 124 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import re 2 | import random 3 | import numpy as np 4 | import os.path 5 | import scipy.misc 6 | import shutil 7 | import zipfile 8 | import time 9 | import tensorflow as tf 10 | from glob import glob 11 | from urllib.request import urlretrieve 12 | from tqdm import tqdm 13 | 14 | 15 | class DLProgress(tqdm): 16 | last_block = 0 17 | 18 | def hook(self, block_num=1, block_size=1, total_size=None): 19 | self.total = total_size 20 | self.update((block_num - self.last_block) * block_size) 21 | self.last_block = block_num 22 | 23 | 24 | def maybe_download_pretrained_vgg(data_dir): 25 | """ 26 | Download and extract pretrained vgg model if it doesn't exist 27 | :param data_dir: Directory to download the model to 28 | """ 29 | vgg_filename = 'vgg.zip' 30 | vgg_path = os.path.join(data_dir, 'vgg') 31 | vgg_files = [ 32 | os.path.join(vgg_path, 'variables/variables.data-00000-of-00001'), 33 | os.path.join(vgg_path, 'variables/variables.index'), 34 | os.path.join(vgg_path, 'saved_model.pb')] 35 | 36 | missing_vgg_files = [vgg_file for vgg_file in vgg_files if not os.path.exists(vgg_file)] 37 | if missing_vgg_files: 38 | # Clean vgg dir 39 | if os.path.exists(vgg_path): 40 | shutil.rmtree(vgg_path) 41 | os.makedirs(vgg_path) 42 | 43 | # Download vgg 44 | print('Downloading pre-trained vgg model...') 45 | with DLProgress(unit='B', unit_scale=True, miniters=1) as pbar: 46 | urlretrieve( 47 | 'https://s3-us-west-1.amazonaws.com/udacity-selfdrivingcar/vgg.zip', 48 | os.path.join(vgg_path, vgg_filename), 49 | pbar.hook) 50 | 51 | # Extract vgg 52 | print('Extracting model...') 53 | zip_ref = zipfile.ZipFile(os.path.join(vgg_path, vgg_filename), 'r') 54 | zip_ref.extractall(data_dir) 55 | zip_ref.close() 56 | 57 | # Remove zip file to save space 58 | os.remove(os.path.join(vgg_path, vgg_filename)) 59 | 60 | 61 | def gen_batch_function(data_folder, image_shape): 62 | """ 63 | Generate function to create batches of training data 64 | :param data_folder: Path to folder that contains all the datasets 65 | :param image_shape: Tuple - Shape of image 66 | :return: 67 | """ 68 | def get_batches_fn(batch_size): 69 | """ 70 | Create batches of training data 71 | :param batch_size: Batch Size 72 | :return: Batches of training data 73 | """ 74 | image_paths = glob(os.path.join(data_folder, 'image_2', '*.png')) 75 | label_paths = { 76 | re.sub(r'_(lane|road)_', '_', os.path.basename(path)): path 77 | for path in glob(os.path.join(data_folder, 'gt_image_2', '*_road_*.png'))} 78 | background_color = np.array([255, 0, 0]) 79 | 80 | random.shuffle(image_paths) 81 | for batch_i in range(0, len(image_paths), batch_size): 82 | images = [] 83 | gt_images = [] 84 | for image_file in image_paths[batch_i:batch_i+batch_size]: 85 | gt_image_file = label_paths[os.path.basename(image_file)] 86 | 87 | image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape) 88 | gt_image = scipy.misc.imresize(scipy.misc.imread(gt_image_file), image_shape) 89 | 90 | gt_bg = np.all(gt_image == background_color, axis=2) 91 | gt_bg = gt_bg.reshape(*gt_bg.shape, 1) 92 | gt_image = np.concatenate((gt_bg, np.invert(gt_bg)), axis=2) 93 | 94 | images.append(image) 95 | gt_images.append(gt_image) 96 | 97 | yield np.array(images), np.array(gt_images) 98 | return get_batches_fn 99 | 100 | 101 | def gen_test_output(sess, logits, keep_prob, image_pl, data_folder, image_shape): 102 | """ 103 | Generate test output using the test images 104 | :param sess: TF session 105 | :param logits: TF Tensor for the logits 106 | :param keep_prob: TF Placeholder for the dropout keep robability 107 | :param image_pl: TF Placeholder for the image placeholder 108 | :param data_folder: Path to the folder that contains the datasets 109 | :param image_shape: Tuple - Shape of image 110 | :return: Output for for each test image 111 | """ 112 | for image_file in glob(os.path.join(data_folder, 'image_2', '*.png')): 113 | image = scipy.misc.imresize(scipy.misc.imread(image_file), image_shape) 114 | 115 | im_softmax = sess.run( 116 | [tf.nn.softmax(logits)], 117 | {keep_prob: 1.0, image_pl: [image]}) 118 | im_softmax = im_softmax[0][:, 1].reshape(image_shape[0], image_shape[1]) 119 | segmentation = (im_softmax > 0.5).reshape(image_shape[0], image_shape[1], 1) 120 | mask = np.dot(segmentation, np.array([[0, 255, 0, 127]])) 121 | mask = scipy.misc.toimage(mask, mode="RGBA") 122 | street_im = scipy.misc.toimage(image) 123 | street_im.paste(mask, box=None, mask=mask) 124 | 125 | yield os.path.basename(image_file), np.array(street_im) 126 | 127 | 128 | def save_inference_samples(runs_dir, data_dir, sess, image_shape, logits, keep_prob, input_image): 129 | # Make folder for current run 130 | output_dir = os.path.join(runs_dir, str(time.time())) 131 | if os.path.exists(output_dir): 132 | shutil.rmtree(output_dir) 133 | os.makedirs(output_dir) 134 | 135 | # Run NN on test images and save them to HD 136 | print('Training Finished. Saving test images to: {}'.format(output_dir)) 137 | image_outputs = gen_test_output( 138 | sess, logits, keep_prob, input_image, os.path.join(data_dir, 'data_road/testing'), image_shape) 139 | for name, image in image_outputs: 140 | scipy.misc.imsave(os.path.join(output_dir, name), image) 141 | -------------------------------------------------------------------------------- /project_tests.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from copy import deepcopy 4 | from glob import glob 5 | from unittest import mock 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | 11 | def test_safe(func): 12 | """ 13 | Isolate tests 14 | """ 15 | def func_wrapper(*args): 16 | with tf.Graph().as_default(): 17 | result = func(*args) 18 | print('Tests Passed') 19 | return result 20 | 21 | return func_wrapper 22 | 23 | 24 | def _prevent_print(function, params): 25 | sys.stdout = open(os.devnull, "w") 26 | function(**params) 27 | sys.stdout = sys.__stdout__ 28 | 29 | 30 | def _assert_tensor_shape(tensor, shape, display_name): 31 | assert tf.assert_rank(tensor, len(shape), message='{} has wrong rank'.format(display_name)) 32 | 33 | tensor_shape = tensor.get_shape().as_list() if len(shape) else [] 34 | 35 | wrong_dimension = [ten_dim for ten_dim, cor_dim in zip(tensor_shape, shape) 36 | if cor_dim is not None and ten_dim != cor_dim] 37 | assert not wrong_dimension, \ 38 | '{} has wrong shape. Found {}'.format(display_name, tensor_shape) 39 | 40 | 41 | class TmpMock(object): 42 | """ 43 | Mock a attribute. Restore attribute when exiting scope. 44 | """ 45 | def __init__(self, module, attrib_name): 46 | self.original_attrib = deepcopy(getattr(module, attrib_name)) 47 | setattr(module, attrib_name, mock.MagicMock()) 48 | self.module = module 49 | self.attrib_name = attrib_name 50 | 51 | def __enter__(self): 52 | return getattr(self.module, self.attrib_name) 53 | 54 | def __exit__(self, type, value, traceback): 55 | setattr(self.module, self.attrib_name, self.original_attrib) 56 | 57 | 58 | @test_safe 59 | def test_load_vgg(load_vgg, tf_module): 60 | with TmpMock(tf_module.saved_model.loader, 'load') as mock_load_model: 61 | vgg_path = '' 62 | sess = tf.Session() 63 | test_input_image = tf.placeholder(tf.float32, name='image_input') 64 | test_keep_prob = tf.placeholder(tf.float32, name='keep_prob') 65 | test_vgg_layer3_out = tf.placeholder(tf.float32, name='layer3_out') 66 | test_vgg_layer4_out = tf.placeholder(tf.float32, name='layer4_out') 67 | test_vgg_layer7_out = tf.placeholder(tf.float32, name='layer7_out') 68 | 69 | input_image, keep_prob, vgg_layer3_out, vgg_layer4_out, vgg_layer7_out = load_vgg(sess, vgg_path) 70 | 71 | assert mock_load_model.called, \ 72 | 'tf.saved_model.loader.load() not called' 73 | assert mock_load_model.call_args == mock.call(sess, ['vgg16'], vgg_path), \ 74 | 'tf.saved_model.loader.load() called with wrong arguments.' 75 | 76 | assert input_image == test_input_image, 'input_image is the wrong object' 77 | assert keep_prob == test_keep_prob, 'keep_prob is the wrong object' 78 | assert vgg_layer3_out == test_vgg_layer3_out, 'layer3_out is the wrong object' 79 | assert vgg_layer4_out == test_vgg_layer4_out, 'layer4_out is the wrong object' 80 | assert vgg_layer7_out == test_vgg_layer7_out, 'layer7_out is the wrong object' 81 | 82 | 83 | @test_safe 84 | def test_layers(layers): 85 | num_classes = 2 86 | vgg_layer3_out = tf.placeholder(tf.float32, [None, None, None, 256]) 87 | vgg_layer4_out = tf.placeholder(tf.float32, [None, None, None, 512]) 88 | vgg_layer7_out = tf.placeholder(tf.float32, [None, None, None, 4096]) 89 | layers_output = layers(vgg_layer3_out, vgg_layer4_out, vgg_layer7_out, num_classes) 90 | 91 | _assert_tensor_shape(layers_output, [None, None, None, num_classes], 'Layers Output') 92 | 93 | 94 | @test_safe 95 | def test_optimize(optimize): 96 | num_classes = 2 97 | shape = [2, 3, 4, num_classes] 98 | layers_output = tf.Variable(tf.zeros(shape)) 99 | correct_label = tf.placeholder(tf.float32, [None, None, None, num_classes]) 100 | learning_rate = tf.placeholder(tf.float32) 101 | logits, train_op, cross_entropy_loss = optimize(layers_output, correct_label, learning_rate, num_classes) 102 | 103 | _assert_tensor_shape(logits, [2*3*4, num_classes], 'Logits') 104 | 105 | with tf.Session() as sess: 106 | sess.run(tf.global_variables_initializer()) 107 | sess.run([train_op], {correct_label: np.arange(np.prod(shape)).reshape(shape), learning_rate: 10}) 108 | test, loss = sess.run([layers_output, cross_entropy_loss], {correct_label: np.arange(np.prod(shape)).reshape(shape)}) 109 | 110 | assert test.min() != 0 or test.max() != 0, 'Training operation not changing weights.' 111 | 112 | 113 | @test_safe 114 | def test_train_nn(train_nn): 115 | epochs = 1 116 | batch_size = 2 117 | 118 | def get_batches_fn(batach_size_parm): 119 | shape = [batach_size_parm, 2, 3, 3] 120 | return np.arange(np.prod(shape)).reshape(shape) 121 | 122 | train_op = tf.constant(0) 123 | cross_entropy_loss = tf.constant(10.11) 124 | input_image = tf.placeholder(tf.float32, name='input_image') 125 | correct_label = tf.placeholder(tf.float32, name='correct_label') 126 | keep_prob = tf.placeholder(tf.float32, name='keep_prob') 127 | learning_rate = tf.placeholder(tf.float32, name='learning_rate') 128 | with tf.Session() as sess: 129 | parameters = { 130 | 'sess': sess, 131 | 'epochs': epochs, 132 | 'batch_size': batch_size, 133 | 'get_batches_fn': get_batches_fn, 134 | 'train_op': train_op, 135 | 'cross_entropy_loss': cross_entropy_loss, 136 | 'input_image': input_image, 137 | 'correct_label': correct_label, 138 | 'keep_prob': keep_prob, 139 | 'learning_rate': learning_rate} 140 | _prevent_print(train_nn, parameters) 141 | 142 | 143 | @test_safe 144 | def test_for_kitti_dataset(data_dir): 145 | kitti_dataset_path = os.path.join(data_dir, 'data_road') 146 | training_labels_count = len(glob(os.path.join(kitti_dataset_path, 'training/gt_image_2/*_road_*.png'))) 147 | training_images_count = len(glob(os.path.join(kitti_dataset_path, 'training/image_2/*.png'))) 148 | testing_images_count = len(glob(os.path.join(kitti_dataset_path, 'testing/image_2/*.png'))) 149 | 150 | assert not (training_images_count == training_labels_count == testing_images_count == 0),\ 151 | 'Kitti dataset not found. Extract Kitti dataset in {}'.format(kitti_dataset_path) 152 | assert training_images_count == 289, 'Expected 289 training images, found {} images.'.format(training_images_count) 153 | assert training_labels_count == 289, 'Expected 289 training labels, found {} labels.'.format(training_labels_count) 154 | assert testing_images_count == 290, 'Expected 290 testing images, found {} images.'.format(testing_images_count) 155 | --------------------------------------------------------------------------------