├── .gitignore ├── Capsule_Network.ipynb ├── LICENSE ├── README.md ├── assets ├── capsule_decoder.png ├── capsule_encoder.png ├── cat_face_2.png ├── complete_caps_net.png ├── coupling_coeff.png ├── dynamic_routing.png └── perturbed_reconstructions.png └── helpers.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Cezanne Camacho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Capsule Network 2 | 3 | Readable implementation of a Capsule Network as described in "Dynamic Routing Between Capsules" [Hinton et. al.] 4 | 5 | In this notebook, I'll be building a simple Capsule Network that aims to classify MNIST images. 6 | This is an implementation in PyTorch and this notebook assumes that you are already familiar with [convolutional and fully-connected layers](https://cezannec.github.io/Convolutional_Neural_Networks/). 7 | 8 | ### What are Capsules? 9 | 10 | Capsules are a small group of neurons that have a few key traits: 11 | * Each neuron in a capsule represents various properties of a particular image part; properties like a parts color, width, etc. 12 | * Every capsule **outputs a vector**, which has some magnitude (that represents a part's **existence**) and orientation (that represents a part's generalized pose). 13 | * A capsule network is made of multiple layers of capsules; during training, this network aims to learn the spatial relationships between the parts and whole of an object (ex. how the position of eyes and a nose relate to the position of a whole face in an image). 14 | * Capsules represent relationships between parts of a whole object by using **dynamic routing** to weight the connections between one layer of capsules and the next and creating strong connections between spatially-related object parts. 15 | 16 |

17 | 18 |

19 | 20 | You can read more about all of these traits in [my blog post about capsules and dynamic routing](https://cezannec.github.io/Capsule_Networks/). 21 | 22 | ### Representing Relationships Between Parts 23 | 24 | All of these traits allow capsules to communicate with each other and determine how data moves through them. 25 | Using dynamic communication, during the training process, a capsule network learns the **spatial relationships** between visual parts and their wholes (ex. between eyes, a nose, and a mouth on a face). 26 | When compared to a vanilla CNN, this knowledge about spatial relationships makes it easier for a capsule network to identify an object no matter what orientation it is in. 27 | These networks are also, generally, better able to identify multiple, overlapping objects, and to learn from smaller sets of training data! 28 | 29 | --- 30 | ## Model Architecture 31 | 32 | The Capsule Network that I'll define is made of two main parts: 33 | 1. A convolutional encoder 34 | 2. A fully-connected, linear decoder 35 | 36 |

37 | 38 |

39 | 40 | The above image was taken from the original [Capsule Network paper (Hinton et. al.)](https://arxiv.org/pdf/1710.09829.pdf). The notebook follows the architecture described in that paper and tries to replicate some of the experiments, such as feature visualization, that the authors pursued. 41 | 42 | --- 43 | ## Running Code Locally 44 | 45 | If you're interested in running this code on your own computer, there are thorough instructions on setting up anaconda, and downloading PyTorch and the necessary libraries in the [readme of Udacity's deep learning repo](https://github.com/udacity/deep-learning-v2-pytorch/blob/master/README.md). After downloading the necessary libraries, you can proceed with cloning and running this code, as usual. 46 | -------------------------------------------------------------------------------- /assets/capsule_decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/capsule_decoder.png -------------------------------------------------------------------------------- /assets/capsule_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/capsule_encoder.png -------------------------------------------------------------------------------- /assets/cat_face_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/cat_face_2.png -------------------------------------------------------------------------------- /assets/complete_caps_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/complete_caps_net.png -------------------------------------------------------------------------------- /assets/coupling_coeff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/coupling_coeff.png -------------------------------------------------------------------------------- /assets/dynamic_routing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/dynamic_routing.png -------------------------------------------------------------------------------- /assets/perturbed_reconstructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/607b7ab12daaebc9d8ae8aef305dc76cbc765787/assets/perturbed_reconstructions.png -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def softmax(input_tensor, dim=1): 5 | # transpose input 6 | transposed_input = input_tensor.transpose(dim, len(input_tensor.size()) - 1) 7 | # calculate softmax 8 | softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1) 9 | # un-transpose result 10 | return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input_tensor.size()) - 1) --------------------------------------------------------------------------------