├── .gitignore ├── requirements.txt ├── README.md ├── LICENSE ├── main.py └── capsnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | *DS_Store 2 | *idea/ 3 | data/ 4 | model/ 5 | *pyc 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm==4.42.1 2 | torchvision==0.5.0 3 | torch==1.4.0 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CapsNet 2 | > A pytorch implementation of capsule network. 3 | > 4 | > Reference: https://arxiv.org/abs/1710.09829 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Riroaki 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from torch.utils.data import DataLoader 4 | from torchvision.datasets import MNIST 5 | from torchvision.transforms import transforms 6 | from capsnet import CapsNet, CapsuleLoss 7 | 8 | # Check cuda availability 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | 12 | def main(): 13 | # Load model 14 | model = CapsNet().to(device) 15 | criterion = CapsuleLoss() 16 | optimizer = optim.Adam(model.parameters(), lr=1e-3) 17 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.96) 18 | 19 | # Load data 20 | transform = transforms.Compose([ 21 | # shift by 2 pixels in either direction with zero padding. 22 | transforms.RandomCrop(28, padding=2), 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.1307,), (0.3081,)) 25 | ]) 26 | DATA_PATH = './data' 27 | BATCH_SIZE = 128 28 | train_loader = DataLoader( 29 | dataset=MNIST(root=DATA_PATH, download=True, train=True, transform=transform), 30 | batch_size=BATCH_SIZE, 31 | num_workers=4, 32 | shuffle=True) 33 | test_loader = DataLoader( 34 | dataset=MNIST(root=DATA_PATH, download=True, train=False, transform=transform), 35 | batch_size=BATCH_SIZE, 36 | num_workers=4, 37 | shuffle=True) 38 | 39 | # Train 40 | EPOCHES = 50 41 | model.train() 42 | for ep in range(EPOCHES): 43 | batch_id = 1 44 | correct, total, total_loss = 0, 0, 0. 45 | for images, labels in train_loader: 46 | optimizer.zero_grad() 47 | images = images.to(device) 48 | labels = torch.eye(10).index_select(dim=0, index=labels).to(device) 49 | logits, reconstruction = model(images) 50 | 51 | # Compute loss & accuracy 52 | loss = criterion(images, labels, logits, reconstruction) 53 | correct += torch.sum( 54 | torch.argmax(logits, dim=1) == torch.argmax(labels, dim=1)).item() 55 | total += len(labels) 56 | accuracy = correct / total 57 | total_loss += loss 58 | loss.backward() 59 | optimizer.step() 60 | print('Epoch {}, batch {}, loss: {}, accuracy: {}'.format(ep + 1, 61 | batch_id, 62 | total_loss / batch_id, 63 | accuracy)) 64 | batch_id += 1 65 | scheduler.step(ep) 66 | print('Total loss for epoch {}: {}'.format(ep + 1, total_loss)) 67 | 68 | # Eval 69 | model.eval() 70 | correct, total = 0, 0 71 | for images, labels in test_loader: 72 | # Add channels = 1 73 | images = images.to(device) 74 | # Categogrical encoding 75 | labels = torch.eye(10).index_select(dim=0, index=labels).to(device) 76 | logits, reconstructions = model(images) 77 | pred_labels = torch.argmax(logits, dim=1) 78 | correct += torch.sum(pred_labels == torch.argmax(labels, dim=1)).item() 79 | total += len(labels) 80 | print('Accuracy: {}'.format(correct / total)) 81 | 82 | # Save model 83 | torch.save(model.state_dict(), './model/capsnet_ep{}_acc{}.pt'.format(EPOCHES, correct / total)) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /capsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | # Available device 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | 8 | def squash(x, dim=-1): 9 | squared_norm = (x ** 2).sum(dim=dim, keepdim=True) 10 | scale = squared_norm / (1 + squared_norm) 11 | return scale * x / (squared_norm.sqrt() + 1e-8) 12 | 13 | 14 | class PrimaryCaps(nn.Module): 15 | """Primary capsule layer.""" 16 | 17 | def __init__(self, num_conv_units, in_channels, out_channels, kernel_size, stride): 18 | super(PrimaryCaps, self).__init__() 19 | 20 | # Each conv unit stands for a single capsule. 21 | self.conv = nn.Conv2d(in_channels=in_channels, 22 | out_channels=out_channels * num_conv_units, 23 | kernel_size=kernel_size, 24 | stride=stride) 25 | self.out_channels = out_channels 26 | 27 | def forward(self, x): 28 | # Shape of x: (batch_size, in_channels, height, weight) 29 | # Shape of out: out_capsules * (batch_size, out_channels, height, weight) 30 | out = self.conv(x) 31 | # Flatten out: (batch_size, out_capsules * height * weight, out_channels) 32 | batch_size = out.shape[0] 33 | return squash(out.contiguous().view(batch_size, -1, self.out_channels), dim=-1) 34 | 35 | 36 | class DigitCaps(nn.Module): 37 | """Digit capsule layer.""" 38 | 39 | def __init__(self, in_dim, in_caps, out_caps, out_dim, num_routing): 40 | """ 41 | Initialize the layer. 42 | 43 | Args: 44 | in_dim: Dimensionality of each capsule vector. 45 | in_caps: Number of input capsules if digits layer. 46 | out_caps: Number of capsules in the capsule layer 47 | out_dim: Dimensionality, of the output capsule vector. 48 | num_routing: Number of iterations during routing algorithm 49 | """ 50 | super(DigitCaps, self).__init__() 51 | self.in_dim = in_dim 52 | self.in_caps = in_caps 53 | self.out_caps = out_caps 54 | self.out_dim = out_dim 55 | self.num_routing = num_routing 56 | self.device = device 57 | self.W = nn.Parameter(0.01 * torch.randn(1, out_caps, in_caps, out_dim, in_dim), 58 | requires_grad=True) 59 | 60 | def forward(self, x): 61 | batch_size = x.size(0) 62 | # (batch_size, in_caps, in_dim) -> (batch_size, 1, in_caps, in_dim, 1) 63 | x = x.unsqueeze(1).unsqueeze(4) 64 | # W @ x = 65 | # (1, out_caps, in_caps, out_dim, in_dim) @ (batch_size, 1, in_caps, in_dim, 1) = 66 | # (batch_size, out_caps, in_caps, out_dims, 1) 67 | u_hat = torch.matmul(self.W, x) 68 | # (batch_size, out_caps, in_caps, out_dim) 69 | u_hat = u_hat.squeeze(-1) 70 | # detach u_hat during routing iterations to prevent gradients from flowing 71 | temp_u_hat = u_hat.detach() 72 | 73 | b = torch.zeros(batch_size, self.out_caps, self.in_caps, 1).to(self.device) 74 | 75 | for route_iter in range(self.num_routing - 1): 76 | # (batch_size, out_caps, in_caps, 1) -> Softmax along out_caps 77 | c = b.softmax(dim=1) 78 | 79 | # element-wise multiplication 80 | # (batch_size, out_caps, in_caps, 1) * (batch_size, in_caps, out_caps, out_dim) -> 81 | # (batch_size, out_caps, in_caps, out_dim) sum across in_caps -> 82 | # (batch_size, out_caps, out_dim) 83 | s = (c * temp_u_hat).sum(dim=2) 84 | # apply "squashing" non-linearity along out_dim 85 | v = squash(s) 86 | # dot product agreement between the current output vj and the prediction uj|i 87 | # (batch_size, out_caps, in_caps, out_dim) @ (batch_size, out_caps, out_dim, 1) 88 | # -> (batch_size, out_caps, in_caps, 1) 89 | uv = torch.matmul(temp_u_hat, v.unsqueeze(-1)) 90 | b += uv 91 | 92 | # last iteration is done on the original u_hat, without the routing weights update 93 | c = b.softmax(dim=1) 94 | s = (c * u_hat).sum(dim=2) 95 | # apply "squashing" non-linearity along out_dim 96 | v = squash(s) 97 | 98 | return v 99 | 100 | 101 | class CapsNet(nn.Module): 102 | """Basic implementation of capsule network layer.""" 103 | 104 | def __init__(self): 105 | super(CapsNet, self).__init__() 106 | 107 | # Conv2d layer 108 | self.conv = nn.Conv2d(1, 256, 9) 109 | self.relu = nn.ReLU(inplace=True) 110 | 111 | # Primary capsule 112 | self.primary_caps = PrimaryCaps(num_conv_units=32, 113 | in_channels=256, 114 | out_channels=8, 115 | kernel_size=9, 116 | stride=2) 117 | 118 | # Digit capsule 119 | self.digit_caps = DigitCaps(in_dim=8, 120 | in_caps=32 * 6 * 6, 121 | out_caps=10, 122 | out_dim=16, 123 | num_routing=3) 124 | 125 | # Reconstruction layer 126 | self.decoder = nn.Sequential( 127 | nn.Linear(16 * 10, 512), 128 | nn.ReLU(inplace=True), 129 | nn.Linear(512, 1024), 130 | nn.ReLU(inplace=True), 131 | nn.Linear(1024, 784), 132 | nn.Sigmoid()) 133 | 134 | def forward(self, x): 135 | out = self.relu(self.conv(x)) 136 | out = self.primary_caps(out) 137 | out = self.digit_caps(out) 138 | 139 | # Shape of logits: (batch_size, out_capsules) 140 | logits = torch.norm(out, dim=-1) 141 | pred = torch.eye(10).to(device).index_select(dim=0, index=torch.argmax(logits, dim=1)) 142 | 143 | # Reconstruction 144 | batch_size = out.shape[0] 145 | reconstruction = self.decoder((out * pred.unsqueeze(2)).contiguous().view(batch_size, -1)) 146 | 147 | return logits, reconstruction 148 | 149 | 150 | class CapsuleLoss(nn.Module): 151 | """Combine margin loss & reconstruction loss of capsule network.""" 152 | 153 | def __init__(self, upper_bound=0.9, lower_bound=0.1, lmda=0.5): 154 | super(CapsuleLoss, self).__init__() 155 | self.upper = upper_bound 156 | self.lower = lower_bound 157 | self.lmda = lmda 158 | self.reconstruction_loss_scalar = 5e-4 159 | self.mse = nn.MSELoss(reduction='sum') 160 | 161 | def forward(self, images, labels, logits, reconstructions): 162 | # Shape of left / right / labels: (batch_size, num_classes) 163 | left = (self.upper - logits).relu() ** 2 # True negative 164 | right = (logits - self.lower).relu() ** 2 # False positive 165 | margin_loss = torch.sum(labels * left) + self.lmda * torch.sum((1 - labels) * right) 166 | 167 | # Reconstruction loss 168 | reconstruction_loss = self.mse(reconstructions.contiguous().view(images.shape), images) 169 | 170 | # Combine two losses 171 | return margin_loss + self.reconstruction_loss_scalar * reconstruction_loss 172 | --------------------------------------------------------------------------------