├── LICENSE ├── README.md ├── capsnet.py ├── data_loader.py ├── result.jpg └── test_capsnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 jindongwang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch-CapsuleNet 2 | 3 | A flexible and easy-to-follow Pytorch implementation of Hinton's Capsule Network. 4 | 5 | There are already many repos containing the code for CapsNet. However, most of them are too tight to customize. And as we all know, Hinton's original paper is only tested on *MNIST* datasets. We clearly want to do more. 6 | 7 | This repo is designed to hold other datasets and configurations. And the most important thing is, we want to make the code **flexible**. Then, we can *tailor* the network according to our needs. 8 | 9 | Currently, the code supports both **MNIST and CIFAR-10** datasets. 10 | 11 | ## Requirements 12 | 13 | - Python 3.x 14 | - Pytorch 0.3.0 or above 15 | - Numpy 16 | - tqdm (to make display better, of course you can replace it with 'print') 17 | 18 | ## Run 19 | 20 | Just run `Python test_capsnet.py` in your terminal. That's all. If you want to change the dataset (MNIST or CIFAR-10), you can easily set the `dataset` variable. 21 | 22 | It is better to run the code on a server with GPUs. Capsule network demands good computing devices. For instance, on my device (Nvidia K80), it will take about 5 minutes for one epoch of the MNIST datasets (batch size = 100). 23 | 24 | ## More details 25 | 26 | There are 3 `.py` files: 27 | - `capsnet.py`: the main class for capsule network 28 | - `data_loader.py`: the class to hold many classes 29 | - `test_capsnet.py`: the training and testing code 30 | 31 | The results on your device may look like the following picture: 32 | 33 | ![](https://raw.githubusercontent.com/jindongwang/Pytorch-CapsuleNet/master/result.jpg) 34 | 35 | ## Acknowledgements 36 | 37 | - [Capsule-Network-Tutorial](https://github.com/higgsfield/Capsule-Network-Tutorial) 38 | - The original paper of Capsule Net by Geoffrey Hinton: [Dynamic routing between capsules](http://papers.nips.cc/paper/6975-dynamic-routing-between-capsules) 39 | -------------------------------------------------------------------------------- /capsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | USE_CUDA = True if torch.cuda.is_available() else False 7 | 8 | 9 | class ConvLayer(nn.Module): 10 | def __init__(self, in_channels=1, out_channels=256, kernel_size=9): 11 | super(ConvLayer, self).__init__() 12 | 13 | self.conv = nn.Conv2d(in_channels=in_channels, 14 | out_channels=out_channels, 15 | kernel_size=kernel_size, 16 | stride=1 17 | ) 18 | 19 | def forward(self, x): 20 | return F.relu(self.conv(x)) 21 | 22 | 23 | class PrimaryCaps(nn.Module): 24 | def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6): 25 | super(PrimaryCaps, self).__init__() 26 | self.num_routes = num_routes 27 | self.capsules = nn.ModuleList([ 28 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0) 29 | for _ in range(num_capsules)]) 30 | 31 | def forward(self, x): 32 | u = [capsule(x) for capsule in self.capsules] 33 | u = torch.stack(u, dim=1) 34 | u = u.view(x.size(0), self.num_routes, -1) 35 | return self.squash(u) 36 | 37 | def squash(self, input_tensor): 38 | squared_norm = (input_tensor ** 2).sum(-1, keepdim=True) 39 | output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm)) 40 | return output_tensor 41 | 42 | 43 | class DigitCaps(nn.Module): 44 | def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16): 45 | super(DigitCaps, self).__init__() 46 | 47 | self.in_channels = in_channels 48 | self.num_routes = num_routes 49 | self.num_capsules = num_capsules 50 | 51 | self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels)) 52 | 53 | def forward(self, x): 54 | batch_size = x.size(0) 55 | x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4) 56 | 57 | W = torch.cat([self.W] * batch_size, dim=0) 58 | u_hat = torch.matmul(W, x) 59 | 60 | b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1)) 61 | if USE_CUDA: 62 | b_ij = b_ij.cuda() 63 | 64 | num_iterations = 3 65 | for iteration in range(num_iterations): 66 | c_ij = F.softmax(b_ij, dim=1) 67 | c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4) 68 | 69 | s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) 70 | v_j = self.squash(s_j) 71 | 72 | if iteration < num_iterations - 1: 73 | a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1)) 74 | b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True) 75 | 76 | return v_j.squeeze(1) 77 | 78 | def squash(self, input_tensor): 79 | squared_norm = (input_tensor ** 2).sum(-1, keepdim=True) 80 | output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm)) 81 | return output_tensor 82 | 83 | 84 | class Decoder(nn.Module): 85 | def __init__(self, input_width=28, input_height=28, input_channel=1): 86 | super(Decoder, self).__init__() 87 | self.input_width = input_width 88 | self.input_height = input_height 89 | self.input_channel = input_channel 90 | self.reconstraction_layers = nn.Sequential( 91 | nn.Linear(16 * 10, 512), 92 | nn.ReLU(inplace=True), 93 | nn.Linear(512, 1024), 94 | nn.ReLU(inplace=True), 95 | nn.Linear(1024, self.input_height * self.input_width * self.input_channel), 96 | nn.Sigmoid() 97 | ) 98 | 99 | def forward(self, x, data): 100 | classes = torch.sqrt((x ** 2).sum(2)) 101 | classes = F.softmax(classes, dim=0) 102 | 103 | _, max_length_indices = classes.max(dim=1) 104 | masked = Variable(torch.sparse.torch.eye(10)) 105 | if USE_CUDA: 106 | masked = masked.cuda() 107 | masked = masked.index_select(dim=0, index=Variable(max_length_indices.squeeze(1).data)) 108 | t = (x * masked[:, :, None, None]).view(x.size(0), -1) 109 | reconstructions = self.reconstraction_layers(t) 110 | reconstructions = reconstructions.view(-1, self.input_channel, self.input_width, self.input_height) 111 | return reconstructions, masked 112 | 113 | 114 | class CapsNet(nn.Module): 115 | def __init__(self, config=None): 116 | super(CapsNet, self).__init__() 117 | if config: 118 | self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size) 119 | self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels, 120 | config.pc_kernel_size, config.pc_num_routes) 121 | self.digit_capsules = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels, 122 | config.dc_out_channels) 123 | self.decoder = Decoder(config.input_width, config.input_height, config.cnn_in_channels) 124 | else: 125 | self.conv_layer = ConvLayer() 126 | self.primary_capsules = PrimaryCaps() 127 | self.digit_capsules = DigitCaps() 128 | self.decoder = Decoder() 129 | 130 | self.mse_loss = nn.MSELoss() 131 | 132 | def forward(self, data): 133 | output = self.digit_capsules(self.primary_capsules(self.conv_layer(data))) 134 | reconstructions, masked = self.decoder(output, data) 135 | return output, reconstructions, masked 136 | 137 | def loss(self, data, x, target, reconstructions): 138 | return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions) 139 | 140 | def margin_loss(self, x, labels, size_average=True): 141 | batch_size = x.size(0) 142 | 143 | v_c = torch.sqrt((x ** 2).sum(dim=2, keepdim=True)) 144 | 145 | left = F.relu(0.9 - v_c).view(batch_size, -1) 146 | right = F.relu(v_c - 0.1).view(batch_size, -1) 147 | 148 | loss = labels * left + 0.5 * (1.0 - labels) * right 149 | loss = loss.sum(dim=1).mean() 150 | 151 | return loss 152 | 153 | def reconstruction_loss(self, data, reconstructions): 154 | loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1)) 155 | return loss * 0.0005 156 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | 5 | class Dataset: 6 | def __init__(self, dataset, _batch_size): 7 | super(Dataset, self).__init__() 8 | if dataset == 'mnist': 9 | dataset_transform = transforms.Compose([ 10 | transforms.ToTensor(), 11 | transforms.Normalize((0.1307,), (0.3081,)) 12 | ]) 13 | 14 | train_dataset = datasets.MNIST('/data/mnist', train=True, download=True, 15 | transform=dataset_transform) 16 | test_dataset = datasets.MNIST('/data/mnist', train=False, download=True, 17 | transform=dataset_transform) 18 | 19 | self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=_batch_size, shuffle=True) 20 | self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=_batch_size, shuffle=False) 21 | 22 | elif dataset == 'cifar10': 23 | data_transform = transforms.Compose([ 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 26 | ]) 27 | train_dataset = datasets.CIFAR10( 28 | '/data/cifar', train=True, download=True, transform=data_transform) 29 | test_dataset = datasets.CIFAR10( 30 | '/data/cifar', train=False, download=True, transform=data_transform) 31 | 32 | self.train_loader = torch.utils.data.DataLoader( 33 | train_dataset, batch_size=_batch_size, shuffle=True) 34 | 35 | self.test_loader = torch.utils.data.DataLoader( 36 | test_dataset, batch_size=_batch_size, shuffle=False) 37 | elif dataset == 'office-caltech': 38 | pass 39 | elif dataset == 'office31': 40 | pass 41 | -------------------------------------------------------------------------------- /result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jindongwang/Pytorch-CapsuleNet/c5031b719dcd2e67bbc6ed4d1557af4362c1d9d0/result.jpg -------------------------------------------------------------------------------- /test_capsnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | from torchvision import datasets, transforms 7 | from capsnet import CapsNet 8 | from data_loader import Dataset 9 | from tqdm import tqdm 10 | 11 | USE_CUDA = True if torch.cuda.is_available() else False 12 | BATCH_SIZE = 100 13 | N_EPOCHS = 30 14 | LEARNING_RATE = 0.01 15 | MOMENTUM = 0.9 16 | 17 | ''' 18 | Config class to determine the parameters for capsule net 19 | ''' 20 | 21 | 22 | class Config: 23 | def __init__(self, dataset='mnist'): 24 | if dataset == 'mnist': 25 | # CNN (cnn) 26 | self.cnn_in_channels = 1 27 | self.cnn_out_channels = 256 28 | self.cnn_kernel_size = 9 29 | 30 | # Primary Capsule (pc) 31 | self.pc_num_capsules = 8 32 | self.pc_in_channels = 256 33 | self.pc_out_channels = 32 34 | self.pc_kernel_size = 9 35 | self.pc_num_routes = 32 * 6 * 6 36 | 37 | # Digit Capsule (dc) 38 | self.dc_num_capsules = 10 39 | self.dc_num_routes = 32 * 6 * 6 40 | self.dc_in_channels = 8 41 | self.dc_out_channels = 16 42 | 43 | # Decoder 44 | self.input_width = 28 45 | self.input_height = 28 46 | 47 | elif dataset == 'cifar10': 48 | # CNN (cnn) 49 | self.cnn_in_channels = 3 50 | self.cnn_out_channels = 256 51 | self.cnn_kernel_size = 9 52 | 53 | # Primary Capsule (pc) 54 | self.pc_num_capsules = 8 55 | self.pc_in_channels = 256 56 | self.pc_out_channels = 32 57 | self.pc_kernel_size = 9 58 | self.pc_num_routes = 32 * 8 * 8 59 | 60 | # Digit Capsule (dc) 61 | self.dc_num_capsules = 10 62 | self.dc_num_routes = 32 * 8 * 8 63 | self.dc_in_channels = 8 64 | self.dc_out_channels = 16 65 | 66 | # Decoder 67 | self.input_width = 32 68 | self.input_height = 32 69 | 70 | elif dataset == 'your own dataset': 71 | pass 72 | 73 | 74 | def train(model, optimizer, train_loader, epoch): 75 | capsule_net = model 76 | capsule_net.train() 77 | n_batch = len(list(enumerate(train_loader))) 78 | total_loss = 0 79 | for batch_id, (data, target) in enumerate(tqdm(train_loader)): 80 | 81 | target = torch.sparse.torch.eye(10).index_select(dim=0, index=target) 82 | data, target = Variable(data), Variable(target) 83 | 84 | if USE_CUDA: 85 | data, target = data.cuda(), target.cuda() 86 | 87 | optimizer.zero_grad() 88 | output, reconstructions, masked = capsule_net(data) 89 | loss = capsule_net.loss(data, output, target, reconstructions) 90 | loss.backward() 91 | optimizer.step() 92 | correct = sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1)) 93 | train_loss = loss.item() 94 | total_loss += train_loss 95 | if batch_id % 100 == 0: 96 | tqdm.write("Epoch: [{}/{}], Batch: [{}/{}], train accuracy: {:.6f}, loss: {:.6f}".format( 97 | epoch, 98 | N_EPOCHS, 99 | batch_id + 1, 100 | n_batch, 101 | correct / float(BATCH_SIZE), 102 | train_loss / float(BATCH_SIZE) 103 | )) 104 | tqdm.write('Epoch: [{}/{}], train loss: {:.6f}'.format(epoch,N_EPOCHS,total_loss / len(train_loader.dataset))) 105 | 106 | 107 | def test(capsule_net, test_loader, epoch): 108 | capsule_net.eval() 109 | test_loss = 0 110 | correct = 0 111 | for batch_id, (data, target) in enumerate(test_loader): 112 | 113 | target = torch.sparse.torch.eye(10).index_select(dim=0, index=target) 114 | data, target = Variable(data), Variable(target) 115 | 116 | if USE_CUDA: 117 | data, target = data.cuda(), target.cuda() 118 | 119 | output, reconstructions, masked = capsule_net(data) 120 | loss = capsule_net.loss(data, output, target, reconstructions) 121 | 122 | test_loss += loss.item() 123 | correct += sum(np.argmax(masked.data.cpu().numpy(), 1) == 124 | np.argmax(target.data.cpu().numpy(), 1)) 125 | 126 | tqdm.write( 127 | "Epoch: [{}/{}], test accuracy: {:.6f}, loss: {:.6f}".format(epoch, N_EPOCHS, correct / len(test_loader.dataset), 128 | test_loss / len(test_loader))) 129 | 130 | 131 | if __name__ == '__main__': 132 | torch.manual_seed(1) 133 | dataset = 'cifar10' 134 | # dataset = 'mnist' 135 | config = Config(dataset) 136 | mnist = Dataset(dataset, BATCH_SIZE) 137 | 138 | capsule_net = CapsNet(config) 139 | capsule_net = torch.nn.DataParallel(capsule_net) 140 | if USE_CUDA: 141 | capsule_net = capsule_net.cuda() 142 | capsule_net = capsule_net.module 143 | 144 | optimizer = torch.optim.Adam(capsule_net.parameters()) 145 | 146 | for e in range(1, N_EPOCHS + 1): 147 | train(capsule_net, optimizer, mnist.train_loader, e) 148 | test(capsule_net, mnist.test_loader, e) 149 | --------------------------------------------------------------------------------