├── README.md ├── LICENSE ├── digit.py ├── tf_capsnet.py └── keras_capsnet.py /README.md: -------------------------------------------------------------------------------- 1 | # Capsule-Network 2 | 3 | [![License](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](https://github.com/zhanpenghe/Capsule-Network/blob/master/LICENSE) 4 | 5 | Keras and tensorflow implementation of utilities for building capsule network. 6 | 7 | ## TODO 8 | - make examples and conduct experiments 9 | - build different stucture of network with the capsules 10 | 11 | 12 | Reference: 13 | [Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017](https://arxiv.org/abs/1710.09829) 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Zhanpeng He 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /digit.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | The digit caps model in the paper. 4 | """ 5 | 6 | import numpy as np 7 | from keras.models import Sequential, Model 8 | from keras.datasets import mnist 9 | from keras.utils import to_categorical 10 | from keras.layers import Input, Conv2D, Dense, Reshape 11 | from keras_capsnet import margin_loss, conv2d_caps, DenseCapsule, Mask, CapsuleLength 12 | 13 | img_dim = (28, 28, 1) 14 | output_shape = 10 15 | 16 | 17 | def load_mnist_data(): 18 | 19 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 20 | x_train = x_train.reshape(-1, 28, 28, 1).astype('float32')/255.0 21 | x_test = x_test.reshape(-1, 28, 28, 1).astype('float32')/255.0 22 | y_train = to_categorical(y_train.astype('float32')) 23 | y_test = to_categorical(y_test.astype('float32')) 24 | 25 | return x_train, y_train, x_test, y_test 26 | 27 | 28 | def build_models(): 29 | global img_dim, output_shape 30 | 31 | input_layer = Input(shape=(img_dim[0], img_dim[1], img_dim[2])) 32 | conv1 = Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu')(input_layer) 33 | primary_caps = conv2d_caps( 34 | input_layer=conv1, 35 | nb_filters=32, 36 | capsule_size=8, 37 | kernel_size=9, 38 | strides=2 39 | ) 40 | digit_caps = DenseCapsule(capsule_size=16, nb_capsules=output_shape, name='densecaps')(primary_caps) 41 | output_layer = CapsuleLength(name='capsnet')(digit_caps) 42 | 43 | # Decoder part.. 44 | true_labels_input = Input(shape=(output_shape,)) # use the true label to mask the capsule 45 | masked_by_true_label = Mask()([digit_caps, true_labels_input]) # for training process 46 | masked_by_max = Mask()(digit_caps) # for prediction process 47 | 48 | decoder = Sequential(name='decoder_out') 49 | decoder.add(Dense(512, activation='relu', input_dim=16*output_shape)) 50 | decoder.add(Dense(1024, activation='relu')) 51 | decoder.add(Dense(np.prod(img_dim), activation='sigmoid')) 52 | decoder.add(Reshape(target_shape=img_dim)) 53 | 54 | train_decoder = decoder(masked_by_true_label) 55 | eval_decoder = decoder(masked_by_max) 56 | 57 | train_model = Model( 58 | inputs=[input_layer, true_labels_input], 59 | outputs=[output_layer, train_decoder] 60 | ) 61 | 62 | eval_model = Model( 63 | inputs=input_layer, 64 | outputs=[eval_decoder, output_layer] 65 | ) 66 | 67 | return train_model, eval_model 68 | 69 | 70 | def main(): 71 | 72 | x_train, y_train, x_test, y_test = load_mnist_data() 73 | train_model, eval_model = build_models() 74 | print(train_model.summary()) 75 | train_model.compile(loss=[margin_loss, 'mse'], optimizer='adam') 76 | 77 | train_model.fit([x_train, y_train], [y_train, x_train], batch_size=128, epochs=5, validation_data=[[x_test, y_test], [y_test, x_test]]) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /tf_capsnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | 5 | def squash(vector, epsilon=10e-5): 6 | 7 | vector += epsilon 8 | norm = tf.reduce_sum(tf.square(vector), -2, keepdims=True) 9 | scalar_factor = norm / (1 + norm) / tf.sqrt(norm) 10 | squashed = scalar_factor * vector 11 | return (squashed) 12 | 13 | 14 | # Convolution->Reshape 15 | def conv2d_caps(input_layer, nb_filters, kernel, capsule_size, strides=2): 16 | 17 | conv = tf.layers.conv2d( 18 | inputs=input_layer, 19 | filters=nb_filters*capsule_size, 20 | kernel_size=kernel, 21 | strides=strides, 22 | padding='valid' 23 | ) 24 | shape = conv.get_shape().as_list() 25 | capsules = tf.reshape(conv, shape=[-1, np.prod(shape[1:3]) * nb_filters, capsule_size, 1]) 26 | return squash(capsules) 27 | 28 | def dynamic_routing(u_hat, b_ij, nb_capsules, prev_nb_capsules, iterations=5): 29 | """ 30 | The dynamic routing algorithm from paper: https://arxiv.org/pdf/1710.09829.pdf 31 | """ 32 | for i in range(iterations): 33 | with tf.variable_scope('routing'+str(i)): 34 | 35 | c_ij = tf.nn.softmax(b_ij, axis=2) 36 | s_j =tf.reduce_sum(tf.multiply(c_ij, u_hat), axis=1, keepdims=True) 37 | v_j = squash(s_j) 38 | 39 | if i < iterations-1: 40 | b_ij = b_ij + tf.reduce_sum( 41 | tf.matmul(u_hat, tf.tile(v_j, [1, prev_nb_capsules, 1, 1, 1]), transpose_a=True), 42 | axis=0, 43 | keepdims=True 44 | ) 45 | 46 | return tf.squeeze(v_j, axis=1) 47 | 48 | 49 | def dense_capsule(input_layer, capsule_size, nb_capsules, iterations=5): 50 | """ 51 | Take the output from a layer of capsules and perform the following computations: 52 | ...... 53 | """ 54 | prev_shape = input_layer.get_shape().as_list() 55 | init = tf.random_normal_initializer(stddev=0.01, seed=0) 56 | w_shape = [prev_shape[1], nb_capsules, capsule_size, prev_shape[2]] 57 | w_ij = tf.get_variable('weight', shape=w_shape, dtype=tf.float32, initializer=init) 58 | 59 | 60 | # Expand dimension to allow the dot product 61 | # Dimension change from [None, prev_nb_capsules, prev_capsule_size, 1] 62 | # to [None, prev_nb_capsules, prev_capsule_size, nb_capsules, 1] 63 | expanded_input_layer = tf.expand_dims(input=input_layer, axis=2) 64 | 65 | # Make nb_capsule of copies of previous capsules to proform multiplication 66 | expanded_input_layer = tf.tile(expanded_input_layer, [1, 1, nb_capsules, 1, 1]) 67 | 68 | u_hat = tf.einsum('abdc,iabcf->iabdf', w_ij, expanded_input_layer) 69 | b_ij = tf.zeros(shape=[prev_shape[1], nb_capsules, 1, 1], dtype=np.float32) 70 | 71 | return dynamic_routing(u_hat, b_ij, nb_capsules, prev_shape[1], iterations) 72 | 73 | def test(): 74 | # Some testing.. 75 | input_layer = tf.placeholder(tf.float32, shape=[None, 28, 28, 1]) 76 | caps1 = conv2d_caps(input_layer, 64, [3, 3], 8) 77 | print(caps1.shape) 78 | caps2 = dense_capsule(caps1, 4, 100) 79 | print(caps2.shape) 80 | 81 | if __name__ == '__main__': 82 | test() -------------------------------------------------------------------------------- /keras_capsnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import keras.backend as K 3 | from keras.engine.topology import Layer 4 | from keras.layers import Conv2D, Reshape, Lambda 5 | from keras.activations import softmax 6 | 7 | 8 | def margin_loss(y_true, y_pred, 9 | m_plus=0.9, m_minus=0.1, down_weighting=0.5): 10 | """ 11 | The margin loss defined in the paper. 12 | The default parameters are those used in the paper. 13 | """ 14 | L = y_true*K.square(K.maximum(0.0, m_plus-y_pred)) + down_weighting*(1-y_true)*K.square(K.maximum(0.0, y_pred-m_minus)) 15 | 16 | return K.mean(K.sum(L, 1)) 17 | 18 | 19 | def squash(vector, epsilon=K.epsilon()): 20 | 21 | vector += epsilon 22 | norm = K.sum(K.square(vector), -1, keepdims=True) 23 | scalar_factor = norm / (1 + norm) / K.sqrt(norm) 24 | squashed = scalar_factor * vector 25 | return squashed 26 | 27 | 28 | def conv2d_caps(input_layer, nb_filters, kernel_size, capsule_size, strides=2): 29 | 30 | conv = Conv2D( 31 | filters=nb_filters*capsule_size, 32 | kernel_size=kernel_size, 33 | strides=strides, 34 | padding='valid' 35 | )(input_layer) 36 | 37 | conv_shape = conv.shape 38 | nb_capsules= int(conv_shape[1]*conv_shape[2]*nb_filters) 39 | 40 | capsules = Reshape(target_shape=(nb_capsules, capsule_size))(conv) 41 | 42 | return Lambda(squash, name='primarycap_squash')(capsules) 43 | 44 | 45 | class CapsuleLength(Layer): 46 | 47 | def call(self, inputs, **kwargs): 48 | input_shape = inputs.get_shape().as_list() 49 | x = K.reshape(inputs, shape=[-1, input_shape[1], input_shape[2]]) 50 | return K.sqrt(K.sum(K.square(x), -1)) 51 | 52 | def compute_output_shape(self, input_shape): 53 | return input_shape[:-2] 54 | 55 | 56 | class Mask(Layer): 57 | """ 58 | A mask layer for decoder to minimize the marginal loss. 59 | """ 60 | def call(self, inputs): 61 | 62 | if type(inputs) is list: 63 | 64 | assert len(inputs) == 2 65 | inputs, mask = inputs[0], inputs[1] 66 | 67 | assert mask.get_shape().as_list()[1] == inputs.get_shape().as_list()[1] 68 | 69 | else: 70 | length = K.sqrt(K.sum(K.square(inputs), axis=-1)) 71 | mask = K.one_hot( 72 | indices=K.argmax(length, 1), 73 | num_classes=inputs.get_shape().as_list()[1] 74 | ) 75 | 76 | mask = K.expand_dims(mask, -1) 77 | 78 | # [None, nb_classes, 1] 79 | masked = K.batch_flatten(inputs*mask) 80 | return masked 81 | 82 | def compute_output_shape(self, input_shape): 83 | 84 | if type(input_shape[0]) is tuple: 85 | return tuple([None, input_shape[0][1]]) 86 | else: 87 | return tuple([None, input_shape[1]]) 88 | 89 | 90 | class DenseCapsule(Layer): 91 | 92 | """ 93 | A fully connected capsule layer which is similar to 94 | the dense layer but replace the neurons to capsules 95 | """ 96 | 97 | def __init__(self, capsule_size, nb_capsules, kernel_initializer='glorot_uniform', iterations=5, **kwargs): 98 | super(DenseCapsule, self).__init__(**kwargs) 99 | self.nb_capsules = nb_capsules 100 | self.iterations = iterations 101 | self.capsule_size = capsule_size 102 | self.initializer = kernel_initializer 103 | 104 | def build(self, input_shape): 105 | self.prev_shape = input_shape 106 | self.w_ij = self.add_weight( 107 | name='w_ij', 108 | shape=( self.nb_capsules, input_shape[1], self.capsule_size, input_shape[2]), 109 | initializer=self.initializer 110 | ) 111 | self.built = True 112 | 113 | def batch_dot(self, X, w, axis): 114 | return K.map_fn(lambda x: K.batch_dot(x, w, axis), elems=X) 115 | 116 | def _dynamic_routing(self, u_hat, b_ij): 117 | 118 | for i in range(self.iterations): 119 | 120 | c_ij = softmax(b_ij, axis=1) 121 | s_j = K.batch_dot(c_ij, u_hat, [2, 2]) 122 | v_j = squash(s_j) 123 | 124 | if i < self.iterations-1: 125 | b_ij += K.batch_dot(v_j, u_hat, [2, 3]) 126 | # b_ij = b_ij + K.batch_dot(K.tile(v_j, [1, self.prev_shape[1], 1, 1, 1]), u_hat, [3, 4]) 127 | 128 | # return K.squeeze(v_j, axis=1) 129 | return v_j 130 | 131 | def call(self, inputs): 132 | 133 | expanded_input = K.expand_dims(inputs, 1) 134 | print(expanded_input.shape) 135 | expanded_input = K.tile(expanded_input, [1, self.nb_capsules, 1, 1]) 136 | print(expanded_input.shape) 137 | u_hat = K.map_fn(lambda x: K.batch_dot(x, self.w_ij, [2, 3]), elems=expanded_input) 138 | 139 | b_ij = K.zeros(shape=[K.shape(u_hat)[0], self.nb_capsules, self.prev_shape[1]], dtype=np.float32) 140 | 141 | return self._dynamic_routing(u_hat, b_ij) 142 | 143 | def compute_output_shape(self, input_shape): 144 | return tuple([None, self.nb_capsules, self.capsule_size, 1]) 145 | 146 | 147 | def test(): 148 | from keras.layers import Input 149 | input_layer = Input(shape=(28, 28, 1)) 150 | caps1 = conv2d_caps(input_layer, 64, (3, 3), 8) 151 | print(caps1.shape) 152 | caps2 = DenseCapsule(4, 100)(caps1) 153 | print(caps2.shape) 154 | 155 | if __name__ == '__main__': 156 | test() --------------------------------------------------------------------------------