├── LICENSE ├── .gitignore ├── README.md ├── CapsuleNet.py ├── main.py └── CapsuleLayer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Songyang Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CapsuleNet-Gluon 2 | Implemention of Capsule Net from the paper Dynamic Routing Between Capsules Edit 3 | Add topics 4 | 5 | # Run 6 | 7 | 1. pip install --pre mxnet-cu80 -i https://pypi.douban.com/simple --user 8 | 2. pip install tqdm 9 | 3. python main.py 10 | 11 | # Results 12 | 13 | Results to be added 14 | 15 | # Issues 16 | I use `SGD` to try Capsule Net, if use `Adam` as described in the original paper, I got very low test accuracy. 17 | 18 | I find if use large train batch_size, the results is also unaceptable. I haven't figure it out. 19 | 20 | Any PR is welcomed. 21 | 22 | # Other Implementations 23 | 24 | - Kaggle (this version as self-contained notebook): 25 | - [MNIST Dataset](https://www.kaggle.com/kmader/capsulenet-on-mnist) running on the standard MNIST and predicting for test data 26 | - [MNIST Fashion](https://www.kaggle.com/kmader/capsulenet-on-fashion-mnist) running on the more challenging Fashion images. 27 | - TensorFlow: 28 | - [naturomics/CapsNet-Tensorflow](https://github.com/naturomics/CapsNet-Tensorflow.git) 29 | Very good implementation. I referred to this repository in my code. 30 | - [InnerPeace-Wu/CapsNet-tensorflow](https://github.com/InnerPeace-Wu/CapsNet-tensorflow) 31 | I referred to the use of tf.scan when optimizing my CapsuleLayer. 32 | - [LaoDar/tf_CapsNet_simple](https://github.com/LaoDar/tf_CapsNet_simple) 33 | 34 | - PyTorch: 35 | - [tonysy/CapsuleNet-PyTorch](https://github.com/tonysy/CapsuleNet-PyTorch.git) 36 | - [timomernick/pytorch-capsule](https://github.com/timomernick/pytorch-capsule) 37 | - [gram-ai/capsule-networks](https://github.com/gram-ai/capsule-networks) 38 | - [andreaazzini/capsnet.pytorch](https://github.com/andreaazzini/capsnet.pytorch.git) 39 | - [leftthomas/CapsNet](https://github.com/leftthomas/CapsNet) 40 | 41 | - MXNet: 42 | - [AaronLeong/CapsNet_Mxnet](https://github.com/AaronLeong/CapsNet_Mxnet) 43 | 44 | - Lasagne (Theano): 45 | - [DeniskaMazur/CapsNet-Lasagne](https://github.com/DeniskaMazur/CapsNet-Lasagne) 46 | 47 | - Chainer: 48 | - [soskek/dynamic_routing_between_capsules](https://github.com/soskek/dynamic_routing_between_capsules) 49 | -------------------------------------------------------------------------------- /CapsuleNet.py: -------------------------------------------------------------------------------- 1 | from mxnet.gluon import nn 2 | from mxnet.gluon.loss import Loss, L2Loss, _apply_weighting 3 | from mxnet import nd 4 | from CapsuleLayer import CapsuleConv, CapsuleDense 5 | 6 | class CapsuleNet(nn.HybridBlock): 7 | def __init__(self, *args, **kwargs): 8 | super(CapsuleNet, self).__init__(**kwargs) 9 | 10 | with self.name_scope(): 11 | 12 | conv1 = nn.HybridSequential() 13 | conv1.add( 14 | # Conv1 15 | nn.Conv2D(256,kernel_size=9, strides=2, activation='relu') 16 | ) 17 | 18 | primary = nn.HybridSequential() 19 | primary.add( 20 | CapsuleConv(dim_vector=8,out_channels=32, 21 | kernel_size=9, strides=2) 22 | ) 23 | 24 | digit = nn.HybridSequential() 25 | digit.add( 26 | CapsuleDense(dim_vector=16, dim_input_vector=8, 27 | out_channels=10, num_routing_iter=3) 28 | ) 29 | 30 | decoder_module = nn.HybridSequential() 31 | decoder_module.add( 32 | nn.Dense(512, activation='relu'), 33 | nn.Dense(1024, activation='relu'), 34 | nn.Dense(784, activation='sigmoid')) 35 | 36 | self.net = nn.HybridSequential() 37 | self.net.add(conv1, primary, digit, decoder_module) 38 | 39 | def hybrid_forward(self, F, X, y=None): 40 | # import pdb; pdb.set_trace() 41 | X = self.net[0](X) # Conv1 42 | X = self.net[1](X) # Primary Capsule 43 | X = self.net[2](X) # Digital Capsule 44 | # import pdb ; pdb.set_trace() 45 | X = X.reshape((X.shape[0],X.shape[2], X.shape[4])) 46 | # get length of vector for margin loss calculation 47 | X_l2norm = nd.sqrt((X**2).sum(axis=-1)) 48 | # import pdb ; pdb.set_trace() 49 | prob = nd.softmax(X_l2norm, axis=-1) 50 | 51 | if y is not None: 52 | max_len_indices = y 53 | else: 54 | 55 | max_len_indices = nd.argmax(prob,axis=-1) 56 | 57 | 58 | y_tile = nd.tile(y.expand_dims(axis=1), reps=(1, X.shape[-1])) 59 | batch_activated_capsules = nd.pick(X, y_tile, axis=1, keepdims=True) 60 | 61 | reconstrcutions = self.net[3](batch_activated_capsules) 62 | 63 | return prob, X_l2norm, reconstrcutions 64 | 65 | class CapsuleMarginLoss(Loss): 66 | """Calculates margin loss for CapsuleNet between output and label: 67 | 68 | .. math:: 69 | 70 | 71 | Output and label can have arbitrary shape as long as they have the same 72 | number of elements. 73 | 74 | Parameters 75 | ---------- 76 | weight : float or None 77 | Global scalar weight for loss. 78 | sample_weight : Symbol or None 79 | Per sample weighting. Must be broadcastable to 80 | the same shape as loss. For example, if loss has 81 | shape (64, 10) and you want to weight each sample 82 | in the batch, `sample_weight` should have shape (64, 1). 83 | batch_axis : int, default 0 84 | The axis that represents mini-batch. 85 | """ 86 | def __init__(self, weight=1., batch_axis=0, **kwargs): 87 | super(CapsuleMarginLoss, self).__init__(weight, batch_axis, **kwargs) 88 | 89 | def hybrid_forward(self, F, images, num_classes, labels, X_l2norm, 90 | lambda_value = 0.5, sample_weight=None): 91 | self.num_classes = num_classes 92 | labels_onehot = nd.one_hot(labels, num_classes) 93 | first_term_base = F.square(nd.maximum(0.9-X_l2norm,0)) 94 | second_term_base = F.square(nd.maximum(X_l2norm -0.1, 0)) 95 | # import pdb; pdb.set_trace() 96 | margin_loss = labels_onehot * first_term_base + lambda_value * (1-labels_onehot) * second_term_base 97 | margin_loss = margin_loss.sum(axis=1) 98 | 99 | loss = F.mean(margin_loss, axis=self._batch_axis, exclude=True) 100 | loss = _apply_weighting(F, loss, self._weight/2, sample_weight) 101 | return F.mean(loss, axis=self._batch_axis, exclude=True) 102 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import os 4 | import argparse 5 | import datetime 6 | import numpy as np 7 | import mxnet as mx 8 | from tqdm import tqdm 9 | from mxnet.gluon import nn 10 | from mxnet.gluon.loss import L2Loss 11 | from mxnet import nd, autograd, gluon, init 12 | from CapsuleNet import CapsuleNet, CapsuleMarginLoss 13 | 14 | mx.random.seed(1) 15 | os.environ['PYTHONUNBUFFERED'] = '1' 16 | os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0' 17 | os.environ['MXNET_ENABLE_GPU_P2P'] = '0' 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='CapsuelNet Pytorch MINIST Example') 21 | parser.add_argument('--batch-size', type=int, default=32, metavar='N', 22 | help='input batch size for training (default: 64)') 23 | parser.add_argument('--test-batch-size', type=int, default=128, metavar='N', 24 | help='input batch size for testing (default: 1000)') 25 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 26 | help='number of epochs to train (default: 10)') 27 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 28 | help='learning rate (default: 0.01)') 29 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 30 | help='SGD momentum (default: 0.5)') 31 | parser.add_argument('--no-cuda', action='store_true', default=False, 32 | help='disables CUDA training') 33 | parser.add_argument('--seed', type=int, default=1, metavar='S', 34 | help='random seed (default: 1)') 35 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 36 | help='how many batches to wait before logging training status') 37 | args = parser.parse_args() 38 | 39 | return args 40 | 41 | def transform(data, label): 42 | return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32) 43 | 44 | def train(net,epochs, ctx, train_data,test_data, 45 | margin_loss, reconstructions_loss, 46 | batch_size,scale_factor): 47 | num_classes = 10 48 | trainer = gluon.Trainer( 49 | net.collect_params(),'sgd', {'learning_rate': 0.05, 'wd': 5e-4}) 50 | 51 | for epoch in range(epochs): 52 | train_loss = 0.0 53 | for batch_idx, (data, label) in tqdm(enumerate(train_data), total=len(train_data), ncols=70, leave=False, unit='b'): 54 | label = label.as_in_context(ctx) 55 | data = data.as_in_context(ctx) 56 | with autograd.record(): 57 | prob, X_l2norm, reconstructions = net(data, label) 58 | loss1 = margin_loss(data, num_classes, label, X_l2norm) 59 | loss2 = reconstructions_loss(reconstructions, data) 60 | loss = loss1 + scale_factor * loss2 61 | loss.backward() 62 | trainer.step(batch_size) 63 | train_loss += nd.mean(loss).asscalar() 64 | test_acc = test(test_data, net, ctx) 65 | print('Epoch:{}, TrainLoss:{:.5f}, TestAcc:{}'.format(epoch,train_loss / len(train_data),test_acc)) 66 | 67 | def test(data_iterator, net, ctx): 68 | acc = mx.metric.Accuracy() 69 | for i, (data, label) in tqdm(enumerate(data_iterator),total=len(data_iterator), ncols=70, leave=False, unit='b'): 70 | data = data.as_in_context(ctx) 71 | label = label.as_in_context(ctx) 72 | prob,_,_ = net(data,label) 73 | predictions = nd.argmax(prob, axis=1) 74 | acc.update(preds=predictions, labels=label) 75 | return acc.get()[1] 76 | 77 | def main(): 78 | args = parse_args() 79 | ctx = mx.gpu(0) 80 | scale_factor = 0.0005 81 | ############################################################## 82 | ### Load Dataset ### 83 | ############################################################## 84 | train_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=True, 85 | transform=transform),args.batch_size, 86 | shuffle=True) 87 | test_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=False, 88 | transform=transform), 89 | args.test_batch_size, 90 | shuffle=False) 91 | ############################################################## 92 | ## Load network and set optimizer ## 93 | ############################################################## 94 | capsule_net = CapsuleNet() 95 | capsule_net.initialize(ctx=ctx, init=init.Xavier()) 96 | margin_loss = CapsuleMarginLoss() 97 | reconstructions_loss = L2Loss() 98 | # convert to static graph for speedup 99 | # capsule_net.hybridize() 100 | train(capsule_net, args.epochs,ctx,train_data,test_data, margin_loss, 101 | reconstructions_loss, args.batch_size, scale_factor) 102 | 103 | if __name__ == '__main__': 104 | main() 105 | 106 | -------------------------------------------------------------------------------- /CapsuleLayer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import mxnet as mx 3 | from mxnet.gluon import nn 4 | import mxnet.ndarray as nd 5 | 6 | class CapsuleConv(nn.HybridBlock): 7 | def __init__(self,dim_vector, out_channels, kernel_size, 8 | strides=1,padding=0): 9 | super(CapsuleConv, self).__init__() 10 | 11 | self.capsules_index = ['dim_'+str(i) for i in range(dim_vector)] 12 | for idx in self.capsules_index: 13 | setattr(self, idx, nn.Conv2D(out_channels, 14 | kernel_size=kernel_size, strides=strides, 15 | padding=padding)) 16 | 17 | def squash(self, tensor): 18 | """Batch Squashing Function 19 | 20 | Args: 21 | tensor : 5-D, (batch_size, num_channel, height, width, dim_vector) 22 | 23 | Return: 24 | tesnor_squached : 5-D, (batch_size, num_channel, height, width, dim_vector) 25 | """ 26 | epsilon = 1e-9 27 | tensor_l2norm = (tensor**2).sum(axis=-1).expand_dims(axis=-1) 28 | scale_factor = tensor_l2norm / (1 + tensor_l2norm) 29 | tensor_squashed = tensor * (scale_factor / (tensor_l2norm+epsilon)**0.5 ) 30 | 31 | return tensor_squashed 32 | def concact_vectors_in_list(self, vec_list, axis): 33 | concat_vec = vec_list[0] 34 | for i in range(1, len(vec_list)): 35 | concat_vec = nd.concat(concat_vec, vec_list[i], dim=axis) 36 | 37 | return concat_vec 38 | def hybrid_forward(self,F, X): 39 | 40 | outputs = [getattr(self,idx)(X).expand_dims(axis=-1) for idx in self.capsules_index] 41 | 42 | outputs_cat = self.concact_vectors_in_list(outputs, axis=4) 43 | outputs_squashed = self.squash(outputs_cat) 44 | return outputs_squashed 45 | 46 | class CapsuleDense(nn.HybridBlock): 47 | def __init__(self, dim_vector, dim_input_vector, out_channels, 48 | num_routing_iter=1): 49 | super(CapsuleDense, self).__init__() 50 | 51 | self.dim_vector = dim_vector 52 | self.dim_input_vector = dim_input_vector 53 | self.out_channels = out_channels 54 | self.num_routing_iter = num_routing_iter 55 | self.routing_weight_initial = True 56 | 57 | def squash(self, tensor): 58 | """Batch Squashing Function 59 | 60 | Args: 61 | tensor : 5-D, (batch_size, num_channel, height, width, dim_vector) 62 | 63 | Return: 64 | tesnor_squached : 5-D, (batch_size, num_channel, height, width, dim_vector) 65 | """ 66 | epsilon = 1e-9 67 | tensor_l2norm = (tensor**2).sum(axis=-1).expand_dims(axis=-1) 68 | scale_factor = tensor_l2norm / (1 + tensor_l2norm) 69 | tensor_squashed = tensor * (scale_factor / (tensor_l2norm+epsilon)**0.5 ) 70 | 71 | return tensor_squashed 72 | 73 | def hybrid_forward(self,F, X): 74 | # (batch_size, num_channel_prev, h, w, dim_vector) 75 | # -->(batch_size,num_capsule_prev,1,1,dim_vector) 76 | X = X.reshape((0, -1, 1, 1, 0)) 77 | 78 | 79 | self.num_capsules_prev = X.shape[1] 80 | self.batch_size = X.shape[0] 81 | # (batch_size,num_capsule_prev,out_channels,1,dim_vector) 82 | X_tile = nd.tile(X, reps=(1,1,self.out_channels,1,1)) 83 | 84 | if self.routing_weight_initial: 85 | self.routing_weight = nd.random_normal(shape=(1, 86 | self.num_capsules_prev,self.out_channels, 87 | self.dim_input_vector, self.dim_vector), name='routing_weight').as_in_context(mx.gpu(0)) 88 | self.routing_weight_initial = False 89 | # (batch_size,num_capsule_prev,out_channels,dim_input_vector,dim_vector) 90 | # (64, 1152, 10, 8, 16) 91 | W_tile = nd.tile(self.routing_weight, reps=(self.batch_size,1,1,1,1)) 92 | linear_combination_3d = nd.batch_dot( 93 | X_tile.reshape((-1, X_tile.shape[-2], X_tile.shape[-1])), 94 | W_tile.reshape((-1, W_tile.shape[-2], W_tile.shape[-1]))) 95 | # (64, 1152, 10, 1, 16) 96 | linear_combination = linear_combination_3d.reshape((self.batch_size, 97 | self.num_capsules_prev, self.out_channels, 98 | 1, self.dim_vector)) 99 | 100 | # b_ij (1, 1152, 10, 1, 1) 101 | priors = nd.zeros((1, self.num_capsules_prev,self.out_channels,1,1)) 102 | 103 | ############################################################################ 104 | ## Rounting ## 105 | ############################################################################ 106 | for iter_index in range(self.num_routing_iter): 107 | # NOTE: RoutingAlgorithm-line 4 108 | # b_ij (1, 1152, 10, 1, 1) 109 | softmax_prior = nd.softmax(priors, axis=2) # on num_capsule dimension 110 | # NOTE: RoutingAlgorithm-line 5 111 | # (64, 1152, 10, 1, 16) 112 | # output = torch.mul(softmax_prior, linear_combination) 113 | output = softmax_prior * linear_combination 114 | 115 | # (64, 1, 10, 1, 16) 116 | output_sum = output.sum(axis=1, keepdims=True) # s_J 117 | 118 | # NOTE: RoutingAlgorithm-line 6 119 | # (64, 1, 10, 1, 16) 120 | output_squashed = self.squash(output_sum) # v_J 121 | 122 | # NOTE: RoutingAlgorithm-line 7 123 | # (64, 1152, 10, 1, 16) 124 | output_tile = nd.tile(output_squashed, reps=(1,self.num_capsules_prev,1,1,1)) 125 | # (64, 1152, 10, 1, 16) x (64, 1152, 10, 1, 16) (transpose on last two axis) 126 | # ==> (64, 1152, 10, 1, 1) 127 | U_times_v = nd.batch_dot(linear_combination.reshape((-1, 1, self.dim_vector)), 128 | output_tile.reshape((-1, 1, self.dim_vector)), 129 | transpose_b =True) 130 | U_times_v = U_times_v.reshape((self.batch_size, self.num_capsules_prev, 131 | self.out_channels, 1, 1)) 132 | 133 | priors = priors + U_times_v.sum(axis=0).expand_dims(axis=0) 134 | 135 | return output_squashed # v_J 136 | --------------------------------------------------------------------------------