├── .gitignore
├── Capsule_Network.ipynb
├── LICENSE
├── README.md
├── assets
├── capsule_decoder.png
├── capsule_encoder.png
├── cat_face_2.png
├── complete_caps_net.png
├── coupling_coeff.png
├── dynamic_routing.png
└── perturbed_reconstructions.png
└── helpers.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Cezanne Camacho
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Capsule Network
2 |
3 | Readable implementation of a Capsule Network as described in "Dynamic Routing Between Capsules" [Hinton et. al.]
4 |
5 | In this notebook, I'll be building a simple Capsule Network that aims to classify MNIST images.
6 | This is an implementation in PyTorch and this notebook assumes that you are already familiar with [convolutional and fully-connected layers](https://cezannec.github.io/Convolutional_Neural_Networks/).
7 |
8 | ### What are Capsules?
9 |
10 | Capsules are a small group of neurons that have a few key traits:
11 | * Each neuron in a capsule represents various properties of a particular image part; properties like a parts color, width, etc.
12 | * Every capsule **outputs a vector**, which has some magnitude (that represents a part's **existence**) and orientation (that represents a part's generalized pose).
13 | * A capsule network is made of multiple layers of capsules; during training, this network aims to learn the spatial relationships between the parts and whole of an object (ex. how the position of eyes and a nose relate to the position of a whole face in an image).
14 | * Capsules represent relationships between parts of a whole object by using **dynamic routing** to weight the connections between one layer of capsules and the next and creating strong connections between spatially-related object parts.
15 |
16 |
17 |
18 |
19 |
20 | You can read more about all of these traits in [my blog post about capsules and dynamic routing](https://cezannec.github.io/Capsule_Networks/).
21 |
22 | ### Representing Relationships Between Parts
23 |
24 | All of these traits allow capsules to communicate with each other and determine how data moves through them.
25 | Using dynamic communication, during the training process, a capsule network learns the **spatial relationships** between visual parts and their wholes (ex. between eyes, a nose, and a mouth on a face).
26 | When compared to a vanilla CNN, this knowledge about spatial relationships makes it easier for a capsule network to identify an object no matter what orientation it is in.
27 | These networks are also, generally, better able to identify multiple, overlapping objects, and to learn from smaller sets of training data!
28 |
29 | ---
30 | ## Model Architecture
31 |
32 | The Capsule Network that I'll define is made of two main parts:
33 | 1. A convolutional encoder
34 | 2. A fully-connected, linear decoder
35 |
36 |
37 |
38 |
39 |
40 | The above image was taken from the original [Capsule Network paper (Hinton et. al.)](https://arxiv.org/pdf/1710.09829.pdf). The notebook follows the architecture described in that paper and tries to replicate some of the experiments, such as feature visualization, that the authors pursued.
41 |
42 | ---
43 | ## Running Code Locally
44 |
45 | If you're interested in running this code on your own computer, there are thorough instructions on setting up anaconda, and downloading PyTorch and the necessary libraries in the [readme of Udacity's deep learning repo](https://github.com/udacity/deep-learning-v2-pytorch/blob/master/README.md). After downloading the necessary libraries, you can proceed with cloning and running this code, as usual.
46 |
--------------------------------------------------------------------------------
/assets/capsule_decoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/capsule_decoder.png
--------------------------------------------------------------------------------
/assets/capsule_encoder.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/capsule_encoder.png
--------------------------------------------------------------------------------
/assets/cat_face_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/cat_face_2.png
--------------------------------------------------------------------------------
/assets/complete_caps_net.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/complete_caps_net.png
--------------------------------------------------------------------------------
/assets/coupling_coeff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/coupling_coeff.png
--------------------------------------------------------------------------------
/assets/dynamic_routing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/dynamic_routing.png
--------------------------------------------------------------------------------
/assets/perturbed_reconstructions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/perturbed_reconstructions.png
--------------------------------------------------------------------------------
/helpers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | def softmax(input_tensor, dim=1):
5 | # transpose input
6 | transposed_input = input_tensor.transpose(dim, len(input_tensor.size()) - 1)
7 | # calculate softmax
8 | softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1)
9 | # un-transpose result
10 | return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input_tensor.size()) - 1)
--------------------------------------------------------------------------------