├── .gitignore
├── CapsNet.ipynb
├── LICENSE
├── README.md
└── images
├── capsulearch.png
└── reconsArch.png
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # pytorch
104 | *.pth
105 | processed/
106 | raw/
107 |
--------------------------------------------------------------------------------
/CapsNet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Introduction"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "This is a [pytorch](http://pytorch.org/) implementation of CapsNet, described in the paper [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829) - by [Sara Sabour](https://arxiv.org/find/cs/1/au:+Sabour_S/0/1/0/all/0/1), [Nicholas Frosst](https://arxiv.org/find/cs/1/au:+Frosst_N/0/1/0/all/0/1) and [Geoffrey E Hinton](https://arxiv.org/find/cs/1/au:+Hinton_G/0/1/0/all/0/1).\n",
15 | "\n",
16 | "All images and text in the following sections are extracted directly from the paper."
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {},
22 | "source": [
23 | "# Import Dependencies"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": null,
29 | "metadata": {},
30 | "outputs": [],
31 | "source": [
32 | "import numpy as np\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "import torch\n",
35 | "import torchvision\n",
36 | "import torch.nn.functional as F\n",
37 | "from torch.autograd import Variable\n",
38 | "from tqdm.auto import tqdm\n",
39 | "from collections import defaultdict"
40 | ]
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {},
45 | "source": [
46 | "# Load MNIST"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "Training is performed on 28 x 28 MNIST images that have been shifted by up to 2 pixels in each direction with zero padding. No other data augmentation/deformation is used."
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "INPUT_SIZE = (1, 28, 28)\n",
63 | "transforms = torchvision.transforms.Compose([\n",
64 | " torchvision.transforms.RandomCrop(INPUT_SIZE[1:], padding=2),\n",
65 | " torchvision.transforms.ToTensor(),\n",
66 | "])"
67 | ]
68 | },
69 | {
70 | "cell_type": "markdown",
71 | "metadata": {},
72 | "source": [
73 | "The dataset has 60K and 10K images for training and testing respectively."
74 | ]
75 | },
76 | {
77 | "cell_type": "code",
78 | "execution_count": null,
79 | "metadata": {},
80 | "outputs": [],
81 | "source": [
82 | "trn_dataset = torchvision.datasets.MNIST('.', train=True, download=True, transform=transforms)\n",
83 | "tst_dataset = torchvision.datasets.MNIST('.', train=False, download=True, transform=transforms)\n",
84 | "print('Images for training: %d' % len(trn_dataset))\n",
85 | "print('Images for testing: %d' % len(tst_dataset))"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "BATCH_SIZE = 128 # Batch size not specified in the paper\n",
95 | "trn_loader = torch.utils.data.DataLoader(trn_dataset, BATCH_SIZE, shuffle=True)\n",
96 | "tst_loader = torch.utils.data.DataLoader(tst_dataset, BATCH_SIZE, shuffle=False)"
97 | ]
98 | },
99 | {
100 | "cell_type": "markdown",
101 | "metadata": {},
102 | "source": [
103 | "# Define CapsNet"
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "metadata": {},
109 | "source": [
110 | "## Conv1"
111 | ]
112 | },
113 | {
114 | "cell_type": "markdown",
115 | "metadata": {},
116 | "source": [
117 | "Conv1 has 256, 9 x 9 convolution kernels with a stride of 1 and ReLU activation. This layer converts pixel intensities to the activities of local feature detectors that are then used as inputs to the *primary* capsules."
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": null,
123 | "metadata": {},
124 | "outputs": [],
125 | "source": [
126 | "class Conv1(torch.nn.Module):\n",
127 | " def __init__(self, in_channels, out_channels=256, kernel_size=9):\n",
128 | " super(Conv1, self).__init__()\n",
129 | " self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size)\n",
130 | " self.activation = torch.nn.ReLU()\n",
131 | " \n",
132 | " def forward(self, x):\n",
133 | " x = self.conv(x)\n",
134 | " x = self.activation(x)\n",
135 | " return x"
136 | ]
137 | },
138 | {
139 | "cell_type": "markdown",
140 | "metadata": {},
141 | "source": [
142 | "## Primary Capsules"
143 | ]
144 | },
145 | {
146 | "cell_type": "markdown",
147 | "metadata": {},
148 | "source": [
149 | "The second layer (PrimaryCapsules) is a convolutional capsule layer with 32 channels of convolutional 8D capsules (*i.e.* each primary capsule contains 8 convolutional units with a $[9 \\times 9]$ kernel and a stride of 2). Each primary capsule output sees the outputs of all $[256 \\times 81]$ Conv1 units whose receptive fields overlap with the location of the center of the capsule. In total PrimaryCapsules has $[32 \\times 6 \\times 6]$ capsule outputs (each output is an 8D vector) and each capsule in the $[6 \\times 6]$ grid is sharing their weights with each other. One can see PrimaryCapsules as a Convolution layer with Eq. 1 as its block non-linearity."
150 | ]
151 | },
152 | {
153 | "cell_type": "code",
154 | "execution_count": null,
155 | "metadata": {},
156 | "outputs": [],
157 | "source": [
158 | "class PrimaryCapsules(torch.nn.Module):\n",
159 | " def __init__(self, input_shape=(256, 20, 20), capsule_dim=8,\n",
160 | " out_channels=32, kernel_size=9, stride=2):\n",
161 | " super(PrimaryCapsules, self).__init__()\n",
162 | " self.input_shape = input_shape\n",
163 | " self.capsule_dim = capsule_dim\n",
164 | " self.out_channels = out_channels\n",
165 | " self.kernel_size = kernel_size\n",
166 | " self.stride = stride\n",
167 | " self.in_channels = self.input_shape[0]\n",
168 | " \n",
169 | " self.conv = torch.nn.Conv2d(\n",
170 | " self.in_channels,\n",
171 | " self.out_channels * self.capsule_dim,\n",
172 | " self.kernel_size,\n",
173 | " self.stride\n",
174 | " )\n",
175 | " \n",
176 | " def forward(self, x):\n",
177 | " x = self.conv(x)\n",
178 | " x = x.permute(0, 2, 3, 1).contiguous()\n",
179 | " x = x.view(-1, x.size()[1], x.size()[2], self.out_channels, self.capsule_dim)\n",
180 | " return x"
181 | ]
182 | },
183 | {
184 | "cell_type": "markdown",
185 | "metadata": {},
186 | "source": [
187 | "## Routing"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "metadata": {},
193 | "source": [
194 | "We want the length of the output vector of a capsule to represent the probability that the entity represented by the capsule is present in the current input. We therefore use a non-linear \"squashing\" function to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1. We leave it to discriminative learning to make good use of this non-linearity.\n",
195 | "\n",
196 | "\\begin{equation*}\n",
197 | "\\mathbf{v}_j = \\frac{||\\mathbf{s}_j||^2}{1 + ||\\mathbf{s}_j||^2} \\frac{\\mathbf{s}_j}{||\\mathbf{s}_j||}\n",
198 | "\\end{equation*}\n",
199 | "\n",
200 | "where $\\mathbf{v}_j$ is the vector output of capsule $j$ and $\\mathbf{s}_j$ is its total input.\n",
201 | "\n",
202 | "For all but the first layer of capsules, the total input to a capsule $\\mathbf{s}_j$ is a weighted sum over all \"prediction vectors\" $\\mathbf{\\hat u}_{j|i}$ from the capsules in the layer below and is produced by multiplying the output $\\mathbf{u}_i$ of a capsule in the layer below by a weight matrix $\\mathbf{W}_{ij}$\n",
203 | "\n",
204 | "\\begin{equation*}\n",
205 | "\\mathbf{s}_j = \\sum_i c_{ij} \\mathbf{\\hat u}_{j|i}, \\quad \\mathbf{\\hat u}_{j|i} = \\mathbf{W}_{ij} \\mathbf{u}_i\n",
206 | "\\end{equation*}\n",
207 | "\n",
208 | "where the $c_{ij}$ are coupling coefficients that are determined by the iterative dynamic routing process.\n",
209 | "\n",
210 | "The coupling coefficients between capsule $i$ and all the capsules in the layer above sum to 1 and are determined by a \"routing softmax\" whose initial logits $b_{ij}$ are the log prior probabilities that capsule $i$ should be coupled to capsule $j$.\n",
211 | "\n",
212 | "\\begin{equation*}\n",
213 | "c_{ij} = \\frac{\\exp(b_{ij})}{\\sum_k \\exp(b_{ik})}\n",
214 | "\\end{equation*}\n",
215 | "\n",
216 | "The log priors can be learned discriminatively at the same time as all the other weights. They depend on the location and type of the two capsules but not on the current input image. The initial coupling coefficients are then iteratively refined by measuring the agreement between the current output $\\mathbf{v}_j$ of each capsule, $j$, in the layer above and the prediction $\\mathbf{\\hat u}_{j|i}$ made by capsule $i$.\n",
217 | "\n",
218 | "The agreement is simply the scalar product $a_{ij} = \\mathbf{v}_j \\cdot \\mathbf{\\hat u}_{j|i}$. This agreement is treated as if it was a log likelihood and is added to the initial logit, $b_{ij}$ before computing the new values for all the coupling coefficients linking capsule $i$ to higher level capsules.\n",
219 | "\n",
220 | "In convolutional capsule layers, each capsule outputs a local grid of vectors to each type of capsule in the layer above using different transformation matrices for each member of the grid as well as for each type of capsule."
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "class Routing(torch.nn.Module):\n",
230 | " def __init__(self, caps_dim_before=8, caps_dim_after=16,\n",
231 | " n_capsules_before=(6 * 6 * 32), n_capsules_after=10):\n",
232 | " super(Routing, self).__init__()\n",
233 | " self.n_capsules_before = n_capsules_before\n",
234 | " self.n_capsules_after = n_capsules_after\n",
235 | " self.caps_dim_before = caps_dim_before\n",
236 | " self.caps_dim_after = caps_dim_after\n",
237 | " \n",
238 | " # Parameter initialization not specified in the paper\n",
239 | " n_in = self.n_capsules_before * self.caps_dim_before\n",
240 | " variance = 2 / (n_in)\n",
241 | " std = np.sqrt(variance)\n",
242 | " self.W = torch.nn.Parameter(\n",
243 | " torch.randn(\n",
244 | " self.n_capsules_before,\n",
245 | " self.n_capsules_after,\n",
246 | " self.caps_dim_after,\n",
247 | " self.caps_dim_before) * std,\n",
248 | " requires_grad=True)\n",
249 | " \n",
250 | " # Equation (1)\n",
251 | " @staticmethod\n",
252 | " def squash(s):\n",
253 | " s_norm = torch.norm(s, p=2, dim=-1, keepdim=True)\n",
254 | " s_norm2 = torch.pow(s_norm, 2)\n",
255 | " v = (s_norm2 / (1.0 + s_norm2)) * (s / s_norm)\n",
256 | " return v\n",
257 | " \n",
258 | " # Equation (2)\n",
259 | " def affine(self, x):\n",
260 | " x = self.W @ x.unsqueeze(2).expand(-1, -1, 10, -1).unsqueeze(-1)\n",
261 | " return x.squeeze()\n",
262 | " \n",
263 | " # Equation (3)\n",
264 | " @staticmethod\n",
265 | " def softmax(x, dim=-1):\n",
266 | " exp = torch.exp(x)\n",
267 | " return exp / torch.sum(exp, dim, keepdim=True)\n",
268 | " \n",
269 | " # Procedure 1 - Routing algorithm.\n",
270 | " def routing(self, u, r, l):\n",
271 | " b = Variable(torch.zeros(u.size()[0], l[0], l[1]), requires_grad=False).cuda() # torch.Size([?, 1152, 10])\n",
272 | " \n",
273 | " for iteration in range(r):\n",
274 | " c = Routing.softmax(b) # torch.Size([?, 1152, 10])\n",
275 | " s = (c.unsqueeze(-1).expand(-1, -1, -1, u.size()[-1]) * u).sum(1) # torch.Size([?, 1152, 16])\n",
276 | " v = Routing.squash(s) # torch.Size([?, 10, 16])\n",
277 | " b += (u * v.unsqueeze(1).expand(-1, l[0], -1, -1)).sum(-1)\n",
278 | " return v\n",
279 | " \n",
280 | " def forward(self, x, n_routing_iter):\n",
281 | " x = x.view((-1, self.n_capsules_before, self.caps_dim_before))\n",
282 | " x = self.affine(x) # torch.Size([?, 1152, 10, 16])\n",
283 | " x = self.routing(x, n_routing_iter, (self.n_capsules_before, self.n_capsules_after))\n",
284 | " return x"
285 | ]
286 | },
287 | {
288 | "cell_type": "markdown",
289 | "metadata": {},
290 | "source": [
291 | "The final Layer (DigitCaps) has one 16D capsule per digit class and each of these capsules receives input from all the capsules in the layer below.\n",
292 | "\n",
293 | "We have routing only between two consecutive capsule layers (e.g. PrimaryCapsules and DigitCaps).\n",
294 | "Since Conv1 output is 1D, there is no orientation in its space to agree on. Therefore, no routing is used between Conv1 and PrimaryCapsules. All the routing logits ($b_{ij}$) are initialized to zero. Therefore, initially a capsule output ($\\mathbf{u}_i$) is sent to all parent capsules ($\\mathbf{v}_0...\\mathbf{v}_9$) with equal probability ($c_{ij}$)."
295 | ]
296 | },
297 | {
298 | "cell_type": "markdown",
299 | "metadata": {},
300 | "source": [
301 | "## Norm"
302 | ]
303 | },
304 | {
305 | "cell_type": "markdown",
306 | "metadata": {},
307 | "source": [
308 | "We are using the length of the instantiation vector to represent the probability that a capsule’s entity exists. We would like the top-level capsule for digit class $k$ to have a long instantiation vector if and only if that digit is present in the image."
309 | ]
310 | },
311 | {
312 | "cell_type": "code",
313 | "execution_count": null,
314 | "metadata": {},
315 | "outputs": [],
316 | "source": [
317 | "class Norm(torch.nn.Module):\n",
318 | " def __init__(self):\n",
319 | " super(Norm, self).__init__()\n",
320 | " \n",
321 | " def forward(self, x):\n",
322 | " x = torch.norm(x, p=2, dim=-1)\n",
323 | " return x"
324 | ]
325 | },
326 | {
327 | "cell_type": "markdown",
328 | "metadata": {},
329 | "source": [
330 | "## Decoder"
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "metadata": {},
336 | "source": [
337 | "During training, we mask out all but the activity vector of the correct digit capsule. Then we use this activity vector to reconstruct the input image. The output of the digit capsule is fed into a decoder consisting of 3 fully connected layers that model the pixel intensities (...).\n",
338 | "\n",
339 | "
"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": null,
345 | "metadata": {},
346 | "outputs": [],
347 | "source": [
348 | "class Decoder(torch.nn.Module):\n",
349 | " def __init__(self, in_features, out_features, output_size=INPUT_SIZE):\n",
350 | " super(Decoder, self).__init__()\n",
351 | " self.decoder = self.assemble_decoder(in_features, out_features)\n",
352 | " self.output_size = output_size\n",
353 | " \n",
354 | " def assemble_decoder(self, in_features, out_features):\n",
355 | " HIDDEN_LAYER_FEATURES = [512, 1024]\n",
356 | " return torch.nn.Sequential(\n",
357 | " torch.nn.Linear(in_features, HIDDEN_LAYER_FEATURES[0]),\n",
358 | " torch.nn.ReLU(),\n",
359 | " torch.nn.Linear(HIDDEN_LAYER_FEATURES[0], HIDDEN_LAYER_FEATURES[1]),\n",
360 | " torch.nn.ReLU(),\n",
361 | " torch.nn.Linear(HIDDEN_LAYER_FEATURES[1], out_features),\n",
362 | " torch.nn.Sigmoid(),\n",
363 | " )\n",
364 | " \n",
365 | " def forward(self, x, y):\n",
366 | " x = x[np.arange(0, x.size()[0]), y.cpu().data.numpy(), :].cuda()\n",
367 | " x = self.decoder(x)\n",
368 | " x = x.view(*((-1,) + self.output_size))\n",
369 | " return x"
370 | ]
371 | },
372 | {
373 | "cell_type": "markdown",
374 | "metadata": {},
375 | "source": [
376 | "## CapsNet"
377 | ]
378 | },
379 | {
380 | "cell_type": "markdown",
381 | "metadata": {},
382 | "source": [
383 | "The architecture is shallow with only two convolutional layers and one fully connected layer.\n",
384 | "\n",
385 | "
"
386 | ]
387 | },
388 | {
389 | "cell_type": "code",
390 | "execution_count": null,
391 | "metadata": {},
392 | "outputs": [],
393 | "source": [
394 | "class CapsNet(torch.nn.Module):\n",
395 | " def __init__(self, input_shape=INPUT_SIZE, n_routing_iter=3, use_reconstruction=True):\n",
396 | " super(CapsNet, self).__init__()\n",
397 | " assert len(input_shape) == 3\n",
398 | " \n",
399 | " self.input_shape = input_shape\n",
400 | " self.n_routing_iter = n_routing_iter\n",
401 | " self.use_reconstruction = use_reconstruction\n",
402 | " \n",
403 | " self.conv1 = Conv1(input_shape[0], 256, 9)\n",
404 | " self.primary_capsules = PrimaryCapsules(\n",
405 | " input_shape=(256, 20, 20),\n",
406 | " capsule_dim=8,\n",
407 | " out_channels=32,\n",
408 | " kernel_size=9,\n",
409 | " stride=2\n",
410 | " )\n",
411 | " self.routing = Routing(\n",
412 | " caps_dim_before=8,\n",
413 | " caps_dim_after=16,\n",
414 | " n_capsules_before=6 * 6 * 32,\n",
415 | " n_capsules_after=10\n",
416 | " )\n",
417 | " self.norm = Norm()\n",
418 | " \n",
419 | " if (self.use_reconstruction):\n",
420 | " self.decoder = Decoder(16, int(np.prod(input_shape)))\n",
421 | " \n",
422 | " def n_parameters(self):\n",
423 | " return np.sum([np.prod(x.size()) for x in self.parameters()])\n",
424 | " \n",
425 | " def forward(self, x, y=None):\n",
426 | " conv1 = self.conv1(x)\n",
427 | " primary_capsules = self.primary_capsules(conv1)\n",
428 | " digit_caps = self.routing(primary_capsules, self.n_routing_iter)\n",
429 | " scores = self.norm(digit_caps)\n",
430 | " \n",
431 | " if (self.use_reconstruction and y is not None):\n",
432 | " reconstruction = self.decoder(digit_caps, y).view((-1,) + self.input_shape)\n",
433 | " return scores, reconstruction\n",
434 | " \n",
435 | " return scores"
436 | ]
437 | },
438 | {
439 | "cell_type": "markdown",
440 | "metadata": {},
441 | "source": [
442 | "# Define Loss Functions"
443 | ]
444 | },
445 | {
446 | "cell_type": "markdown",
447 | "metadata": {},
448 | "source": [
449 | "## Margin Loss"
450 | ]
451 | },
452 | {
453 | "cell_type": "markdown",
454 | "metadata": {},
455 | "source": [
456 | "To allow for multiple digits, we use a separate margin loss, $L_k$ for each digit capsule, $k$:\n",
457 | "\n",
458 | "\\begin{equation*}\n",
459 | "L_k = T_k \\max(0, m^+ - ||\\mathbf{v}_k||)^2 + \\lambda (1 - T_k) \\max(0, ||\\mathbf{v}_k|| - m^-)^2\n",
460 | "\\end{equation*}\n",
461 | "\n",
462 | "where $T_k = 1$ iff a digit of class $k$ is present and $m^+ = 0.9$ and $m^- = 0.1$. The $\\lambda$ down-weighting of the loss for absent digit classes stops the initial learning from shrinking the lengths of the activity vectors of all the digit capsules. We use $\\lambda = 0.5$. The total loss is simply the sum of the losses of all digit capsules."
463 | ]
464 | },
465 | {
466 | "cell_type": "code",
467 | "execution_count": null,
468 | "metadata": {},
469 | "outputs": [],
470 | "source": [
471 | "def to_categorical(y, num_classes):\n",
472 | " \"\"\" 1-hot encodes a tensor \"\"\"\n",
473 | " new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]\n",
474 | " if (y.is_cuda):\n",
475 | " return new_y.cuda()\n",
476 | " return new_y"
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": null,
482 | "metadata": {},
483 | "outputs": [],
484 | "source": [
485 | "class MarginLoss(torch.nn.Module):\n",
486 | " def __init__(self, m_pos=0.9, m_neg=0.1, lamb=0.5):\n",
487 | " super(MarginLoss, self).__init__()\n",
488 | " self.m_pos = m_pos\n",
489 | " self.m_neg = m_neg\n",
490 | " self.lamb = lamb\n",
491 | " \n",
492 | " # Equation (4)\n",
493 | " def forward(self, scores, y):\n",
494 | " y = Variable(to_categorical(y, 10))\n",
495 | " \n",
496 | " Tc = y.float()\n",
497 | " loss_pos = torch.pow(torch.clamp(self.m_pos - scores, min=0), 2)\n",
498 | " loss_neg = torch.pow(torch.clamp(scores - self.m_neg, min=0), 2)\n",
499 | " loss = Tc * loss_pos + self.lamb * (1 - Tc) * loss_neg\n",
500 | " loss = loss.sum(-1)\n",
501 | " return loss.mean()"
502 | ]
503 | },
504 | {
505 | "cell_type": "markdown",
506 | "metadata": {},
507 | "source": [
508 | "## Reconstruction Loss"
509 | ]
510 | },
511 | {
512 | "cell_type": "markdown",
513 | "metadata": {},
514 | "source": [
515 | "We use an additional reconstruction loss to encourage the digit capsules to encode the instantiation parameters of the input digit. (...) We minimize the sum of squared differences between the outputs of the logistic units and the pixel intensities."
516 | ]
517 | },
518 | {
519 | "cell_type": "code",
520 | "execution_count": null,
521 | "metadata": {},
522 | "outputs": [],
523 | "source": [
524 | "class SumSquaredDifferencesLoss(torch.nn.Module):\n",
525 | " def __init__(self):\n",
526 | " super(SumSquaredDifferencesLoss, self).__init__()\n",
527 | " \n",
528 | " def forward(self, x_reconstruction, x):\n",
529 | " loss = torch.pow(x - x_reconstruction, 2).sum(-1).sum(-1)\n",
530 | " return loss.mean()"
531 | ]
532 | },
533 | {
534 | "cell_type": "markdown",
535 | "metadata": {},
536 | "source": [
537 | "## Total Loss"
538 | ]
539 | },
540 | {
541 | "cell_type": "markdown",
542 | "metadata": {},
543 | "source": [
544 | "We scale down this reconstruction loss by $0.0005$ so that it does not dominate the margin loss during training."
545 | ]
546 | },
547 | {
548 | "cell_type": "code",
549 | "execution_count": null,
550 | "metadata": {},
551 | "outputs": [],
552 | "source": [
553 | "class CapsNetLoss(torch.nn.Module):\n",
554 | " def __init__(self, reconstruction_loss_scale=0.0005):\n",
555 | " super(CapsNetLoss, self).__init__()\n",
556 | " self.digit_existance_criterion = MarginLoss()\n",
557 | " self.digit_reconstruction_criterion = SumSquaredDifferencesLoss()\n",
558 | " self.reconstruction_loss_scale = reconstruction_loss_scale\n",
559 | " \n",
560 | " def forward(self, x, y, x_reconstruction, scores):\n",
561 | " margin_loss = self.digit_existance_criterion(y_pred.cuda(), y)\n",
562 | " reconstruction_loss = self.reconstruction_loss_scale *\\\n",
563 | " self.digit_reconstruction_criterion(x_reconstruction, x)\n",
564 | " loss = margin_loss + reconstruction_loss\n",
565 | " return loss, margin_loss, reconstruction_loss"
566 | ]
567 | },
568 | {
569 | "cell_type": "markdown",
570 | "metadata": {},
571 | "source": [
572 | "# Train"
573 | ]
574 | },
575 | {
576 | "cell_type": "markdown",
577 | "metadata": {},
578 | "source": [
579 | "## Model"
580 | ]
581 | },
582 | {
583 | "cell_type": "code",
584 | "execution_count": null,
585 | "metadata": {},
586 | "outputs": [],
587 | "source": [
588 | "model = CapsNet().cuda()\n",
589 | "model"
590 | ]
591 | },
592 | {
593 | "cell_type": "markdown",
594 | "metadata": {},
595 | "source": [
596 | "CapsNet has 8.2M parameters and 6.8M parameters without the reconstruction subnetwork."
597 | ]
598 | },
599 | {
600 | "cell_type": "code",
601 | "execution_count": null,
602 | "metadata": {},
603 | "outputs": [],
604 | "source": [
605 | "print('Number of Parameters: %d' % model.n_parameters())"
606 | ]
607 | },
608 | {
609 | "cell_type": "markdown",
610 | "metadata": {},
611 | "source": [
612 | "## Criterion"
613 | ]
614 | },
615 | {
616 | "cell_type": "code",
617 | "execution_count": null,
618 | "metadata": {},
619 | "outputs": [],
620 | "source": [
621 | "criterion = CapsNetLoss()"
622 | ]
623 | },
624 | {
625 | "cell_type": "markdown",
626 | "metadata": {},
627 | "source": [
628 | "## Optimizer"
629 | ]
630 | },
631 | {
632 | "cell_type": "markdown",
633 | "metadata": {},
634 | "source": [
635 | "(...) we use the Adam optimizer with its TensorFlow default parameters, including the exponentially decaying learning rate, to minimize the sum of the margin losses in Eq. 4."
636 | ]
637 | },
638 | {
639 | "cell_type": "code",
640 | "execution_count": null,
641 | "metadata": {},
642 | "outputs": [],
643 | "source": [
644 | "def exponential_decay(optimizer, learning_rate, global_step, decay_steps, decay_rate, staircase=False):\n",
645 | " if (staircase):\n",
646 | " decayed_learning_rate = learning_rate * np.power(decay_rate, global_step // decay_steps)\n",
647 | " else:\n",
648 | " decayed_learning_rate = learning_rate * np.power(decay_rate, global_step / decay_steps)\n",
649 | " \n",
650 | " for param_group in optimizer.param_groups:\n",
651 | " param_group['lr'] = decayed_learning_rate\n",
652 | " \n",
653 | " return optimizer"
654 | ]
655 | },
656 | {
657 | "cell_type": "code",
658 | "execution_count": null,
659 | "metadata": {},
660 | "outputs": [],
661 | "source": [
662 | "LEARNING_RATE = 0.001\n",
663 | "optimizer = torch.optim.Adam(\n",
664 | " model.parameters(),\n",
665 | " lr=LEARNING_RATE,\n",
666 | " betas=(0.9, 0.999),\n",
667 | " eps=1e-08\n",
668 | ")"
669 | ]
670 | },
671 | {
672 | "cell_type": "markdown",
673 | "metadata": {},
674 | "source": [
675 | "## Training"
676 | ]
677 | },
678 | {
679 | "cell_type": "code",
680 | "execution_count": null,
681 | "metadata": {},
682 | "outputs": [],
683 | "source": [
684 | "def save_checkpoint(epoch, train_accuracy, test_accuracy, model, optimizer, path=None):\n",
685 | " if (path is None):\n",
686 | " path = 'checkpoint-%f-%04d.pth' % (test_accuracy, epoch)\n",
687 | " state = {\n",
688 | " 'epoch': epoch,\n",
689 | " 'train_accuracy': train_accuracy,\n",
690 | " 'test_accuracy': test_accuracy,\n",
691 | " 'model_state_dict': model.state_dict(),\n",
692 | " 'optimizer_state_dict': optimizer.state_dict(),\n",
693 | " }\n",
694 | " torch.save(state, path)"
695 | ]
696 | },
697 | {
698 | "cell_type": "code",
699 | "execution_count": null,
700 | "metadata": {},
701 | "outputs": [],
702 | "source": [
703 | "def show_example(model, x, y, x_reconstruction, y_pred):\n",
704 | " x = x.squeeze().cpu().data.numpy()\n",
705 | " y = y.cpu().data.numpy()\n",
706 | " x_reconstruction = x_reconstruction.squeeze().cpu().data.numpy()\n",
707 | " _, y_pred = torch.max(y_pred, -1)\n",
708 | " y_pred = y_pred.cpu().data.numpy()\n",
709 | " \n",
710 | " fig, ax = plt.subplots(1, 2)\n",
711 | " ax[0].imshow(x, cmap='Greys')\n",
712 | " ax[0].set_title('Input: %d' % y)\n",
713 | " ax[1].imshow(x_reconstruction, cmap='Greys')\n",
714 | " ax[1].set_title('Output: %d' % y_pred)\n",
715 | " plt.show()"
716 | ]
717 | },
718 | {
719 | "cell_type": "code",
720 | "execution_count": null,
721 | "metadata": {},
722 | "outputs": [],
723 | "source": [
724 | "def test(model, loader):\n",
725 | " metrics = defaultdict(lambda:list())\n",
726 | " for batch_id, (x, y) in tqdm(enumerate(loader), total=len(loader)):\n",
727 | " x = Variable(x).float().cuda()\n",
728 | " y = Variable(y).cuda()\n",
729 | " y_pred, x_reconstruction = model(x, y)\n",
730 | " _, y_pred = torch.max(y_pred, -1)\n",
731 | " metrics['accuracy'].append((y_pred == y).cpu().data.numpy())\n",
732 | " metrics['accuracy'] = np.concatenate(metrics['accuracy']).mean()\n",
733 | " return metrics"
734 | ]
735 | },
736 | {
737 | "cell_type": "code",
738 | "execution_count": null,
739 | "metadata": {},
740 | "outputs": [],
741 | "source": [
742 | "global_epoch = 0\n",
743 | "global_step = 0\n",
744 | "best_tst_accuracy = 0.0\n",
745 | "history = defaultdict(lambda:list())\n",
746 | "COMPUTE_TRN_METRICS = False"
747 | ]
748 | },
749 | {
750 | "cell_type": "code",
751 | "execution_count": null,
752 | "metadata": {},
753 | "outputs": [],
754 | "source": [
755 | "n_epochs = 1500 # Number of epochs not specified in the paper\n",
756 | "for epoch in range(n_epochs):\n",
757 | " print('Epoch %d (%d/%d):' % (global_epoch + 1, epoch + 1, n_epochs))\n",
758 | " \n",
759 | " for batch_id, (x, y) in tqdm(enumerate(trn_loader), total=len(trn_loader)):\n",
760 | " optimizer = exponential_decay(optimizer, LEARNING_RATE, global_epoch, 1, 0.90) # Configurations not specified in the paper\n",
761 | " \n",
762 | " x = Variable(x).float().cuda()\n",
763 | " y = Variable(y).cuda()\n",
764 | " \n",
765 | " y_pred, x_reconstruction = model(x, y)\n",
766 | " loss, margin_loss, reconstruction_loss = criterion(x, y, x_reconstruction, y_pred.cuda())\n",
767 | " \n",
768 | " history['margin_loss'].append(margin_loss.cpu().data.numpy())\n",
769 | " history['reconstruction_loss'].append(reconstruction_loss.cpu().data.numpy())\n",
770 | " history['loss'].append(loss.cpu().data.numpy())\n",
771 | " \n",
772 | " optimizer.zero_grad()\n",
773 | " loss.backward()\n",
774 | " optimizer.step()\n",
775 | " \n",
776 | " global_step += 1\n",
777 | "\n",
778 | " trn_metrics = test(model, trn_loader) if COMPUTE_TRN_METRICS else None\n",
779 | " tst_metrics = test(model, tst_loader)\n",
780 | " \n",
781 | " print('Margin Loss: %f' % history['margin_loss'][-1])\n",
782 | " print('Reconstruction Loss: %f' % history['reconstruction_loss'][-1])\n",
783 | " print('Loss: %f' % history['loss'][-1])\n",
784 | " print('Train Accuracy: %f' % (trn_metrics['accuracy'] if COMPUTE_TRN_METRICS else 0.0))\n",
785 | " print('Test Accuracy: %f' % tst_metrics['accuracy'])\n",
786 | " \n",
787 | " print('Example:')\n",
788 | " idx = np.random.randint(0, len(x))\n",
789 | " show_example(model, x[idx], y[idx], x_reconstruction[idx], y_pred[idx])\n",
790 | " \n",
791 | " if (tst_metrics['accuracy'] >= best_tst_accuracy):\n",
792 | " best_tst_accuracy = tst_metrics['accuracy']\n",
793 | " save_checkpoint(\n",
794 | " global_epoch + 1,\n",
795 | " trn_metrics['accuracy'] if COMPUTE_TRN_METRICS else 0.0,\n",
796 | " tst_metrics['accuracy'],\n",
797 | " model,\n",
798 | " optimizer\n",
799 | " )\n",
800 | " global_epoch += 1"
801 | ]
802 | },
803 | {
804 | "cell_type": "markdown",
805 | "metadata": {},
806 | "source": [
807 | "## Loss Curve"
808 | ]
809 | },
810 | {
811 | "cell_type": "code",
812 | "execution_count": null,
813 | "metadata": {},
814 | "outputs": [],
815 | "source": [
816 | "def compute_avg_curve(y, n_points_avg):\n",
817 | " avg_kernel = np.ones((n_points_avg,)) / n_points_avg\n",
818 | " rolling_mean = np.convolve(y, avg_kernel, mode='valid')\n",
819 | " return rolling_mean"
820 | ]
821 | },
822 | {
823 | "cell_type": "code",
824 | "execution_count": null,
825 | "metadata": {},
826 | "outputs": [],
827 | "source": [
828 | "n_points_avg = 10\n",
829 | "n_points_plot = 1000\n",
830 | "plt.figure(figsize=(20, 10))\n",
831 | "\n",
832 | "curve = np.asarray(history['loss'])[-n_points_plot:]\n",
833 | "avg_curve = compute_avg_curve(curve, n_points_avg)\n",
834 | "plt.plot(avg_curve, '-g')\n",
835 | "\n",
836 | "curve = np.asarray(history['margin_loss'])[-n_points_plot:]\n",
837 | "avg_curve = compute_avg_curve(curve, n_points_avg)\n",
838 | "plt.plot(avg_curve, '-b')\n",
839 | "\n",
840 | "curve = np.asarray(history['reconstruction_loss'])[-n_points_plot:]\n",
841 | "avg_curve = compute_avg_curve(curve, n_points_avg)\n",
842 | "plt.plot(avg_curve, '-r')\n",
843 | "\n",
844 | "plt.legend(['Total Loss', 'Margin Loss', 'Reconstruction Loss'])\n",
845 | "plt.show()"
846 | ]
847 | },
848 | {
849 | "cell_type": "markdown",
850 | "metadata": {},
851 | "source": [
852 | "Done!"
853 | ]
854 | }
855 | ],
856 | "metadata": {
857 | "kernelspec": {
858 | "display_name": "Python 3",
859 | "language": "python",
860 | "name": "python3"
861 | },
862 | "language_info": {
863 | "codemirror_mode": {
864 | "name": "ipython",
865 | "version": 3
866 | },
867 | "file_extension": ".py",
868 | "mimetype": "text/x-python",
869 | "name": "python",
870 | "nbconvert_exporter": "python",
871 | "pygments_lexer": "ipython3",
872 | "version": "3.7.7"
873 | },
874 | "toc": {
875 | "nav_menu": {
876 | "height": "177px",
877 | "width": "219px"
878 | },
879 | "number_sections": true,
880 | "sideBar": true,
881 | "skip_h1_title": false,
882 | "toc_cell": false,
883 | "toc_position": {
884 | "height": "659px",
885 | "left": "0px",
886 | "right": "1007.8px",
887 | "top": "133px",
888 | "width": "241px"
889 | },
890 | "toc_section_display": "block",
891 | "toc_window_display": true
892 | }
893 | },
894 | "nbformat": 4,
895 | "nbformat_minor": 4
896 | }
897 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Arthur Crippa Búrigo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CapsNet
2 | This is a [pytorch](http://pytorch.org/) implementation of CapsNet, described in the paper [Dynamic Routing Between Capsules](https://arxiv.org/abs/1710.09829) - by [Sara Sabour](https://arxiv.org/find/cs/1/au:+Sabour_S/0/1/0/all/0/1), [Nicholas Frosst](https://arxiv.org/find/cs/1/au:+Frosst_N/0/1/0/all/0/1) and [Geoffrey E Hinton](https://arxiv.org/find/cs/1/au:+Hinton_G/0/1/0/all/0/1).
3 |
4 | ## MNIST
5 | ### Accuracy
6 | Although the paper reports an accuracy of *99.75%*, the maximum accuracy achieved by this implementation was **99.68%**.
7 |
8 | ### Execution Speed
9 | 111 seconds per epoch on a single Titan Xp GPU.
10 |
11 | ### Number of Parameters
12 | Model has 8141840 parameters.
13 |
--------------------------------------------------------------------------------
/images/capsulearch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acburigo/CapsNet/d805642a7a0c89f71e57e405a3f8a7ce609af24e/images/capsulearch.png
--------------------------------------------------------------------------------
/images/reconsArch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acburigo/CapsNet/d805642a7a0c89f71e57e405a3f8a7ce609af24e/images/reconsArch.png
--------------------------------------------------------------------------------