├── LICENSE ├── README.md └── capsnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Nishant Nikhil 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CapsNet-PyTorch 2 | My attempt at implementing CapsNet from the paper Dynamic Routing Between Capsules.
3 | [Link to paper](https://arxiv.org/abs/1710.09829)
4 | Authors of paper - Sara Sabour, Nicholas Frosst, Geoffrey E Hinton
5 | 6 | 7 | The code is buggy right now, reconstruction loss is yet to be added.
8 | Training-Testing not done till now.
9 | Suggestions and contributions welcome. 10 | -------------------------------------------------------------------------------- /capsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torchvision import datasets, transforms 6 | from torch.autograd import Variable 7 | 8 | batch_size = 1 9 | test_batch_size = 1 10 | epochs = 10 11 | lr = 0.01 12 | momentum = 0.5 13 | no_cuda = True 14 | seed = 1 15 | log_interval = 10 16 | 17 | cuda = not no_cuda and torch.cuda.is_available() 18 | 19 | torch.manual_seed(seed) 20 | 21 | if cuda: 22 | torch.cuda.manual_seed(seed) 23 | 24 | 25 | kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {} 26 | train_loader = torch.utils.data.DataLoader( 27 | datasets.MNIST('../data', train=True, download=True, 28 | transform=transforms.Compose([ 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.1307,), (0.3081,)) 31 | ])), 32 | batch_size=batch_size, shuffle=True, **kwargs) 33 | test_loader = torch.utils.data.DataLoader( 34 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.1307,), (0.3081,)) 37 | ])), 38 | batch_size=test_batch_size, shuffle=True, **kwargs) 39 | 40 | class CapsNet(nn.Module): 41 | def __init__(self): 42 | super(CapsNet, self).__init__() 43 | self.conv1 = nn.Conv2d(1, 256, 9) # First Conv 44 | conv_caps = [nn.Conv2d(256, 8, 9, stride = 2) for i in range(32)] 45 | self.conv_caps = nn.ModuleList(conv_caps) # Primary caps 46 | self.weight_matrices = nn.ModuleList([nn.ModuleList([nn.ModuleList([nn.Linear(8, 16) for i in range(6)]) for i in range(6)]) for i in range(32)]) # From primary caps to digit caps 47 | self.bij = Variable(torch.FloatTensor(32, 6, 6, 10).zero_()) # routing weights 48 | def forward(self, x): 49 | x = F.relu((self.conv1(x))) 50 | prim_caps_layer = [self.conv_caps[i](x).resize(8, 6, 6).permute(1, 2, 0) for i in range(32)] 51 | for k in range(len(prim_caps_layer)): 52 | for i in range(prim_caps_layer[k].size()[0]): 53 | for j in range(prim_caps_layer[k].size()[1]): 54 | tmp = self.non_linearity(prim_caps_layer[k][i, j].clone()) 55 | prim_caps_layer[k][i, j] = tmp 56 | tmp = torch.stack(prim_caps_layer) 57 | out = Variable(torch.FloatTensor(32, 6, 6, 16)) 58 | for i in range(32): 59 | for j in range(6): 60 | for k in range(6): 61 | t = self.weight_matrices[i][j][k](tmp[i, j, k].clone()) 62 | out[i, j, k] = t 63 | # print (self.bij[0][0][0]) 64 | for loop in range(10): 65 | si = Variable(torch.FloatTensor(10, 16).zero_()) 66 | for i in range(32): 67 | for j in range(6): 68 | for k in range(6): 69 | ci = F.softmax(self.bij[i,j,k].clone()) 70 | for m in range(10): 71 | t = si[m].clone() + ci[m].clone() * out[i,j,k].clone() 72 | si[m] = t 73 | for i in range(10): 74 | tmp = self.non_linearity(si[i].clone()) 75 | si[i] = tmp 76 | for i in range(32): 77 | for j in range(6): 78 | for k in range(6): 79 | for m in range(10): 80 | tmp = self.bij[i, j, k, m].clone() + si[m].dot(out[i,j,k].clone()) 81 | self.bij[i, j, k, m] = tmp 82 | # print (self.bij[0][0][0]) 83 | norms = Variable(torch.FloatTensor(10)) 84 | for i in range(10): 85 | norms[i] = si[i].norm() 86 | return norms, self.bij 87 | def non_linearity(self, vec): 88 | nm = vec.norm() 89 | nm2 = nm ** 2 90 | vec = vec * nm2 / ((1 + nm2) * nm) 91 | return vec 92 | 93 | model = CapsNet() 94 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) 95 | model.train() 96 | 97 | point_nine = torch.FloatTensor(1) 98 | point_nine.fill_(0.9) 99 | point_nine = Variable(point_nine) 100 | 101 | point_one = torch.FloatTensor(1) 102 | point_one.fill_(0.1) 103 | point_one = Variable(point_one) 104 | 105 | point_five = torch.FloatTensor(1) 106 | point_five.fill_(0.5) 107 | point_five = Variable(point_five) 108 | 109 | for batch_idx, (data, target) in enumerate(train_loader): 110 | data = Variable(data) 111 | target = target[0] 112 | optimizer.zero_grad() 113 | output = model(data) 114 | norms = output[0] 115 | total_loss = 0 116 | for i in range(10): 117 | print (i) 118 | if (i == target): 119 | loss = torch.max(Variable(torch.zeros(1)), point_nine - norms[i]) 120 | loss.backward(retain_graph = True) 121 | optimizer.step() 122 | total_loss += loss 123 | print (loss) 124 | else: 125 | loss = point_five * torch.max(Variable(torch.zeros(1)), norms[i] - point_one) 126 | loss.backward(retain_graph = True) 127 | optimizer.step() 128 | print (loss) 129 | 130 | --------------------------------------------------------------------------------