├── .gitignore ├── Layers.py ├── README.md ├── VNet.py └── VNetDiagram.png /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | .static_storage/ 56 | .media/ 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | .idea/* 107 | 108 | debug.py -------------------------------------------------------------------------------- /Layers.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Miguel Monteiro 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import tensorflow as tf 24 | import numpy as np 25 | 26 | 27 | def xavier_initializer_convolution(shape, dist='uniform', lambda_initializer=True): 28 | """ 29 | Xavier initializer for N-D convolution patches. input_activations = patch_volume * in_channels; 30 | output_activations = patch_volume * out_channels; Uniform: lim = sqrt(3/(input_activations + output_activations)) 31 | Normal: stddev = sqrt(6/(input_activations + output_activations)) 32 | :param shape: The shape of the convolution patch i.e. spatial_shape + [input_channels, output_channels]. The order of 33 | input_channels and output_channels is irrelevant, hence this can be used to initialize deconvolution parameters. 34 | :param dist: A string either 'uniform' or 'normal' determining the type of distribution 35 | :param lambda_initializer: Whether to return the initial actual values of the parameters (True) or placeholders that 36 | are initialized when the session is initiated 37 | :return: A numpy araray with the initial values for the parameters in the patch 38 | """ 39 | s = len(shape) - 2 40 | num_activations = np.prod(shape[:s]) * np.sum(shape[s:]) # input_activations + output_activations 41 | if dist == 'uniform': 42 | lim = np.sqrt(6. / num_activations) 43 | if lambda_initializer: 44 | return np.random.uniform(-lim, lim, shape).astype(np.float32) 45 | else: 46 | return tf.random_uniform(shape, minval=-lim, maxval=lim) 47 | if dist == 'normal': 48 | stddev = np.sqrt(3. / num_activations) 49 | if lambda_initializer: 50 | return np.random.normal(0, stddev, shape).astype(np.float32) 51 | else: 52 | tf.truncated_normal(shape, mean=0, stddev=stddev) 53 | raise ValueError('Distribution must be either "uniform" or "normal".') 54 | 55 | 56 | def constant_initializer(value, shape, lambda_initializer=True): 57 | if lambda_initializer: 58 | return np.full(shape, value).astype(np.float32) 59 | else: 60 | return tf.constant(value, tf.float32, shape) 61 | 62 | 63 | def get_spatial_rank(x): 64 | """ 65 | 66 | :param x: an input tensor with shape [batch_size, ..., num_channels] 67 | :return: the spatial rank of the tensor i.e. the number of spatial dimensions between batch_size and num_channels 68 | """ 69 | return len(x.get_shape()) - 2 70 | 71 | 72 | def get_num_channels(x): 73 | """ 74 | 75 | :param x: an input tensor with shape [batch_size, ..., num_channels] 76 | :return: the number of channels of x 77 | """ 78 | return int(x.get_shape()[-1]) 79 | 80 | 81 | def get_spatial_size(x): 82 | """ 83 | 84 | :param x: an input tensor with shape [batch_size, ..., num_channels] 85 | :return: The spatial shape of x, excluding batch_size and num_channels. 86 | """ 87 | return x.get_shape()[1:-1] 88 | 89 | 90 | # parametric leaky relu 91 | def prelu(x): 92 | alpha = tf.get_variable('alpha', shape=x.get_shape()[-1], dtype=x.dtype, initializer=tf.constant_initializer(0.1)) 93 | return tf.maximum(0.0, x) + alpha * tf.minimum(0.0, x) 94 | 95 | 96 | def convolution(x, filter, padding='SAME', strides=None, dilation_rate=None): 97 | w = tf.get_variable(name='weights', initializer=xavier_initializer_convolution(shape=filter)) 98 | b = tf.get_variable(name='biases', initializer=constant_initializer(0, shape=filter[-1])) 99 | 100 | return tf.nn.convolution(x, w, padding, strides, dilation_rate) + b 101 | 102 | 103 | def deconvolution(x, filter, output_shape, strides, padding='SAME'): 104 | w = tf.get_variable(name='weights', initializer=xavier_initializer_convolution(shape=filter)) 105 | b = tf.get_variable(name='biases', initializer=constant_initializer(0, shape=filter[-2])) 106 | 107 | spatial_rank = get_spatial_rank(x) 108 | if spatial_rank == 2: 109 | return tf.nn.conv2d_transpose(x, filter, output_shape, strides, padding) + b 110 | if spatial_rank == 3: 111 | return tf.nn.conv3d_transpose(x, w, output_shape, strides, padding) + b 112 | raise ValueError('Only 2D and 3D images supported.') 113 | 114 | 115 | # More complex blocks 116 | 117 | # down convolution 118 | def down_convolution(x, factor, kernel_size): 119 | num_channels = get_num_channels(x) 120 | spatial_rank = get_spatial_rank(x) 121 | strides = spatial_rank * [factor] 122 | filter = kernel_size + [num_channels, num_channels * factor] 123 | x = convolution(x, filter, strides=strides) 124 | return x 125 | 126 | 127 | # up convolution 128 | def up_convolution(x, output_shape, factor, kernel_size): 129 | num_channels = get_num_channels(x) 130 | spatial_rank = get_spatial_rank(x) 131 | strides = [1] + spatial_rank * [factor] + [1] 132 | filter = kernel_size + [num_channels // factor, num_channels] 133 | x = deconvolution(x, filter, output_shape, strides=strides) 134 | return x 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Tensorflow implementation of V-Net 2 | 3 | This is a Tensorflow implementation of the ["V-Net"](https://arxiv.org/abs/1606.04797) architecture used for 3D medical 4 | imaging segmentation. 5 | This code only implements the Tensorflow graph, it must be used within a training program. 6 | 7 | ### Visual Representation of the Network 8 | 9 | This is an example of a network this code implements. 10 | 11 | ![VNetDiagram](VNetDiagram.png) 12 | 13 | ### Example Usage 14 | 15 | ``` 16 | from VNet import VNet 17 | 18 | input_channels = 6 19 | num_classes = 1 20 | 21 | tf_input = tf.placeholder(dtype=tf.float32, shape=(10, 190, 190, 20, input_channels)) 22 | 23 | model = VNet(num_classes=num_classes, keep_prob=.7) 24 | 25 | logits = model.network_fn(tf_input, is_training=True) 26 | 27 | ``` 28 | 29 | `logits` will have shape `[10, 190, 190, 20, 1]`, it can the be flattened and used in the sigmoid cross entropy function. 30 | 31 | 32 | ### How to use 33 | 34 | 1. Instantiate a `VNet` class. The only mandatory argument is the number of output classes/channels. 35 | The default arguments of the class implement the network as in the paper. However, the implementation is flexible and by 36 | looking at the `VNet` class docstring you can change the network architecture. 37 | 38 | 2. Call the method `network_fn` to get the output of the network. The input of `network_fn` is a tensor with shape 39 | `[batch_size, x, y, z, ..., input_channels]` which can have as many spatial dimensions as wanted. The output of 40 | `network_fn` will have shape `[batch_size, x, y, z, ..., num_classes]`. 41 | 42 | 43 | ##### Notes 44 | 45 | In a binary segmentation problem you could use `num_classes=1` with a sigmoid loss and in a three class 46 | problem you could use `num_classes=3` with a softmax loss. 47 | 48 | 49 | ### Implementation details 50 | 51 | The `VNet` class with default parameters implements the network as is in the 52 | [original paper]((https://arxiv.org/abs/1606.04797)) but with a bit more flexibility in the number of input and output 53 | channels: 54 | 1. The input can have more than one channel. If the input has more than one channel than one more convolution is added 55 | before the input to increase the input number of channels to match `n_channels`. If the input has only one channel then 56 | it is broadcasted in the first skip connection (repeated ``n_channel` times). 57 | 2. The output does not need to have two channels like in the original architecture. 58 | 59 | The `VNEt` class can be instantiated with following arguments 60 | 61 | * `num_classes`: Number of output classes. 62 | * `keep_prob`: Dropout keep probability, set to 1.0 if not training or if no dropout is desired. 63 | * `num_channels`: The number of output channels in the first level, this will be doubled every level. 64 | * `num_levels`: The number of levels in the encoder and decoder of the network. Default is 4 as in the paper. 65 | * `num_convolutions`: An array with the number of convolutions at each level, i.e. if `num_convolutions = (1, 3, 4, 5)` 66 | then the third level of the encoder and decoder will have 4 convolutions. 67 | * `bottom_convolutions`: The number of convolutions at the bottom level of the network. Must be given separately because 68 | of the odd symmetry of the network. 69 | * `activation_fn`: The activation function. Defaults to relu, however there is prelu implementation in this code. 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /VNet.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 Miguel Monteiro 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import tensorflow as tf 24 | from Layers import convolution, down_convolution, up_convolution, get_num_channels 25 | 26 | 27 | def convolution_block(layer_input, num_convolutions, keep_prob, activation_fn): 28 | x = layer_input 29 | n_channels = get_num_channels(x) 30 | for i in range(num_convolutions): 31 | with tf.variable_scope('conv_' + str(i+1)): 32 | x = convolution(x, [5, 5, 5, n_channels, n_channels]) 33 | if i == num_convolutions - 1: 34 | x = x + layer_input 35 | x = activation_fn(x) 36 | x = tf.nn.dropout(x, keep_prob) 37 | return x 38 | 39 | 40 | def convolution_block_2(layer_input, fine_grained_features, num_convolutions, keep_prob, activation_fn): 41 | 42 | x = tf.concat((layer_input, fine_grained_features), axis=-1) 43 | n_channels = get_num_channels(layer_input) 44 | if num_convolutions == 1: 45 | with tf.variable_scope('conv_' + str(1)): 46 | x = convolution(x, [5, 5, 5, n_channels * 2, n_channels]) 47 | x = x + layer_input 48 | x = activation_fn(x) 49 | x = tf.nn.dropout(x, keep_prob) 50 | return x 51 | 52 | with tf.variable_scope('conv_' + str(1)): 53 | x = convolution(x, [5, 5, 5, n_channels * 2, n_channels]) 54 | x = activation_fn(x) 55 | x = tf.nn.dropout(x, keep_prob) 56 | 57 | for i in range(1, num_convolutions): 58 | with tf.variable_scope('conv_' + str(i+1)): 59 | x = convolution(x, [5, 5, 5, n_channels, n_channels]) 60 | if i == num_convolutions - 1: 61 | x = x + layer_input 62 | x = activation_fn(x) 63 | x = tf.nn.dropout(x, keep_prob) 64 | 65 | return x 66 | 67 | 68 | class VNet(object): 69 | def __init__(self, 70 | num_classes, 71 | keep_prob=1.0, 72 | num_channels=16, 73 | num_levels=4, 74 | num_convolutions=(1, 2, 3, 3), 75 | bottom_convolutions=3, 76 | activation_fn=tf.nn.relu): 77 | """ 78 | Implements VNet architecture https://arxiv.org/abs/1606.04797 79 | :param num_classes: Number of output classes. 80 | :param keep_prob: Dropout keep probability, set to 1.0 if not training or if no dropout is desired. 81 | :param num_channels: The number of output channels in the first level, this will be doubled every level. 82 | :param num_levels: The number of levels in the network. Default is 4 as in the paper. 83 | :param num_convolutions: An array with the number of convolutions at each level. 84 | :param bottom_convolutions: The number of convolutions at the bottom level of the network. 85 | :param activation_fn: The activation function. 86 | """ 87 | self.num_classes = num_classes 88 | self.keep_prob = keep_prob 89 | self.num_channels = num_channels 90 | assert num_levels == len(num_convolutions) 91 | self.num_levels = num_levels 92 | self.num_convolutions = num_convolutions 93 | self.bottom_convolutions = bottom_convolutions 94 | self.activation_fn = activation_fn 95 | 96 | def network_fn(self, x, is_training): 97 | 98 | keep_prob = self.keep_prob if is_training else 1.0 99 | # if the input has more than 1 channel it has to be expanded because broadcasting only works for 1 input 100 | # channel 101 | input_channels = int(x.get_shape()[-1]) 102 | with tf.variable_scope('vnet/input_layer'): 103 | if input_channels == 1: 104 | x = tf.tile(x, [1, 1, 1, 1, self.num_channels]) 105 | else: 106 | x = self.activation_fn(convolution(x, [5, 5, 5, input_channels, self.num_channels])) 107 | 108 | features = list() 109 | 110 | for l in range(self.num_levels): 111 | with tf.variable_scope('vnet/encoder/level_' + str(l + 1)): 112 | x = convolution_block(x, self.num_convolutions[l], keep_prob, activation_fn=self.activation_fn) 113 | features.append(x) 114 | with tf.variable_scope('down_convolution'): 115 | x = self.activation_fn(down_convolution(x, factor=2, kernel_size=[2, 2, 2])) 116 | 117 | with tf.variable_scope('vnet/bottom_level'): 118 | x = convolution_block(x, self.bottom_convolutions, keep_prob, activation_fn=self.activation_fn) 119 | 120 | for l in reversed(range(self.num_levels)): 121 | with tf.variable_scope('vnet/decoder/level_' + str(l + 1)): 122 | f = features[l] 123 | with tf.variable_scope('up_convolution'): 124 | x = self.activation_fn(up_convolution(x, tf.shape(f), factor=2, kernel_size=[2, 2, 2])) 125 | 126 | x = convolution_block_2(x, f, self.num_convolutions[l], keep_prob, activation_fn=self.activation_fn) 127 | 128 | with tf.variable_scope('vnet/output_layer'): 129 | logits = convolution(x, [1, 1, 1, self.num_channels, self.num_classes]) 130 | 131 | return logits 132 | 133 | 134 | -------------------------------------------------------------------------------- /VNetDiagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiguelMonteiro/VNet-Tensorflow/e102e2c950f56658b028cfc65394a00800498eb3/VNetDiagram.png --------------------------------------------------------------------------------