├── LICENSE
├── README.md
├── capsnet.py
├── capsule.py
├── config.py
├── data
└── .gitkeep
├── epochs
└── .gitkeep
├── loss.py
├── main.py
├── results
├── confusion_matrix.png
├── ground_truth.jpg
├── reconstruction.jpg
├── test_acc.png
├── test_loss.png
├── train_acc.png
└── train_loss.png
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 leftthomas
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CapsNet
2 | A PyTorch implementation of CapsNet based on NIPS 2017 paper [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829).
3 |
4 | ## Requirements
5 | - [Anaconda](https://www.anaconda.com/download/)
6 | - PyTorch
7 | ```
8 | conda install pytorch torchvision -c soumith
9 | conda install pytorch torchvision cuda80 -c soumith # install it if you have installed cuda
10 | ```
11 | - PyTorchNet
12 | ```
13 | pip install git+https://github.com/pytorch/tnt.git@master
14 | ```
15 |
16 | ## Usage
17 |
18 | ```
19 | git clone https://github.com/leftthomas/CapsNet.git
20 | cd CapsNet
21 | python -m visdom.server & python main.py
22 | ```
23 | Visdom now can be accessed by going to `127.0.0.1:8097` in your browser, or your own host address if specified.
24 |
25 | ## Benchmarks
26 | Highest accuracy was 99.57% after 30 epochs. The model may achieve a higher accuracy as shown by the trend of the loss/accuracy graphs below.
27 |
28 |
29 |
30 |
31 | |
32 |
33 |
34 | |
35 |
36 |
37 |
38 |
39 |
40 |
41 | |
42 |
43 |
44 | |
45 |
46 |
47 |
48 | The confusion matrix of the digit numbers are showed below.
49 |
50 |
51 | The reconstructions of the digit numbers are showed at right and the ground truth at left.
52 |
53 |
54 |
55 |
56 | |
57 |
58 |
59 | |
60 |
61 |
62 |
63 | Default PyTorch Adam optimizer hyperparameters were used with no learning rate scheduling. Epochs with batch size of 100 takes ~2 minutes on a NVIDIA GTX 1070 GPU.
64 |
65 | ## Other Implementations
66 | - [capsnet.pytorch](https://github.com/andreaazzini/capsnet.pytorch.git)
67 |
68 | - [CapsNet-Keras](https://github.com/naturomics/XifengGuo/CapsNet-Keras.git)
69 |
70 | - [CapsNet-Tensorflow](https://github.com/naturomics/CapsNet-Tensorflow.git)
71 |
72 | ## Credits
73 | Primarily referenced this implementation:
74 | [PyTorch implementation by @Gram.AI](https://github.com/gram-ai/capsule-networks)
75 |
--------------------------------------------------------------------------------
/capsnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch.autograd import Variable
5 |
6 | import config
7 | from capsule import CapsuleLayer
8 |
9 |
10 | class CapsuleNet(nn.Module):
11 | def __init__(self):
12 | super(CapsuleNet, self).__init__()
13 |
14 | self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
15 | self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32,
16 | kernel_size=9, stride=2)
17 | self.digit_capsules = CapsuleLayer(num_capsules=config.NUM_CLASSES, num_route_nodes=32 * 6 * 6, in_channels=8,
18 | out_channels=16)
19 |
20 | self.decoder = nn.Sequential(
21 | nn.Linear(16 * config.NUM_CLASSES, 512),
22 | nn.ReLU(inplace=True),
23 | nn.Linear(512, 1024),
24 | nn.ReLU(inplace=True),
25 | nn.Linear(1024, 784),
26 | nn.Sigmoid()
27 | )
28 |
29 | def forward(self, x, y=None):
30 | x = F.relu(self.conv1(x), inplace=True)
31 | x = self.primary_capsules(x)
32 | x = self.digit_capsules(x).squeeze().transpose(0, 1)
33 |
34 | classes = (x ** 2).sum(dim=-1) ** 0.5
35 | classes = F.softmax(classes, dim=-1)
36 |
37 | if y is None:
38 | # In all batches, get the most active capsule.
39 | _, max_length_indices = classes.max(dim=1)
40 | if torch.cuda.is_available():
41 | y = Variable(torch.eye(config.NUM_CLASSES)).cuda().index_select(dim=0, index=max_length_indices)
42 | else:
43 | y = Variable(torch.eye(config.NUM_CLASSES)).index_select(dim=0, index=max_length_indices)
44 | reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
45 |
46 | return classes, reconstructions
47 |
48 |
49 | if __name__ == "__main__":
50 | model = CapsuleNet()
51 | print(model)
52 |
--------------------------------------------------------------------------------
/capsule.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from torch.autograd import Variable
5 |
6 | import config
7 |
8 |
9 | class CapsuleLayer(nn.Module):
10 | def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None,
11 | num_iterations=config.NUM_ROUTING_ITERATIONS):
12 | super(CapsuleLayer, self).__init__()
13 |
14 | self.num_route_nodes = num_route_nodes
15 | self.num_iterations = num_iterations
16 |
17 | self.num_capsules = num_capsules
18 |
19 | if num_route_nodes != -1:
20 | self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
21 | else:
22 | self.capsules = nn.ModuleList(
23 | [nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in
24 | range(num_capsules)])
25 |
26 | @staticmethod
27 | def squash(tensor, dim=-1):
28 | squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
29 | scale = squared_norm / (1 + squared_norm)
30 | return scale * tensor / torch.sqrt(squared_norm)
31 |
32 | def forward(self, x):
33 | if self.num_route_nodes != -1:
34 | priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
35 | logits = Variable(torch.zeros(*priors.size()))
36 | if torch.cuda.is_available():
37 | logits = logits.cuda()
38 | for i in range(self.num_iterations):
39 | probs = F.softmax(logits, dim=2)
40 | outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))
41 |
42 | if i != self.num_iterations - 1:
43 | delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
44 | logits = logits + delta_logits
45 | else:
46 | outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
47 | outputs = torch.cat(outputs, dim=-1)
48 | outputs = self.squash(outputs)
49 |
50 | return outputs
51 |
52 |
53 | if __name__ == "__main__":
54 | primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32,
55 | kernel_size=9, stride=2)
56 | print(primary_capsules)
57 | digit_capsules = CapsuleLayer(num_capsules=config.NUM_CLASSES, num_route_nodes=32 * 6 * 6, in_channels=8,
58 | out_channels=16)
59 | print(digit_capsules)
60 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | BATCH_SIZE = 100
2 | NUM_CLASSES = 10
3 | NUM_EPOCHS = 100
4 | NUM_ROUTING_ITERATIONS = 3
5 |
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/data/.gitkeep
--------------------------------------------------------------------------------
/epochs/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/epochs/.gitkeep
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch import nn
3 |
4 |
5 | class CapsuleLoss(nn.Module):
6 | def __init__(self):
7 | super(CapsuleLoss, self).__init__()
8 | self.reconstruction_loss = nn.MSELoss(size_average=False)
9 |
10 | def forward(self, images, labels, classes, reconstructions):
11 | left = F.relu(0.9 - classes, inplace=True) ** 2
12 | right = F.relu(classes - 0.1, inplace=True) ** 2
13 |
14 | margin_loss = labels * left + 0.5 * (1. - labels) * right
15 | margin_loss = margin_loss.sum()
16 |
17 | reconstruction_loss = self.reconstruction_loss(reconstructions, images)
18 |
19 | return (margin_loss + 0.0005 * reconstruction_loss) / images.size(0)
20 |
21 |
22 | if __name__ == "__main__":
23 | digit_loss = CapsuleLoss()
24 | print(digit_loss)
25 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchnet as tnt
3 | from torch.autograd import Variable
4 | from torch.optim import Adam
5 | from torchnet.engine import Engine
6 | from torchnet.logger import VisdomPlotLogger, VisdomLogger
7 | from torchvision.utils import make_grid
8 | from tqdm import tqdm
9 |
10 | import config
11 | import utils
12 | from capsnet import CapsuleNet
13 | from loss import CapsuleLoss
14 |
15 |
16 | def processor(sample):
17 | data, labels, training = sample
18 |
19 | data = utils.augmentation(data.unsqueeze(1).float() / 255.0)
20 | labels = torch.eye(config.NUM_CLASSES).index_select(dim=0, index=labels)
21 |
22 | data = Variable(data)
23 | labels = Variable(labels)
24 | if torch.cuda.is_available():
25 | data = data.cuda()
26 | labels = labels.cuda()
27 |
28 | if training:
29 | classes, reconstructions = model(data, labels)
30 | else:
31 | classes, reconstructions = model(data)
32 |
33 | loss = capsule_loss(data, labels, classes, reconstructions)
34 |
35 | return loss, classes
36 |
37 |
38 | def on_sample(state):
39 | state['sample'].append(state['train'])
40 |
41 |
42 | def reset_meters():
43 | meter_accuracy.reset()
44 | meter_loss.reset()
45 | confusion_meter.reset()
46 |
47 |
48 | def on_forward(state):
49 | meter_accuracy.add(state['output'].data, state['sample'][1])
50 | confusion_meter.add(state['output'].data, state['sample'][1])
51 | meter_loss.add(state['loss'].data[0])
52 |
53 |
54 | def on_start_epoch(state):
55 | reset_meters()
56 | state['iterator'] = tqdm(state['iterator'])
57 |
58 |
59 | def on_end_epoch(state):
60 | print('[Epoch %d] Training Loss: %.4f (Accuracy: %.2f%%)' % (
61 | state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))
62 |
63 | train_loss_logger.log(state['epoch'], meter_loss.value()[0])
64 | train_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])
65 |
66 | reset_meters()
67 |
68 | engine.test(processor, utils.get_iterator(False))
69 | test_loss_logger.log(state['epoch'], meter_loss.value()[0])
70 | test_accuracy_logger.log(state['epoch'], meter_accuracy.value()[0])
71 | confusion_logger.log(confusion_meter.value())
72 |
73 | print('[Epoch %d] Testing Loss: %.4f (Accuracy: %.2f%%)' % (
74 | state['epoch'], meter_loss.value()[0], meter_accuracy.value()[0]))
75 |
76 | torch.save(model.state_dict(), 'epochs/epoch_%d.pt' % state['epoch'])
77 |
78 | # reconstruction visualization
79 |
80 | test_sample = next(iter(utils.get_iterator(False)))
81 |
82 | ground_truth = (test_sample[0].unsqueeze(1).float() / 255.0)
83 | if torch.cuda.is_available():
84 | _, reconstructions = model(Variable(ground_truth).cuda())
85 | else:
86 | _, reconstructions = model(Variable(ground_truth))
87 | reconstruction = reconstructions.cpu().view_as(ground_truth).data
88 |
89 | ground_truth_logger.log(
90 | make_grid(ground_truth, nrow=int(config.BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
91 | reconstruction_logger.log(
92 | make_grid(reconstruction, nrow=int(config.BATCH_SIZE ** 0.5), normalize=True, range=(0, 1)).numpy())
93 |
94 |
95 | if __name__ == "__main__":
96 | model = CapsuleNet()
97 | if torch.cuda.is_available():
98 | model.cuda()
99 |
100 | print("# parameters:", sum(param.numel() for param in model.parameters()))
101 |
102 | optimizer = Adam(model.parameters())
103 |
104 | engine = Engine()
105 | meter_loss = tnt.meter.AverageValueMeter()
106 | meter_accuracy = tnt.meter.ClassErrorMeter(accuracy=True)
107 | confusion_meter = tnt.meter.ConfusionMeter(config.NUM_CLASSES, normalized=True)
108 |
109 | train_loss_logger = VisdomPlotLogger('line', opts={'title': 'Train Loss'})
110 | train_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Train Accuracy'})
111 | test_loss_logger = VisdomPlotLogger('line', opts={'title': 'Test Loss'})
112 | test_accuracy_logger = VisdomPlotLogger('line', opts={'title': 'Test Accuracy'})
113 | confusion_logger = VisdomLogger('heatmap', opts={'title': 'Confusion Matrix',
114 | 'columnnames': list(range(config.NUM_CLASSES)),
115 | 'rownames': list(range(config.NUM_CLASSES))})
116 | ground_truth_logger = VisdomLogger('image', opts={'title': 'Ground Truth'})
117 | reconstruction_logger = VisdomLogger('image', opts={'title': 'Reconstruction'})
118 |
119 | capsule_loss = CapsuleLoss()
120 |
121 | engine.hooks['on_sample'] = on_sample
122 | engine.hooks['on_forward'] = on_forward
123 | engine.hooks['on_start_epoch'] = on_start_epoch
124 | engine.hooks['on_end_epoch'] = on_end_epoch
125 |
126 | engine.train(processor, utils.get_iterator(True), maxepoch=config.NUM_EPOCHS, optimizer=optimizer)
127 |
--------------------------------------------------------------------------------
/results/confusion_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/confusion_matrix.png
--------------------------------------------------------------------------------
/results/ground_truth.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/ground_truth.jpg
--------------------------------------------------------------------------------
/results/reconstruction.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/reconstruction.jpg
--------------------------------------------------------------------------------
/results/test_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/test_acc.png
--------------------------------------------------------------------------------
/results/test_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/test_loss.png
--------------------------------------------------------------------------------
/results/train_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/train_acc.png
--------------------------------------------------------------------------------
/results/train_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/CapsNet/5de2f45daadbe4377df4ccf8a4d31683d7f397bf/results/train_loss.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torchnet as tnt
4 | from torchvision.datasets.mnist import MNIST
5 |
6 | import config
7 |
8 |
9 | def augmentation(x, max_shift=2):
10 | _, _, height, width = x.size()
11 |
12 | h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2)
13 | source_height_slice = slice(max(0, h_shift), h_shift + height)
14 | source_width_slice = slice(max(0, w_shift), w_shift + width)
15 | target_height_slice = slice(max(0, -h_shift), -h_shift + height)
16 | target_width_slice = slice(max(0, -w_shift), -w_shift + width)
17 |
18 | shifted_image = torch.zeros(*x.size())
19 | shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice]
20 | return shifted_image.float()
21 |
22 |
23 | def get_iterator(mode):
24 | dataset = MNIST(root='./data', train=mode, download=True)
25 | data = getattr(dataset, 'train_data' if mode else 'test_data')
26 | labels = getattr(dataset, 'train_labels' if mode else 'test_labels')
27 | tensor_dataset = tnt.dataset.TensorDataset([data, labels])
28 |
29 | return tensor_dataset.parallel(batch_size=config.BATCH_SIZE, num_workers=4, shuffle=mode)
30 |
31 |
32 | if __name__ == "__main__":
33 | t = torch.rand(1, 1, 28, 28)
34 | print(t)
35 | y = augmentation(t)
36 | print(y)
37 |
--------------------------------------------------------------------------------