├── .gitignore ├── LICENSE ├── README.md ├── data ├── camvid.png ├── camvid_out.png ├── cityscapes.png └── cityscapes_out.png ├── datasets.py ├── main_tf.py ├── model.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | .spyproject 95 | 96 | # Rope project settings 97 | .ropeproject 98 | 99 | # mkdocs documentation 100 | /site 101 | 102 | # mypy 103 | .mypy_cache/ 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Andrea Palazzi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dilation-tensorflow 2 | A native *Tensorflow* implementation of semantic segmentation according to [Multi-Scale Context Aggregation by Dilated Convolutions](https://arxiv.org/abs/1511.07122) by Yu and Koltun. 3 | 4 | Pretrained weights have been converted to TensorFlow from the [original Caffe implementation](https://github.com/fyu/dilation). 5 | 6 | Model pretrained either on CityScapes or on CamVid datasets is available. 7 | 8 | You you're looking instead for a *Keras+Theano* implementation of this very same network you can find it [here](https://github.com/DavideA/dilation-keras). 9 | 10 | ## Examples 11 | 12 | ### Cityscapes 13 | 14 | 15 | 16 | 22 | 28 | 29 |
17 |

18 | input 19 |
Test image (input) 20 |

21 |
23 |

24 | segmentation 25 |
Test image (prediction) 26 |

27 |
30 | 31 | ### CamVid 32 | 33 | 34 | 35 | 41 | 47 | 48 |
36 |

37 | input 38 |
Test image (input) 39 |

40 |
42 |

43 | segmentation 44 |
Test image (prediction) 45 |

