├── .gitignore ├── LICENSE ├── README.md ├── capsules.py ├── loss.py ├── main.py ├── model.py ├── reconstructed.png └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Jupyter Notebook 10 | .ipynb_checkpoints 11 | 12 | # Data folder 13 | data/ 14 | checkpoints/ 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Daniel Havir 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Capsule Network # 2 | [![License](https://img.shields.io/github/license/mashape/apistatus.svg?maxAge=2592000)](LICENSE) 3 | 4 | #### PyTorch implementation of the following paper: 5 | * [_Dynamic Routing Between Capsules_](https://arxiv.org/abs/1710.09829) by Sara Sabour, Nicholas Frosst and Geoffrey Hinton 6 | 7 | ### Official implemenation 8 | * [Official implementation](https://github.com/Sarasra/models/tree/master/research/capsules) (TensorFlow) by Sara Sabour 9 | 10 | ### Visual represenation 11 | ![capsules_visual_representation](https://cdn-images-1.medium.com/max/1600/1*UPaxEd1A3N5ceckB85RIRg.jpeg) 12 | > Image source: _Mike Ross_, [A Visual Representation of Capsule Network Computations](https://medium.com/@mike_ross/a-visual-representation-of-capsule-network-computations-83767d79e737) 13 | 14 | ### Run the experiment 15 | * For details, run `python main.py --help` 16 | 17 | ### Example of reconstructed vs. original images 18 | ![reconstructed](reconstructed.png) 19 | 20 | ______ 21 | 22 | ### Requirements: 23 | * PyTorch (http://www.pytorch.org) 24 | * NumPy (http://www.numpy.org/) 25 | * GPU 26 | 27 | ### Default hyper-parameters (similar to the paper): 28 | * Per-GPU `batch_size` = 128 29 | * Initial `learning_rate` = 0.001 30 | * Exponential `lr_decay` = 0.96 31 | * Number of routing iteration (`num_routing`) = 3 32 | 33 | #### Loss function hyper-parameters (see [loss.py](loss.py)): 34 | * Lambda for Margin Loss = 0.5 35 | * Scaling factor for reconstruction loss = 0.0005 36 | 37 | ### GPU Speed benchmarks: 38 | (with above mentioned hyper-parameters) 39 | * Single GeForce GTX 1080Ti - 35.6s per epoch 40 | * Two GeForce GTX 1080Ti - 35.8s per epoch (twice the batch size -> half the iteration) 41 | 42 | -------------------------------------------------------------------------------- /capsules.py: -------------------------------------------------------------------------------- 1 | ######################################## 2 | #### Licensed under the MIT license #### 3 | ######################################## 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def squash(s, dim=-1): 11 | ''' 12 | "Squashing" non-linearity that shrunks short vectors to almost zero length and long vectors to a length slightly below 1 13 | Eq. (1): v_j = ||s_j||^2 / (1 + ||s_j||^2) * s_j / ||s_j|| 14 | 15 | Args: 16 | s: Vector before activation 17 | dim: Dimension along which to calculate the norm 18 | 19 | Returns: 20 | Squashed vector 21 | ''' 22 | squared_norm = torch.sum(s**2, dim=dim, keepdim=True) 23 | return squared_norm / (1 + squared_norm) * s / (torch.sqrt(squared_norm) + 1e-8) 24 | 25 | 26 | class PrimaryCapsules(nn.Module): 27 | def __init__(self, in_channels, out_channels, dim_caps, 28 | kernel_size=9, stride=2, padding=0): 29 | """ 30 | Initialize the layer. 31 | 32 | Args: 33 | in_channels: Number of input channels. 34 | out_channels: Number of output channels. 35 | dim_caps: Dimensionality, i.e. length, of the output capsule vector. 36 | 37 | """ 38 | super(PrimaryCapsules, self).__init__() 39 | self.dim_caps = dim_caps 40 | self._caps_channel = int(out_channels / dim_caps) 41 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) 42 | 43 | def forward(self, x): 44 | out = self.conv(x) 45 | out = out.view(out.size(0), self._caps_channel, out.size(2), out.size(3), self.dim_caps) 46 | out = out.view(out.size(0), -1, self.dim_caps) 47 | return squash(out) 48 | 49 | 50 | class RoutingCapsules(nn.Module): 51 | def __init__(self, in_dim, in_caps, num_caps, dim_caps, num_routing, device: torch.device): 52 | """ 53 | Initialize the layer. 54 | 55 | Args: 56 | in_dim: Dimensionality (i.e. length) of each capsule vector. 57 | in_caps: Number of input capsules if digits layer. 58 | num_caps: Number of capsules in the capsule layer 59 | dim_caps: Dimensionality, i.e. length, of the output capsule vector. 60 | num_routing: Number of iterations during routing algorithm 61 | """ 62 | super(RoutingCapsules, self).__init__() 63 | self.in_dim = in_dim 64 | self.in_caps = in_caps 65 | self.num_caps = num_caps 66 | self.dim_caps = dim_caps 67 | self.num_routing = num_routing 68 | self.device = device 69 | 70 | self.W = nn.Parameter( 0.01 * torch.randn(1, num_caps, in_caps, dim_caps, in_dim ) ) 71 | 72 | def __repr__(self): 73 | tab = ' ' 74 | line = '\n' 75 | next = ' -> ' 76 | res = self.__class__.__name__ + '(' 77 | res = res + line + tab + '(' + str(0) + '): ' + 'CapsuleLinear(' 78 | res = res + str(self.in_dim) + ', ' + str(self.dim_caps) + ')' 79 | res = res + line + tab + '(' + str(1) + '): ' + 'Routing(' 80 | res = res + 'num_routing=' + str(self.num_routing) + ')' 81 | res = res + line + ')' 82 | return res 83 | 84 | def forward(self, x): 85 | batch_size = x.size(0) 86 | # (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1) 87 | x = x.unsqueeze(1).unsqueeze(4) 88 | # 89 | # W @ x = 90 | # (1, num_caps, in_caps, dim_caps, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) = 91 | # (batch_size, num_caps, in_caps, dim_caps, 1) 92 | u_hat = torch.matmul(self.W, x) 93 | # (batch_size, num_caps, in_caps, dim_caps) 94 | u_hat = u_hat.squeeze(-1) 95 | # detach u_hat during routing iterations to prevent gradients from flowing 96 | temp_u_hat = u_hat.detach() 97 | 98 | ''' 99 | Procedure 1: Routing algorithm 100 | ''' 101 | b = torch.zeros(batch_size, self.num_caps, self.in_caps, 1).to(self.device) 102 | 103 | for route_iter in range(self.num_routing-1): 104 | # (batch_size, num_caps, in_caps, 1) -> Softmax along num_caps 105 | c = F.softmax(b, dim=1) 106 | 107 | # element-wise multiplication 108 | # (batch_size, num_caps, in_caps, 1) * (batch_size, in_caps, num_caps, dim_caps) -> 109 | # (batch_size, num_caps, in_caps, dim_caps) sum across in_caps -> 110 | # (batch_size, num_caps, dim_caps) 111 | s = (c * temp_u_hat).sum(dim=2) 112 | # apply "squashing" non-linearity along dim_caps 113 | v = squash(s) 114 | # dot product agreement between the current output vj and the prediction uj|i 115 | # (batch_size, num_caps, in_caps, dim_caps) @ (batch_size, num_caps, dim_caps, 1) 116 | # -> (batch_size, num_caps, in_caps, 1) 117 | uv = torch.matmul(temp_u_hat, v.unsqueeze(-1)) 118 | b += uv 119 | 120 | # last iteration is done on the original u_hat, without the routing weights update 121 | c = F.softmax(b, dim=1) 122 | s = (c * u_hat).sum(dim=2) 123 | # apply "squashing" non-linearity along dim_caps 124 | v = squash(s) 125 | 126 | return v 127 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | ######################################## 2 | #### Licensed under the MIT license #### 3 | ######################################## 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class MarginLoss(nn.Module): 10 | def __init__(self, size_average=False, loss_lambda=0.5): 11 | ''' 12 | Margin loss for digit existence 13 | Eq. (4): L_k = T_k * max(0, m+ - ||v_k||)^2 + lambda * (1 - T_k) * max(0, ||v_k|| - m-)^2 14 | 15 | Args: 16 | size_average: should the losses be averaged (True) or summed (False) over observations for each minibatch. 17 | loss_lambda: parameter for down-weighting the loss for missing digits 18 | ''' 19 | super(MarginLoss, self).__init__() 20 | self.size_average = size_average 21 | self.m_plus = 0.9 22 | self.m_minus = 0.1 23 | self.loss_lambda = loss_lambda 24 | 25 | def forward(self, inputs, labels): 26 | L_k = labels * F.relu(self.m_plus - inputs)**2 + self.loss_lambda * (1 - labels) * F.relu(inputs - self.m_minus)**2 27 | L_k = L_k.sum(dim=1) 28 | 29 | if self.size_average: 30 | return L_k.mean() 31 | else: 32 | return L_k.sum() 33 | 34 | class CapsuleLoss(nn.Module): 35 | def __init__(self, loss_lambda=0.5, recon_loss_scale=5e-4, size_average=False): 36 | ''' 37 | Combined margin loss and reconstruction loss. Margin loss see above. 38 | Sum squared error (SSE) was used as a reconstruction loss. 39 | 40 | Args: 41 | recon_loss_scale: param for scaling down the the reconstruction loss 42 | size_average: if True, reconstruction loss becomes MSE instead of SSE 43 | ''' 44 | super(CapsuleLoss, self).__init__() 45 | self.size_average = size_average 46 | self.margin_loss = MarginLoss(size_average=size_average, loss_lambda=loss_lambda) 47 | self.reconstruction_loss = nn.MSELoss(size_average=size_average) 48 | self.recon_loss_scale = recon_loss_scale 49 | 50 | def forward(self, inputs, labels, images, reconstructions): 51 | margin_loss = self.margin_loss(inputs, labels) 52 | reconstruction_loss = self.reconstruction_loss(reconstructions, images) 53 | caps_loss = (margin_loss + self.recon_loss_scale * reconstruction_loss) 54 | 55 | return caps_loss 56 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | from trainer import CapsNetTrainer 6 | import argparse 7 | 8 | DATA_PATH = os.path.join(os.environ['data']) 9 | 10 | # Collect arguments (if any) 11 | parser = argparse.ArgumentParser() 12 | 13 | # MNIST or CIFAR? 14 | parser.add_argument('dataset', nargs='?', type=str, default='MNIST', help="'MNIST' or 'CIFAR' (case insensitive).") 15 | # Batch size 16 | parser.add_argument('-bs', '--batch_size', type=int, default=128, help='Batch size.') 17 | # Epochs 18 | parser.add_argument('-e', '--epochs', type=int, default=50, help='Number of epochs.') 19 | # Learning rate 20 | parser.add_argument('-lr', '--learning_rate', type=float, default=1e-3, help='Learning rate.') 21 | # Number of routing iterations 22 | parser.add_argument('--num_routing', type=int, default=3, help='Number of routing iteration in routing capsules.') 23 | # Exponential learning rate decay 24 | parser.add_argument('--lr_decay', type=float, default=0.96, help='Exponential learning rate decay.') 25 | # Select device "cuda" for GPU or "cpu" 26 | parser.add_argument('--device', type=str, default=("cuda" if torch.cuda.is_available() else "cpu"), choices=['cuda', 'cpu'], help='Device to use. Choose "cuda" for GPU or "cpu".') 27 | # Use multiple GPUs? 28 | parser.add_argument('--multi_gpu', action='store_true', help='Flag whether to use multiple GPUs.') 29 | # Select GPU device 30 | parser.add_argument('--gpu_device', type=int, default=None, help='ID of a GPU to use when multiple GPUs are available.') 31 | # Data directory 32 | parser.add_argument('--data_path', type=str, default=DATA_PATH, help='Path to the MNIST or CIFAR dataset. Alternatively you can set the path as an environmental variable $data.') 33 | args = parser.parse_args() 34 | 35 | device = torch.device(args.device) 36 | 37 | if args.gpu_device is not None: 38 | torch.cuda.set_device(args.gpu_device) 39 | 40 | if args.multi_gpu: 41 | args.batch_size *= torch.cuda.device_count() 42 | 43 | datasets = { 44 | 'MNIST': torchvision.datasets.MNIST, 45 | 'CIFAR': torchvision.datasets.CIFAR10 46 | } 47 | 48 | if args.dataset.upper() == 'MNIST': 49 | args.data_path = os.path.join(args.data_path, 'MNIST') 50 | size = 28 51 | classes = list(range(10)) 52 | mean, std = ( ( 0.1307,), ( 0.3081,) ) 53 | elif args.dataset.upper() == 'CIFAR': 54 | args.data_path = os.path.join(args.data_path, 'CIFAR') 55 | size = 32 56 | classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 57 | mean, std = ( (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) ) 58 | else: 59 | raise ValueError('Dataset must be either MNIST or CIFAR') 60 | 61 | transform = transforms.Compose([ 62 | # shift by 2 pixels in either direction with zero padding. 63 | transforms.RandomCrop(size, padding=2), 64 | transforms.ToTensor(), 65 | transforms.Normalize( mean, std ) 66 | ]) 67 | 68 | loaders = {} 69 | trainset = datasets[args.dataset.upper()](root=args.data_path, train=True, download=True, transform=transform) 70 | loaders['train'] = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2) 71 | 72 | testset = datasets[args.dataset.upper()](root=args.data_path, train=False, download=True, transform=transform) 73 | loaders['test'] = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=2) 74 | print(8*'#', f'Using {args.dataset.upper()} dataset', 8*'#') 75 | 76 | # Run 77 | caps_net = CapsNetTrainer(loaders, args.batch_size, args.learning_rate, args.num_routing, args.lr_decay, device=device, multi_gpu=args.multi_gpu) 78 | caps_net.run(args.epochs, classes=classes) 79 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | ######################################## 2 | #### Licensed under the MIT license #### 3 | ######################################## 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from numpy import prod 9 | import capsules as caps 10 | 11 | class CapsuleNetwork(nn.Module): 12 | def __init__(self, img_shape, channels, primary_dim, num_classes, out_dim, num_routing, device: torch.device, kernel_size=9): 13 | super(CapsuleNetwork, self).__init__() 14 | self.img_shape = img_shape 15 | self.num_classes = num_classes 16 | self.device = device 17 | 18 | self.conv1 = nn.Conv2d(img_shape[0], channels, kernel_size, stride=1, bias=True) 19 | self.relu = nn.ReLU(inplace=True) 20 | 21 | self.primary = caps.PrimaryCapsules(channels, channels, primary_dim, kernel_size) 22 | 23 | primary_caps = int(channels / primary_dim * ( img_shape[1] - 2*(kernel_size-1) ) * ( img_shape[2] - 2*(kernel_size-1) ) / 4) 24 | self.digits = caps.RoutingCapsules(primary_dim, primary_caps, num_classes, out_dim, num_routing, device=self.device) 25 | 26 | self.decoder = nn.Sequential( 27 | nn.Linear(out_dim * num_classes, 512), 28 | nn.ReLU(inplace=True), 29 | nn.Linear(512, 1024), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(1024, int(prod(img_shape)) ), 32 | nn.Sigmoid() 33 | ) 34 | 35 | def forward(self, x): 36 | out = self.conv1(x) 37 | out = self.relu(out) 38 | out = self.primary(out) 39 | out = self.digits(out) 40 | preds = torch.norm(out, dim=-1) 41 | 42 | # Reconstruct the *predicted* image 43 | _, max_length_idx = preds.max(dim=1) 44 | y = torch.eye(self.num_classes).to(self.device) 45 | y = y.index_select(dim=0, index=max_length_idx).unsqueeze(2) 46 | 47 | reconstructions = self.decoder( (out*y).view(out.size(0), -1) ) 48 | reconstructions = reconstructions.view(-1, *self.img_shape) 49 | 50 | return preds, reconstructions 51 | -------------------------------------------------------------------------------- /reconstructed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielhavir/capsule-network/1c37c2f10d3485672fd609748502a45d40e54bd7/reconstructed.png -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | ######################################## 2 | #### Licensed under the MIT license #### 3 | ######################################## 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import os 9 | from numpy import prod 10 | from datetime import datetime 11 | from model import CapsuleNetwork 12 | from loss import CapsuleLoss 13 | from time import time 14 | 15 | SAVE_MODEL_PATH = 'checkpoints/' 16 | if not os.path.exists(SAVE_MODEL_PATH): 17 | os.mkdir(SAVE_MODEL_PATH) 18 | 19 | class CapsNetTrainer: 20 | """ 21 | Wrapper object for handling training and evaluation 22 | """ 23 | def __init__(self, loaders, batch_size, learning_rate, num_routing=3, lr_decay=0.9, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), multi_gpu=(torch.cuda.device_count() > 1)): 24 | self.device = device 25 | self.multi_gpu = multi_gpu 26 | 27 | self.loaders = loaders 28 | img_shape = self.loaders['train'].dataset[0][0].numpy().shape 29 | 30 | self.net = CapsuleNetwork(img_shape=img_shape, channels=256, primary_dim=8, num_classes=10, out_dim=16, num_routing=num_routing, device=self.device).to(self.device) 31 | 32 | if self.multi_gpu: 33 | self.net = nn.DataParallel(self.net) 34 | 35 | self.criterion = CapsuleLoss(loss_lambda=0.5, recon_loss_scale=5e-4) 36 | self.optimizer = optim.Adam(self.net.parameters(), lr=learning_rate) 37 | self.scheduler = optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=lr_decay) 38 | print(8*'#', 'PyTorch Model built'.upper(), 8*'#') 39 | print('Num params:', sum([prod(p.size()) for p in self.net.parameters()])) 40 | 41 | def __repr__(self): 42 | return repr(self.net) 43 | 44 | def run(self, epochs, classes): 45 | print(8*'#', 'Run started'.upper(), 8*'#') 46 | eye = torch.eye(len(classes)).to(self.device) 47 | 48 | for epoch in range(1, epochs+1): 49 | for phase in ['train', 'test']: 50 | print(f'{phase}ing...'.capitalize()) 51 | if phase == 'train': 52 | self.net.train() 53 | else: 54 | self.net.eval() 55 | 56 | t0 = time() 57 | running_loss = 0.0 58 | correct = 0; total = 0 59 | for i, (images, labels) in enumerate(self.loaders[phase]): 60 | t1 = time() 61 | images, labels = images.to(self.device), labels.to(self.device) 62 | # One-hot encode labels 63 | labels = eye[labels] 64 | 65 | self.optimizer.zero_grad() 66 | 67 | outputs, reconstructions = self.net(images) 68 | loss = self.criterion(outputs, labels, images, reconstructions) 69 | 70 | if phase == 'train': 71 | loss.backward() 72 | self.optimizer.step() 73 | 74 | running_loss += loss.item() 75 | 76 | _, predicted = torch.max(outputs, 1) 77 | _, labels = torch.max(labels, 1) 78 | total += labels.size(0) 79 | correct += (predicted == labels).sum() 80 | accuracy = float(correct) / float(total) 81 | 82 | if phase == 'train': 83 | print(f'Epoch {epoch}, Batch {i+1}, Loss {running_loss/(i+1)}', 84 | f'Accuracy {accuracy} Time {round(time()-t1, 3)}s') 85 | 86 | print(f'{phase.upper()} Epoch {epoch}, Loss {running_loss/(i+1)}', 87 | f'Accuracy {accuracy} Time {round(time()-t0, 3)}s') 88 | 89 | self.scheduler.step() 90 | 91 | now = str(datetime.now()).replace(" ", "-") 92 | error_rate = round((1-accuracy)*100, 2) 93 | torch.save(self.net.state_dict(), os.path.join(SAVE_MODEL_PATH, f'{error_rate}_{now}.pth.tar')) 94 | 95 | class_correct = list(0. for _ in classes) 96 | class_total = list(0. for _ in classes) 97 | for images, labels in self.loaders['test']: 98 | images, labels = images.to(self.device), labels.to(self.device) 99 | 100 | outputs, reconstructions = self.net(images) 101 | _, predicted = torch.max(outputs, 1) 102 | c = (predicted == labels).squeeze() 103 | for i in range(labels.size(0)): 104 | label = labels[i] 105 | class_correct[label] += c[i].item() 106 | class_total[label] += 1 107 | 108 | 109 | for i in range(len(classes)): 110 | print('Accuracy of %5s : %2d %%' % ( 111 | classes[i], 100 * class_correct[i] / class_total[i])) 112 | --------------------------------------------------------------------------------