├── LICENSE ├── README.md ├── capsnet.py ├── capsule.py ├── config.py ├── data └── .gitkeep ├── epochs └── .gitkeep ├── loss.py ├── main.py ├── results ├── confusion_matrix.png ├── ground_truth.jpg ├── reconstruction.jpg ├── test_acc.png ├── test_loss.png ├── train_acc.png └── train_loss.png └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 leftthomas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CapsNet 2 | A PyTorch implementation of CapsNet based on NIPS 2017 paper [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829). 3 | 4 | ## Requirements 5 | - [Anaconda](https://www.anaconda.com/download/) 6 | - PyTorch 7 | ``` 8 | conda install pytorch torchvision -c soumith 9 | conda install pytorch torchvision cuda80 -c soumith # install it if you have installed cuda 10 | ``` 11 | - PyTorchNet 12 | ``` 13 | pip install git+https://github.com/pytorch/tnt.git@master 14 | ``` 15 | 16 | ## Usage 17 | 18 | ``` 19 | git clone https://github.com/leftthomas/CapsNet.git 20 | cd CapsNet 21 | python -m visdom.server & python main.py 22 | ``` 23 | Visdom now can be accessed by going to `127.0.0.1:8097` in your browser, or your own host address if specified. 24 | 25 | ## Benchmarks 26 | Highest accuracy was 99.57% after 30 epochs. The model may achieve a higher accuracy as shown by the trend of the loss/accuracy graphs below. 27 | 28 | 29 | 32 | 35 | 36 |
30 | 31 | 33 | 34 |
37 | 38 | 39 | 42 | 45 | 46 |
40 | 41 | 43 | 44 |
47 | 48 | The confusion matrix of the digit numbers are showed below. 49 | 50 | 51 | The reconstructions of the digit numbers are showed at right and the ground truth at left. 52 | 53 | 54 | 57 | 60 | 61 |
55 | 56 | 58 | 59 |
62 | 63 | Default PyTorch Adam optimizer hyperparameters were used with no learning rate scheduling. Epochs with batch size of 100 takes ~2 minutes on a NVIDIA GTX 1070 GPU. 64 | 65 | ## Other Implementations 66 | - [capsnet.pytorch](https://github.com/andreaazzini/capsnet.pytorch.git) 67 | 68 | - [CapsNet-Keras](https://github.com/naturomics/XifengGuo/CapsNet-Keras.git) 69 | 70 | - [CapsNet-Tensorflow](https://github.com/naturomics/CapsNet-Tensorflow.git) 71 | 72 | ## Credits 73 | Primarily referenced this implementation: 74 | [PyTorch implementation by @Gram.AI](https://github.com/gram-ai/capsule-networks) 75 | -------------------------------------------------------------------------------- /capsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.autograd import Variable 5 | 6 | import config 7 | from capsule import CapsuleLayer 8 | 9 | 10 | class CapsuleNet(nn.Module): 11 | def __init__(self): 12 | super(CapsuleNet, self).__init__() 13 | 14 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1) 15 | self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32, 16 | kernel_size=9, stride=2) 17 | self.digit_capsules = CapsuleLayer(num_capsules=config.NUM_CLASSES, num_route_nodes=32 * 6 * 6, in_channels=8, 18 | out_channels=16) 19 | 20 | self.decoder = nn.Sequential( 21 | nn.Linear(16 * config.NUM_CLASSES, 512), 22 | nn.ReLU(inplace=True), 23 | nn.Linear(512, 1024), 24 | nn.ReLU(inplace=True), 25 | nn.Linear(1024, 784), 26 | nn.Sigmoid() 27 | ) 28 | 29 | def forward(self, x, y=None): 30 | x = F.relu(self.conv1(x), inplace=True) 31 | x = self.primary_capsules(x) 32 | x = self.digit_capsules(x).squeeze().transpose(0, 1) 33 | 34 | classes = (x ** 2).sum(dim=-1) ** 0.5 35 | classes = F.softmax(classes, dim=-1) 36 | 37 | if y is None: 38 | # In all batches, get the most active capsule. 39 | _, max_length_indices = classes.max(dim=1) 40 | if torch.cuda.is_available(): 41 | y = Variable(torch.eye(config.NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices) 42 | else: 43 | y = Variable(torch.eye(config.NUM_CLASSES)).index_select(dim=0, index=max_length_indices) 44 | reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1)) 45 | 46 | return classes, reconstructions 47 | 48 | 49 | if __name__ == "__main__": 50 | model = CapsuleNet() 51 | print(model) 52 | -------------------------------------------------------------------------------- /capsule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.autograd import Variable 5 | 6 | import config 7 | 8 | 9 | class CapsuleLayer(nn.Module): 10 | def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None, 11 | num_iterations=config.NUM_ROUTING_ITERATIONS): 12 | super(CapsuleLayer, self).__init__() 13 | 14 | self.num_route_nodes = num_route_nodes 15 | self.num_iterations = num_iterations 16 | 17 | self.num_capsules = num_capsules 18 | 19 | if num_route_nodes != -1: 20 | self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels)) 21 | else: 22 | self.capsules = nn.ModuleList( 23 | [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in 24 | range(num_capsules)]) 25 | 26 | @staticmethod 27 | def squash(tensor, dim=-1): 28 | squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True) 29 | scale = squared_norm / (1 + squared_norm) 30 | return scale * tensor / torch.sqrt(squared_norm) 31 | 32 | def forward(self, x): 33 | if self.num_route_nodes != -1: 34 | priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :] 35 | logits = Variable(torch.zeros(*priors.size())) 36 | if torch.cuda.is_available(): 37 | logits = logits.cuda() 38 | for i in range(self.num_iterations): 39 | probs = F.softmax(logits, dim=2) 40 | outputs = self.squash((probs * priors).sum(dim=2, keepdim=True)) 41 | 42 | if i != self.num_iterations - 1: 43 | delta_logits = (priors * outputs).sum(dim=-1, keepdim=True) 44 | logits = logits + delta_logits 45 | else: 46 | outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules] 47 | outputs = torch.cat(outputs, dim=-1) 48 | outputs = self.squash(outputs) 49 | 50 | return outputs 51 | 52 | 53 | if __name__ == "__main__": 54 | primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32, 55 | kernel_size=9, stride=2) 56 | print(primary_capsules) 57 | digit_capsules = CapsuleLayer(num_capsules=config.NUM_CLASSES, num_route_nodes=32 * 6 * 6, in_channels=8, 58 | out_channels=16) 59 | print(digit_capsules) 60 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | BATCH_SIZE = 100 2 | NUM_CLASSES = 10 3 | NUM_EPOCHS = 100 4 | NUM_ROUTING_ITERATIONS = 3 5 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/data/.gitkeep -------------------------------------------------------------------------------- /epochs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/epochs/.gitkeep -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | 5 | class CapsuleLoss(nn.Module): 6 | def __init__(self): 7 | super(CapsuleLoss, self).__init__() 8 | self.reconstruction_loss = nn.MSELoss(size_average=False) 9 | 10 | def forward(self, images, labels, classes, reconstructions): 11 | left = F.relu(0.9 - classes, inplace=True) ** 2 12 | right = F.relu(classes - 0.1, inplace=True) ** 2 13 | 14 | margin_loss = labels * left + 0.5 * (1. - labels) * right 15 | margin_loss = margin_loss.sum() 16 | 17 | reconstruction_loss = self.reconstruction_loss(reconstructions, images) 18 | 19 | return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0) 20 | 21 | 22 | if __name__ == "__main__": 23 | digit_loss = CapsuleLoss() 24 | print(digit_loss) 25 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchnet as tnt 3 | from torch.autograd import Variable 4 | from torch.optim import Adam 5 | from torchnet.engine import Engine 6 | from torchnet.logger import VisdomPlotLogger, VisdomLogger 7 | from torchvision.utils import make_grid 8 | from tqdm import tqdm 9 | 10 | import config 11 | import utils 12 | from capsnet import CapsuleNet 13 | from loss import CapsuleLoss 14 | 15 | 16 | def processor(sample): 17 | data, labels, training = sample 18 | 19 | data = utils.augmentation(data.unsqueeze(1).float() / 255.0) 20 | labels = torch.eye(config.NUM_CLASSES).index_select(dim=0, index=labels) 21 | 22 | data = Variable(data) 23 | labels = Variable(labels) 24 | if torch.cuda.is_available(): 25 | data = data.cuda() 26 | labels = labels.cuda() 27 | 28 | if training: 29 | classes, reconstructions = model(data, labels) 30 | else: 31 | classes, reconstructions = model(data) 32 | 33 | loss = capsule_loss(data, labels, classes, reconstructions) 34 | 35 | return loss, classes 36 | 37 | 38 | def on_sample(state): 39 | state['sample'].append(state['train']) 40 | 41 | 42 | def reset_meters(): 43 | meter_accuracy.reset() 44 | meter_loss.reset() 45 | confusion_meter.reset() 46 | 47 | 48 | def on_forward(state): 49 | meter_accuracy.add(state['output'].data, state['sample'][1]) 50 | confusion_meter.add(state['output'].data, state['sample'][1]) 51 | meter_loss.add(state['loss'].data[0]) 52 | 53 | 54 | def on_start_epoch(state): 55 | reset_meters() 56 | state['iterator'] = tqdm(state['iterator']) 57 | 58 | 59 | def on_end_epoch(state): 60 | print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % ( 61 | state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0])) 62 | 63 | train_loss_logger.log(state['epoch'], meter_loss.value()[0]) 64 | train_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0]) 65 | 66 | reset_meters() 67 | 68 | engine.test(processor, utils.get_iterator(False)) 69 | test_loss_logger.log(state['epoch'], meter_loss.value()[0]) 70 | test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0]) 71 | confusion_logger.log(confusion_meter.value()) 72 | 73 | print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % ( 74 | state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0])) 75 | 76 | torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch']) 77 | 78 | # reconstruction visualization 79 | 80 | test_sample = next(iter(utils.get_iterator(False))) 81 | 82 | ground_truth = (test_sample[0].unsqueeze(1).float() / 255.0) 83 | if torch.cuda.is_available(): 84 | _, reconstructions = model(Variable(ground_truth).cuda()) 85 | else: 86 | _, reconstructions = model(Variable(ground_truth)) 87 | reconstruction = reconstructions.cpu().view_as(ground_truth).data 88 | 89 | ground_truth_logger.log( 90 | make_grid(ground_truth, nrow=int(config.BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy()) 91 | reconstruction_logger.log( 92 | make_grid(reconstruction, nrow=int(config.BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy()) 93 | 94 | 95 | if __name__ == "__main__": 96 | model = CapsuleNet() 97 | if torch.cuda.is_available(): 98 | model.cuda() 99 | 100 | print("# parameters:", sum(param.numel() for param in model.parameters())) 101 | 102 | optimizer = Adam(model.parameters()) 103 | 104 | engine = Engine() 105 | meter_loss = tnt.meter.AverageValueMeter() 106 | meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True) 107 | confusion_meter = tnt.meter.ConfusionMeter(config.NUM_CLASSES, normalized=True) 108 | 109 | train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'}) 110 | train_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'}) 111 | test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'}) 112 | test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'}) 113 | confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion Matrix', 114 | 'columnnames': list(range(config.NUM_CLASSES)), 115 | 'rownames': list(range(config.NUM_CLASSES))}) 116 | ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'}) 117 | reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'}) 118 | 119 | capsule_loss = CapsuleLoss() 120 | 121 | engine.hooks['on_sample'] = on_sample 122 | engine.hooks['on_forward'] = on_forward 123 | engine.hooks['on_start_epoch'] = on_start_epoch 124 | engine.hooks['on_end_epoch'] = on_end_epoch 125 | 126 | engine.train(processor, utils.get_iterator(True), maxepoch=config.NUM_EPOCHS, optimizer=optimizer) 127 | -------------------------------------------------------------------------------- /results/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/confusion_matrix.png -------------------------------------------------------------------------------- /results/ground_truth.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/ground_truth.jpg -------------------------------------------------------------------------------- /results/reconstruction.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/reconstruction.jpg -------------------------------------------------------------------------------- /results/test_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/test_acc.png -------------------------------------------------------------------------------- /results/test_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/test_loss.png -------------------------------------------------------------------------------- /results/train_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/train_acc.png -------------------------------------------------------------------------------- /results/train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/train_loss.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchnet as tnt 4 | from torchvision.datasets.mnist import MNIST 5 | 6 | import config 7 | 8 | 9 | def augmentation(x, max_shift=2): 10 | _, _, height, width = x.size() 11 | 12 | h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2) 13 | source_height_slice = slice(max(0, h_shift), h_shift + height) 14 | source_width_slice = slice(max(0, w_shift), w_shift + width) 15 | target_height_slice = slice(max(0, -h_shift), -h_shift + height) 16 | target_width_slice = slice(max(0, -w_shift), -w_shift + width) 17 | 18 | shifted_image = torch.zeros(*x.size()) 19 | shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice] 20 | return shifted_image.float() 21 | 22 | 23 | def get_iterator(mode): 24 | dataset = MNIST(root='./data', train=mode, download=True) 25 | data = getattr(dataset, 'train_data' if mode else 'test_data') 26 | labels = getattr(dataset, 'train_labels' if mode else 'test_labels') 27 | tensor_dataset = tnt.dataset.TensorDataset([data, labels]) 28 | 29 | return tensor_dataset.parallel(batch_size=config.BATCH_SIZE, num_workers=4, shuffle=mode) 30 | 31 | 32 | if __name__ == "__main__": 33 | t = torch.rand(1, 1, 28, 28) 34 | print(t) 35 | y = augmentation(t) 36 | print(y) 37 | --------------------------------------------------------------------------------