46 |
49 | 50 | ## How-to 51 | 1. Download pretrained weights from here: 52 | 53 | [CityScapes weights](https://drive.google.com/open?id=0Bx9YaGcDPu3XR0d4cXVSWmtVdEE) 54 | 55 | [CamVid weights](https://drive.google.com/open?id=0Bx9YaGcDPu3Xd0JrcXZpTEpkb0U) 56 | 57 | 2. Move weights file into [`data`](data) directory. 58 | 59 | 3. Run the model on the test image by executing [`main_tf.py`](main_tf.py). 60 | 61 | ## Configuration 62 | 63 | This model has been tested with the following configuration: 64 | - Ubuntu 16.04 65 | - python 3.5.2 66 | - tensorflow 1.1.0 67 | - cv2 3.2.0 68 | 69 | ## Acknowledgements 70 | 71 | Thanks to [DavideA](https://github.com/DavideA) which converted all weights from Caffe to Keras+Theano ([here](https://github.com/DavideA/dilation-keras)) making my effort of conversion towards TensorFlow much less painful than it could have been :-) 72 | 73 | 74 | -------------------------------------------------------------------------------- /data/camvid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ndrplz/dilation-tensorflow/306a158a1defbd56d2db8529d95300936672c421/data/camvid.png -------------------------------------------------------------------------------- /data/camvid_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ndrplz/dilation-tensorflow/306a158a1defbd56d2db8529d95300936672c421/data/camvid_out.png -------------------------------------------------------------------------------- /data/cityscapes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ndrplz/dilation-tensorflow/306a158a1defbd56d2db8529d95300936672c421/data/cityscapes.png -------------------------------------------------------------------------------- /data/cityscapes_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ndrplz/dilation-tensorflow/306a158a1defbd56d2db8529d95300936672c421/data/cityscapes_out.png -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Configurations for different datasets 4 | CONFIG = { 5 | 'cityscapes': { 6 | 'classes': 19, 7 | 'weights_file': 'data/pretrained_dilation_cityscapes.pickle', 8 | 'input_shape': (1396, 1396, 3), 9 | 'output_shape': (1024, 1024, 19), 10 | 'mean_pixel': (72.39, 82.91, 73.16), 11 | 'palette': np.array([[128, 64, 128], 12 | [244, 35, 232], 13 | [70, 70, 70], 14 | [102, 102, 156], 15 | [190, 153, 153], 16 | [153, 153, 153], 17 | [250, 170, 30], 18 | [220, 220, 0], 19 | [107, 142, 35], 20 | [152, 251, 152], 21 | [70, 130, 180], 22 | [220, 20, 60], 23 | [255, 0, 0], 24 | [0, 0, 142], 25 | [0, 0, 70], 26 | [0, 60, 100], 27 | [0, 80, 100], 28 | [0, 0, 230], 29 | [119, 11, 32]], dtype='uint8'), 30 | 'zoom': 1, 31 | 'conv_margin': 186 32 | }, 33 | 'camvid': { 34 | 'classes': 11, 35 | 'weights_file': 'data/pretrained_dilation_camvid.pickle', 36 | 'input_shape': (900, 1100, 3), 37 | 'output_shape': (66, 91, 11), 38 | 'mean_pixel': (110.70, 108.77, 105.41), 39 | 'palette': np.array([[128, 0, 0], 40 | [128, 128, 0], 41 | [128, 128, 128], 42 | [64, 0, 128], 43 | [192, 128, 128], 44 | [128, 64, 128], 45 | [64, 64, 0], 46 | [64, 64, 128], 47 | [192, 192, 128], 48 | [0, 0, 192], 49 | [0, 128, 192]], dtype='uint8'), 50 | 'zoom': 8, 51 | 'conv_margin': 186 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /main_tf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import pickle 3 | import cv2 4 | import os 5 | import os.path as path 6 | from utils import predict 7 | from model import dilation_model_pretrained 8 | from datasets import CONFIG 9 | 10 | 11 | if __name__ == '__main__': 12 | 13 | # Choose between 'cityscapes' and 'camvid' 14 | dataset = 'cityscapes' 15 | 16 | # Load dict of pretrained weights 17 | print('Loading pre-trained weights...') 18 | with open(CONFIG[dataset]['weights_file'], 'rb') as f: 19 | w_pretrained = pickle.load(f) 20 | print('Done.') 21 | 22 | # Create checkpoint directory 23 | checkpoint_dir = path.join('data/checkpoint', 'dilation_' + dataset) 24 | if not path.exists(checkpoint_dir): 25 | os.makedirs(checkpoint_dir) 26 | 27 | # Image in / out parameters 28 | input_image_path = path.join('data', dataset + '.png') 29 | output_image_path = path.join('data', dataset + '_out.png') 30 | 31 | # Build pretrained model and save it as TF checkpoint 32 | with tf.Session() as sess: 33 | 34 | # Choose input shape according to dataset characteristics 35 | input_h, input_w, input_c = CONFIG[dataset]['input_shape'] 36 | input_tensor = tf.placeholder(tf.float32, shape=(None, input_h, input_w, input_c), name='input_placeholder') 37 | 38 | # Create pretrained model 39 | model = dilation_model_pretrained(dataset, input_tensor, w_pretrained, trainable=False) 40 | 41 | sess.run(tf.global_variables_initializer()) 42 | 43 | # Save both graph and weights 44 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)) 45 | saver.save(sess, path.join(checkpoint_dir, 'dilation')) 46 | 47 | # Restore both graph and weights from TF checkpoint 48 | with tf.Session() as sess: 49 | 50 | saver = tf.train.import_meta_graph(path.join(checkpoint_dir, 'dilation.meta')) 51 | saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir)) 52 | 53 | graph = tf.get_default_graph() 54 | model = graph.get_tensor_by_name('softmax:0') 55 | model = tf.reshape(model, shape=(1,)+CONFIG[dataset]['output_shape']) 56 | 57 | # Read and predict on a test image 58 | input_image = cv2.imread(input_image_path) 59 | input_tensor = graph.get_tensor_by_name('input_placeholder:0') 60 | predicted_image = predict(input_image, input_tensor, model, dataset, sess) 61 | 62 | # Convert colorspace (palette is in RGB) and save prediction result 63 | predicted_image = cv2.cvtColor(predicted_image, cv2.COLOR_BGR2RGB) 64 | cv2.imwrite(output_image_path, predicted_image) 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def dilation_model_pretrained(dataset, input_tensor, w_pretrained, trainable): 5 | 6 | def conv(name, input, strides, padding, add_bias, apply_relu, atrous_rate=None): 7 | """ 8 | Helper function for loading convolution weights from weight dictionary. 9 | """ 10 | with tf.variable_scope(name): 11 | 12 | # Load kernel weights and apply convolution 13 | w_kernel = w_pretrained[name + '/kernel:0'] 14 | w_kernel = tf.Variable(initial_value=w_kernel, trainable=trainable) 15 | 16 | if not atrous_rate: 17 | conv_out = tf.nn.conv2d(input, w_kernel, strides, padding) 18 | else: 19 | conv_out = tf.nn.atrous_conv2d(input, w_kernel, atrous_rate, padding) 20 | if add_bias: 21 | # Load bias values and add them to conv output 22 | w_bias = w_pretrained[name + '/bias:0'] 23 | w_bias = tf.Variable(initial_value=w_bias, trainable=trainable) 24 | conv_out = tf.nn.bias_add(conv_out, w_bias) 25 | 26 | if apply_relu: 27 | # Apply ReLu nonlinearity 28 | conv_out = tf.nn.relu(conv_out) 29 | 30 | return conv_out 31 | 32 | # Sanity check on dataset name 33 | if dataset not in ['cityscapes', 'camvid']: 34 | raise ValueError('Dataset "{}" not supported.'.format(dataset)) 35 | 36 | # Start building the model 37 | else: 38 | 39 | h = conv('conv1_1', input_tensor, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 40 | h = conv('conv1_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 41 | h = tf.layers.max_pooling2d(h, pool_size=(2, 2), strides=(2, 2), padding='valid', name='pool1') 42 | 43 | h = conv('conv2_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 44 | h = conv('conv2_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 45 | h = tf.layers.max_pooling2d(h, pool_size=(2, 2), strides=(2, 2), padding='valid', name='pool2') 46 | 47 | h = conv('conv3_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 48 | h = conv('conv3_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 49 | h = conv('conv3_3', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 50 | h = tf.layers.max_pooling2d(h, pool_size=(2, 2), strides=(2, 2), padding='valid', name='pool3') 51 | 52 | h = conv('conv4_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 53 | h = conv('conv4_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 54 | h = conv('conv4_3', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 55 | 56 | h = conv('conv5_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=2) 57 | h = conv('conv5_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=2) 58 | h = conv('conv5_3', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=2) 59 | h = conv('fc6', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=4) 60 | 61 | h = tf.layers.dropout(h, rate=0.5, name='drop6') 62 | h = conv('fc7', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 63 | h = tf.layers.dropout(h, rate=0.5, name='drop7') 64 | h = conv('final', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 65 | 66 | h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT', name='ctx_pad1_1') 67 | h = conv('ctx_conv1_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 68 | h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT', name='ctx_pad1_2') 69 | h = conv('ctx_conv1_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 70 | 71 | h = tf.pad(h, [[0, 0], [2, 2], [2, 2], [0, 0]], mode='CONSTANT', name='ctx_pad2_1') 72 | h = conv('ctx_conv2_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=2) 73 | 74 | h = tf.pad(h, [[0, 0], [4, 4], [4, 4], [0, 0]], mode='CONSTANT', name='ctx_pad3_1') 75 | h = conv('ctx_conv3_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=4) 76 | 77 | h = tf.pad(h, [[0, 0], [8, 8], [8, 8], [0, 0]], mode='CONSTANT', name='ctx_pad4_1') 78 | h = conv('ctx_conv4_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=8) 79 | 80 | h = tf.pad(h, [[0, 0], [16, 16], [16, 16], [0, 0]], mode='CONSTANT', name='ctx_pad5_1') 81 | h = conv('ctx_conv5_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=16) 82 | 83 | if dataset == 'cityscapes': 84 | h = tf.pad(h, [[0, 0], [32, 32], [32, 32], [0, 0]], mode='CONSTANT', name='ctx_pad6_1') 85 | h = conv('ctx_conv6_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=32) 86 | 87 | h = tf.pad(h, [[0, 0], [64, 64], [64, 64], [0, 0]], mode='CONSTANT', name='ctx_pad7_1') 88 | h = conv('ctx_conv7_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=64) 89 | 90 | h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT', name='ctx_pad_fc1') 91 | h = conv('ctx_fc1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True) 92 | h = conv('ctx_final', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=False) 93 | 94 | if dataset == 'cityscapes': 95 | h = tf.image.resize_bilinear(h, size=(1024, 1024)) 96 | logits = conv('ctx_upsample', h, strides=[1, 1, 1, 1], padding='SAME', add_bias=False, apply_relu=True) 97 | else: 98 | logits = h 99 | 100 | softmax = tf.nn.softmax(logits, dim=3, name='softmax') 101 | 102 | return softmax 103 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from datasets import CONFIG 4 | #import numba 5 | 6 | 7 | # this function is the same as the one in the original repository 8 | # basically it performs upsampling for datasets having zoom > 1 9 | # @numba.jit(nopython=True) 10 | def interp_map(prob, zoom, width, height): 11 | channels = prob.shape[2] 12 | zoom_prob = np.zeros((height, width, channels), dtype=np.float32) 13 | for c in range(channels): 14 | for h in range(height): 15 | for w in range(width): 16 | r0 = h // zoom 17 | r1 = r0 + 1 18 | c0 = w // zoom 19 | c1 = c0 + 1 20 | rt = float(h) / zoom - r0 21 | ct = float(w) / zoom - c0 22 | v0 = rt * prob[r1, c0, c] + (1 - rt) * prob[r0, c0, c] 23 | v1 = rt * prob[r1, c1, c] + (1 - rt) * prob[r0, c1, c] 24 | zoom_prob[h, w, c] = (1 - ct) * v0 + ct * v1 25 | return zoom_prob 26 | 27 | 28 | # predict function, mostly reported as it was in the original repo 29 | def predict(image, input_tensor, model, ds, sess): 30 | 31 | image = image.astype(np.float32) - CONFIG[ds]['mean_pixel'] 32 | conv_margin = CONFIG[ds]['conv_margin'] 33 | 34 | input_dims = (1,) + CONFIG[ds]['input_shape'] 35 | batch_size, input_height, input_width, num_channels = input_dims 36 | model_in = np.zeros(input_dims, dtype=np.float32) 37 | 38 | image_size = image.shape 39 | output_height = input_height - 2 * conv_margin 40 | output_width = input_width - 2 * conv_margin 41 | image = cv2.copyMakeBorder(image, conv_margin, conv_margin, 42 | conv_margin, conv_margin, 43 | cv2.BORDER_REFLECT_101) 44 | 45 | num_tiles_h = image_size[0] // output_height + (1 if image_size[0] % output_height else 0) 46 | num_tiles_w = image_size[1] // output_width + (1 if image_size[1] % output_width else 0) 47 | 48 | row_prediction = [] 49 | for h in range(num_tiles_h): 50 | col_prediction = [] 51 | for w in range(num_tiles_w): 52 | offset = [output_height * h, 53 | output_width * w] 54 | tile = image[offset[0]:offset[0] + input_height, 55 | offset[1]:offset[1] + input_width, :] 56 | margin = [0, input_height - tile.shape[0], 57 | 0, input_width - tile.shape[1]] 58 | tile = cv2.copyMakeBorder(tile, margin[0], margin[1], 59 | margin[2], margin[3], 60 | cv2.BORDER_REFLECT_101) 61 | 62 | model_in[0] = tile 63 | 64 | prob = sess.run(model, feed_dict={input_tensor: tile[None, ...]})[0] 65 | 66 | col_prediction.append(prob) 67 | 68 | col_prediction = np.concatenate(col_prediction, axis=1) # previously axis=2 69 | row_prediction.append(col_prediction) 70 | prob = np.concatenate(row_prediction, axis=0) 71 | if CONFIG[ds]['zoom'] > 1: 72 | prob = interp_map(prob, CONFIG[ds]['zoom'], image_size[1], image_size[0]) 73 | 74 | prediction = np.argmax(prob, axis=2) 75 | color_image = CONFIG[ds]['palette'][prediction.ravel()].reshape(image_size) 76 | 77 | return color_image --------------------------------------------------------------------------------