├── LICENSE
├── README.md
└── capsnet.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Nishant Nikhil
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CapsNet-PyTorch
2 | My attempt at implementing CapsNet from the paper Dynamic Routing Between Capsules.
3 | [Link to paper](https://arxiv.org/abs/1710.09829)
4 | Authors of paper - Sara Sabour, Nicholas Frosst, Geoffrey E Hinton
5 |
6 |
7 | The code is buggy right now, reconstruction loss is yet to be added.
8 | Training-Testing not done till now.
9 | Suggestions and contributions welcome.
10 |
--------------------------------------------------------------------------------
/capsnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | from torchvision import datasets, transforms
6 | from torch.autograd import Variable
7 |
8 | batch_size = 1
9 | test_batch_size = 1
10 | epochs = 10
11 | lr = 0.01
12 | momentum = 0.5
13 | no_cuda = True
14 | seed = 1
15 | log_interval = 10
16 |
17 | cuda = not no_cuda and torch.cuda.is_available()
18 |
19 | torch.manual_seed(seed)
20 |
21 | if cuda:
22 | torch.cuda.manual_seed(seed)
23 |
24 |
25 | kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
26 | train_loader = torch.utils.data.DataLoader(
27 | datasets.MNIST('../data', train=True, download=True,
28 | transform=transforms.Compose([
29 | transforms.ToTensor(),
30 | transforms.Normalize((0.1307,), (0.3081,))
31 | ])),
32 | batch_size=batch_size, shuffle=True, **kwargs)
33 | test_loader = torch.utils.data.DataLoader(
34 | datasets.MNIST('../data', train=False, transform=transforms.Compose([
35 | transforms.ToTensor(),
36 | transforms.Normalize((0.1307,), (0.3081,))
37 | ])),
38 | batch_size=test_batch_size, shuffle=True, **kwargs)
39 |
40 | class CapsNet(nn.Module):
41 | def __init__(self):
42 | super(CapsNet, self).__init__()
43 | self.conv1 = nn.Conv2d(1, 256, 9) # First Conv
44 | conv_caps = [nn.Conv2d(256, 8, 9, stride = 2) for i in range(32)]
45 | self.conv_caps = nn.ModuleList(conv_caps) # Primary caps
46 | self.weight_matrices = nn.ModuleList([nn.ModuleList([nn.ModuleList([nn.Linear(8, 16) for i in range(6)]) for i in range(6)]) for i in range(32)]) # From primary caps to digit caps
47 | self.bij = Variable(torch.FloatTensor(32, 6, 6, 10).zero_()) # routing weights
48 | def forward(self, x):
49 | x = F.relu((self.conv1(x)))
50 | prim_caps_layer = [self.conv_caps[i](x).resize(8, 6, 6).permute(1, 2, 0) for i in range(32)]
51 | for k in range(len(prim_caps_layer)):
52 | for i in range(prim_caps_layer[k].size()[0]):
53 | for j in range(prim_caps_layer[k].size()[1]):
54 | tmp = self.non_linearity(prim_caps_layer[k][i, j].clone())
55 | prim_caps_layer[k][i, j] = tmp
56 | tmp = torch.stack(prim_caps_layer)
57 | out = Variable(torch.FloatTensor(32, 6, 6, 16))
58 | for i in range(32):
59 | for j in range(6):
60 | for k in range(6):
61 | t = self.weight_matrices[i][j][k](tmp[i, j, k].clone())
62 | out[i, j, k] = t
63 | # print (self.bij[0][0][0])
64 | for loop in range(10):
65 | si = Variable(torch.FloatTensor(10, 16).zero_())
66 | for i in range(32):
67 | for j in range(6):
68 | for k in range(6):
69 | ci = F.softmax(self.bij[i,j,k].clone())
70 | for m in range(10):
71 | t = si[m].clone() + ci[m].clone() * out[i,j,k].clone()
72 | si[m] = t
73 | for i in range(10):
74 | tmp = self.non_linearity(si[i].clone())
75 | si[i] = tmp
76 | for i in range(32):
77 | for j in range(6):
78 | for k in range(6):
79 | for m in range(10):
80 | tmp = self.bij[i, j, k, m].clone() + si[m].dot(out[i,j,k].clone())
81 | self.bij[i, j, k, m] = tmp
82 | # print (self.bij[0][0][0])
83 | norms = Variable(torch.FloatTensor(10))
84 | for i in range(10):
85 | norms[i] = si[i].norm()
86 | return norms, self.bij
87 | def non_linearity(self, vec):
88 | nm = vec.norm()
89 | nm2 = nm ** 2
90 | vec = vec * nm2 / ((1 + nm2) * nm)
91 | return vec
92 |
93 | model = CapsNet()
94 | optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
95 | model.train()
96 |
97 | point_nine = torch.FloatTensor(1)
98 | point_nine.fill_(0.9)
99 | point_nine = Variable(point_nine)
100 |
101 | point_one = torch.FloatTensor(1)
102 | point_one.fill_(0.1)
103 | point_one = Variable(point_one)
104 |
105 | point_five = torch.FloatTensor(1)
106 | point_five.fill_(0.5)
107 | point_five = Variable(point_five)
108 |
109 | for batch_idx, (data, target) in enumerate(train_loader):
110 | data = Variable(data)
111 | target = target[0]
112 | optimizer.zero_grad()
113 | output = model(data)
114 | norms = output[0]
115 | total_loss = 0
116 | for i in range(10):
117 | print (i)
118 | if (i == target):
119 | loss = torch.max(Variable(torch.zeros(1)), point_nine - norms[i])
120 | loss.backward(retain_graph = True)
121 | optimizer.step()
122 | total_loss += loss
123 | print (loss)
124 | else:
125 | loss = point_five * torch.max(Variable(torch.zeros(1)), norms[i] - point_one)
126 | loss.backward(retain_graph = True)
127 | optimizer.step()
128 | print (loss)
129 |
130 |
--------------------------------------------------------------------------------