├── .gitignore
├── LICENSE
├── README.md
├── data
├── camvid.png
├── camvid_out.png
├── cityscapes.png
└── cityscapes_out.png
├── datasets.py
├── main_tf.py
├── model.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | env/
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | local_settings.py
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # dotenv
85 | .env
86 |
87 | # virtualenv
88 | .venv
89 | venv/
90 | ENV/
91 |
92 | # Spyder project settings
93 | .spyderproject
94 | .spyproject
95 |
96 | # Rope project settings
97 | .ropeproject
98 |
99 | # mkdocs documentation
100 | /site
101 |
102 | # mypy
103 | .mypy_cache/
104 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Andrea Palazzi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # dilation-tensorflow
2 | A native *Tensorflow* implementation of semantic segmentation according to [Multi-Scale Context Aggregation by Dilated Convolutions](https://arxiv.org/abs/1511.07122) by Yu and Koltun.
3 |
4 | Pretrained weights have been converted to TensorFlow from the [original Caffe implementation](https://github.com/fyu/dilation).
5 |
6 | Model pretrained either on CityScapes or on CamVid datasets is available.
7 |
8 | You you're looking instead for a *Keras+Theano* implementation of this very same network you can find it [here](https://github.com/DavideA/dilation-keras).
9 |
10 | ## Examples
11 |
12 | ### Cityscapes
13 |
14 |
15 |
16 |
17 |
18 |
19 | Test image (input)
20 |
21 | |
22 |
23 |
24 |
25 | Test image (prediction)
26 |
27 | |
28 |
29 |
30 |
31 | ### CamVid
32 |
33 |
34 |
35 |
36 |
37 |
38 | Test image (input)
39 |
40 | |
41 |
42 |
43 |
44 | Test image (prediction)
45 |
46 | |
47 |
48 |
49 |
50 | ## How-to
51 | 1. Download pretrained weights from here:
52 |
53 | [CityScapes weights](https://drive.google.com/open?id=0Bx9YaGcDPu3XR0d4cXVSWmtVdEE)
54 |
55 | [CamVid weights](https://drive.google.com/open?id=0Bx9YaGcDPu3Xd0JrcXZpTEpkb0U)
56 |
57 | 2. Move weights file into [`data`](data) directory.
58 |
59 | 3. Run the model on the test image by executing [`main_tf.py`](main_tf.py).
60 |
61 | ## Configuration
62 |
63 | This model has been tested with the following configuration:
64 | - Ubuntu 16.04
65 | - python 3.5.2
66 | - tensorflow 1.1.0
67 | - cv2 3.2.0
68 |
69 | ## Acknowledgements
70 |
71 | Thanks to [DavideA](https://github.com/DavideA) which converted all weights from Caffe to Keras+Theano ([here](https://github.com/DavideA/dilation-keras)) making my effort of conversion towards TensorFlow much less painful than it could have been :-)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/data/camvid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ndrplz/dilation-tensorflow/306a158a1defbd56d2db8529d95300936672c421/data/camvid.png
--------------------------------------------------------------------------------
/data/camvid_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ndrplz/dilation-tensorflow/306a158a1defbd56d2db8529d95300936672c421/data/camvid_out.png
--------------------------------------------------------------------------------
/data/cityscapes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ndrplz/dilation-tensorflow/306a158a1defbd56d2db8529d95300936672c421/data/cityscapes.png
--------------------------------------------------------------------------------
/data/cityscapes_out.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ndrplz/dilation-tensorflow/306a158a1defbd56d2db8529d95300936672c421/data/cityscapes_out.png
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | # Configurations for different datasets
4 | CONFIG = {
5 | 'cityscapes': {
6 | 'classes': 19,
7 | 'weights_file': 'data/pretrained_dilation_cityscapes.pickle',
8 | 'input_shape': (1396, 1396, 3),
9 | 'output_shape': (1024, 1024, 19),
10 | 'mean_pixel': (72.39, 82.91, 73.16),
11 | 'palette': np.array([[128, 64, 128],
12 | [244, 35, 232],
13 | [70, 70, 70],
14 | [102, 102, 156],
15 | [190, 153, 153],
16 | [153, 153, 153],
17 | [250, 170, 30],
18 | [220, 220, 0],
19 | [107, 142, 35],
20 | [152, 251, 152],
21 | [70, 130, 180],
22 | [220, 20, 60],
23 | [255, 0, 0],
24 | [0, 0, 142],
25 | [0, 0, 70],
26 | [0, 60, 100],
27 | [0, 80, 100],
28 | [0, 0, 230],
29 | [119, 11, 32]], dtype='uint8'),
30 | 'zoom': 1,
31 | 'conv_margin': 186
32 | },
33 | 'camvid': {
34 | 'classes': 11,
35 | 'weights_file': 'data/pretrained_dilation_camvid.pickle',
36 | 'input_shape': (900, 1100, 3),
37 | 'output_shape': (66, 91, 11),
38 | 'mean_pixel': (110.70, 108.77, 105.41),
39 | 'palette': np.array([[128, 0, 0],
40 | [128, 128, 0],
41 | [128, 128, 128],
42 | [64, 0, 128],
43 | [192, 128, 128],
44 | [128, 64, 128],
45 | [64, 64, 0],
46 | [64, 64, 128],
47 | [192, 192, 128],
48 | [0, 0, 192],
49 | [0, 128, 192]], dtype='uint8'),
50 | 'zoom': 8,
51 | 'conv_margin': 186
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/main_tf.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import pickle
3 | import cv2
4 | import os
5 | import os.path as path
6 | from utils import predict
7 | from model import dilation_model_pretrained
8 | from datasets import CONFIG
9 |
10 |
11 | if __name__ == '__main__':
12 |
13 | # Choose between 'cityscapes' and 'camvid'
14 | dataset = 'cityscapes'
15 |
16 | # Load dict of pretrained weights
17 | print('Loading pre-trained weights...')
18 | with open(CONFIG[dataset]['weights_file'], 'rb') as f:
19 | w_pretrained = pickle.load(f)
20 | print('Done.')
21 |
22 | # Create checkpoint directory
23 | checkpoint_dir = path.join('data/checkpoint', 'dilation_' + dataset)
24 | if not path.exists(checkpoint_dir):
25 | os.makedirs(checkpoint_dir)
26 |
27 | # Image in / out parameters
28 | input_image_path = path.join('data', dataset + '.png')
29 | output_image_path = path.join('data', dataset + '_out.png')
30 |
31 | # Build pretrained model and save it as TF checkpoint
32 | with tf.Session() as sess:
33 |
34 | # Choose input shape according to dataset characteristics
35 | input_h, input_w, input_c = CONFIG[dataset]['input_shape']
36 | input_tensor = tf.placeholder(tf.float32, shape=(None, input_h, input_w, input_c), name='input_placeholder')
37 |
38 | # Create pretrained model
39 | model = dilation_model_pretrained(dataset, input_tensor, w_pretrained, trainable=False)
40 |
41 | sess.run(tf.global_variables_initializer())
42 |
43 | # Save both graph and weights
44 | saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES))
45 | saver.save(sess, path.join(checkpoint_dir, 'dilation'))
46 |
47 | # Restore both graph and weights from TF checkpoint
48 | with tf.Session() as sess:
49 |
50 | saver = tf.train.import_meta_graph(path.join(checkpoint_dir, 'dilation.meta'))
51 | saver.restore(sess, tf.train.latest_checkpoint(checkpoint_dir))
52 |
53 | graph = tf.get_default_graph()
54 | model = graph.get_tensor_by_name('softmax:0')
55 | model = tf.reshape(model, shape=(1,)+CONFIG[dataset]['output_shape'])
56 |
57 | # Read and predict on a test image
58 | input_image = cv2.imread(input_image_path)
59 | input_tensor = graph.get_tensor_by_name('input_placeholder:0')
60 | predicted_image = predict(input_image, input_tensor, model, dataset, sess)
61 |
62 | # Convert colorspace (palette is in RGB) and save prediction result
63 | predicted_image = cv2.cvtColor(predicted_image, cv2.COLOR_BGR2RGB)
64 | cv2.imwrite(output_image_path, predicted_image)
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def dilation_model_pretrained(dataset, input_tensor, w_pretrained, trainable):
5 |
6 | def conv(name, input, strides, padding, add_bias, apply_relu, atrous_rate=None):
7 | """
8 | Helper function for loading convolution weights from weight dictionary.
9 | """
10 | with tf.variable_scope(name):
11 |
12 | # Load kernel weights and apply convolution
13 | w_kernel = w_pretrained[name + '/kernel:0']
14 | w_kernel = tf.Variable(initial_value=w_kernel, trainable=trainable)
15 |
16 | if not atrous_rate:
17 | conv_out = tf.nn.conv2d(input, w_kernel, strides, padding)
18 | else:
19 | conv_out = tf.nn.atrous_conv2d(input, w_kernel, atrous_rate, padding)
20 | if add_bias:
21 | # Load bias values and add them to conv output
22 | w_bias = w_pretrained[name + '/bias:0']
23 | w_bias = tf.Variable(initial_value=w_bias, trainable=trainable)
24 | conv_out = tf.nn.bias_add(conv_out, w_bias)
25 |
26 | if apply_relu:
27 | # Apply ReLu nonlinearity
28 | conv_out = tf.nn.relu(conv_out)
29 |
30 | return conv_out
31 |
32 | # Sanity check on dataset name
33 | if dataset not in ['cityscapes', 'camvid']:
34 | raise ValueError('Dataset "{}" not supported.'.format(dataset))
35 |
36 | # Start building the model
37 | else:
38 |
39 | h = conv('conv1_1', input_tensor, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
40 | h = conv('conv1_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
41 | h = tf.layers.max_pooling2d(h, pool_size=(2, 2), strides=(2, 2), padding='valid', name='pool1')
42 |
43 | h = conv('conv2_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
44 | h = conv('conv2_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
45 | h = tf.layers.max_pooling2d(h, pool_size=(2, 2), strides=(2, 2), padding='valid', name='pool2')
46 |
47 | h = conv('conv3_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
48 | h = conv('conv3_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
49 | h = conv('conv3_3', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
50 | h = tf.layers.max_pooling2d(h, pool_size=(2, 2), strides=(2, 2), padding='valid', name='pool3')
51 |
52 | h = conv('conv4_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
53 | h = conv('conv4_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
54 | h = conv('conv4_3', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
55 |
56 | h = conv('conv5_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=2)
57 | h = conv('conv5_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=2)
58 | h = conv('conv5_3', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=2)
59 | h = conv('fc6', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=4)
60 |
61 | h = tf.layers.dropout(h, rate=0.5, name='drop6')
62 | h = conv('fc7', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
63 | h = tf.layers.dropout(h, rate=0.5, name='drop7')
64 | h = conv('final', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
65 |
66 | h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT', name='ctx_pad1_1')
67 | h = conv('ctx_conv1_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
68 | h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT', name='ctx_pad1_2')
69 | h = conv('ctx_conv1_2', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
70 |
71 | h = tf.pad(h, [[0, 0], [2, 2], [2, 2], [0, 0]], mode='CONSTANT', name='ctx_pad2_1')
72 | h = conv('ctx_conv2_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=2)
73 |
74 | h = tf.pad(h, [[0, 0], [4, 4], [4, 4], [0, 0]], mode='CONSTANT', name='ctx_pad3_1')
75 | h = conv('ctx_conv3_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=4)
76 |
77 | h = tf.pad(h, [[0, 0], [8, 8], [8, 8], [0, 0]], mode='CONSTANT', name='ctx_pad4_1')
78 | h = conv('ctx_conv4_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=8)
79 |
80 | h = tf.pad(h, [[0, 0], [16, 16], [16, 16], [0, 0]], mode='CONSTANT', name='ctx_pad5_1')
81 | h = conv('ctx_conv5_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=16)
82 |
83 | if dataset == 'cityscapes':
84 | h = tf.pad(h, [[0, 0], [32, 32], [32, 32], [0, 0]], mode='CONSTANT', name='ctx_pad6_1')
85 | h = conv('ctx_conv6_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=32)
86 |
87 | h = tf.pad(h, [[0, 0], [64, 64], [64, 64], [0, 0]], mode='CONSTANT', name='ctx_pad7_1')
88 | h = conv('ctx_conv7_1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True, atrous_rate=64)
89 |
90 | h = tf.pad(h, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT', name='ctx_pad_fc1')
91 | h = conv('ctx_fc1', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=True)
92 | h = conv('ctx_final', h, strides=[1, 1, 1, 1], padding='VALID', add_bias=True, apply_relu=False)
93 |
94 | if dataset == 'cityscapes':
95 | h = tf.image.resize_bilinear(h, size=(1024, 1024))
96 | logits = conv('ctx_upsample', h, strides=[1, 1, 1, 1], padding='SAME', add_bias=False, apply_relu=True)
97 | else:
98 | logits = h
99 |
100 | softmax = tf.nn.softmax(logits, dim=3, name='softmax')
101 |
102 | return softmax
103 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from datasets import CONFIG
4 | #import numba
5 |
6 |
7 | # this function is the same as the one in the original repository
8 | # basically it performs upsampling for datasets having zoom > 1
9 | # @numba.jit(nopython=True)
10 | def interp_map(prob, zoom, width, height):
11 | channels = prob.shape[2]
12 | zoom_prob = np.zeros((height, width, channels), dtype=np.float32)
13 | for c in range(channels):
14 | for h in range(height):
15 | for w in range(width):
16 | r0 = h // zoom
17 | r1 = r0 + 1
18 | c0 = w // zoom
19 | c1 = c0 + 1
20 | rt = float(h) / zoom - r0
21 | ct = float(w) / zoom - c0
22 | v0 = rt * prob[r1, c0, c] + (1 - rt) * prob[r0, c0, c]
23 | v1 = rt * prob[r1, c1, c] + (1 - rt) * prob[r0, c1, c]
24 | zoom_prob[h, w, c] = (1 - ct) * v0 + ct * v1
25 | return zoom_prob
26 |
27 |
28 | # predict function, mostly reported as it was in the original repo
29 | def predict(image, input_tensor, model, ds, sess):
30 |
31 | image = image.astype(np.float32) - CONFIG[ds]['mean_pixel']
32 | conv_margin = CONFIG[ds]['conv_margin']
33 |
34 | input_dims = (1,) + CONFIG[ds]['input_shape']
35 | batch_size, input_height, input_width, num_channels = input_dims
36 | model_in = np.zeros(input_dims, dtype=np.float32)
37 |
38 | image_size = image.shape
39 | output_height = input_height - 2 * conv_margin
40 | output_width = input_width - 2 * conv_margin
41 | image = cv2.copyMakeBorder(image, conv_margin, conv_margin,
42 | conv_margin, conv_margin,
43 | cv2.BORDER_REFLECT_101)
44 |
45 | num_tiles_h = image_size[0] // output_height + (1 if image_size[0] % output_height else 0)
46 | num_tiles_w = image_size[1] // output_width + (1 if image_size[1] % output_width else 0)
47 |
48 | row_prediction = []
49 | for h in range(num_tiles_h):
50 | col_prediction = []
51 | for w in range(num_tiles_w):
52 | offset = [output_height * h,
53 | output_width * w]
54 | tile = image[offset[0]:offset[0] + input_height,
55 | offset[1]:offset[1] + input_width, :]
56 | margin = [0, input_height - tile.shape[0],
57 | 0, input_width - tile.shape[1]]
58 | tile = cv2.copyMakeBorder(tile, margin[0], margin[1],
59 | margin[2], margin[3],
60 | cv2.BORDER_REFLECT_101)
61 |
62 | model_in[0] = tile
63 |
64 | prob = sess.run(model, feed_dict={input_tensor: tile[None, ...]})[0]
65 |
66 | col_prediction.append(prob)
67 |
68 | col_prediction = np.concatenate(col_prediction, axis=1) # previously axis=2
69 | row_prediction.append(col_prediction)
70 | prob = np.concatenate(row_prediction, axis=0)
71 | if CONFIG[ds]['zoom'] > 1:
72 | prob = interp_map(prob, CONFIG[ds]['zoom'], image_size[1], image_size[0])
73 |
74 | prediction = np.argmax(prob, axis=2)
75 | color_image = CONFIG[ds]['palette'][prediction.ravel()].reshape(image_size)
76 |
77 | return color_image
--------------------------------------------------------------------------------