├── .gitignore ├── assets ├── cat_face_2.png ├── capsule_decoder.png ├── capsule_encoder.png ├── coupling_coeff.png ├── dynamic_routing.png ├── complete_caps_net.png └── perturbed_reconstructions.png ├── helpers.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | -------------------------------------------------------------------------------- /assets/cat_face_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/HEAD/assets/cat_face_2.png -------------------------------------------------------------------------------- /assets/capsule_decoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/HEAD/assets/capsule_decoder.png -------------------------------------------------------------------------------- /assets/capsule_encoder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/HEAD/assets/capsule_encoder.png -------------------------------------------------------------------------------- /assets/coupling_coeff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/HEAD/assets/coupling_coeff.png -------------------------------------------------------------------------------- /assets/dynamic_routing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/HEAD/assets/dynamic_routing.png -------------------------------------------------------------------------------- /assets/complete_caps_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/HEAD/assets/complete_caps_net.png -------------------------------------------------------------------------------- /assets/perturbed_reconstructions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cezannec/capsule_net_pytorch/HEAD/assets/perturbed_reconstructions.png -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def softmax(input_tensor, dim=1): 5 | # transpose input 6 | transposed_input = input_tensor.transpose(dim, len(input_tensor.size()) - 1) 7 | # calculate softmax 8 | softmaxed_output = F.softmax(transposed_input.contiguous().view(-1, transposed_input.size(-1)), dim=-1) 9 | # un-transpose result 10 | return softmaxed_output.view(*transposed_input.size()).transpose(dim, len(input_tensor.size()) - 1) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Cezanne Camacho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Capsule Network 2 | 3 | Readable implementation of a Capsule Network as described in "Dynamic Routing Between Capsules" [Hinton et. al.] 4 | 5 | In this notebook, I'll be building a simple Capsule Network that aims to classify MNIST images. 6 | This is an implementation in PyTorch and this notebook assumes that you are already familiar with [convolutional and fully-connected layers](https://cezannec.github.io/Convolutional_Neural_Networks/). 7 | 8 | ### What are Capsules? 9 | 10 | Capsules are a small group of neurons that have a few key traits: 11 | * Each neuron in a capsule represents various properties of a particular image part; properties like a parts color, width, etc. 12 | * Every capsule **outputs a vector**, which has some magnitude (that represents a part's **existence**) and orientation (that represents a part's generalized pose). 13 | * A capsule network is made of multiple layers of capsules; during training, this network aims to learn the spatial relationships between the parts and whole of an object (ex. how the position of eyes and a nose relate to the position of a whole face in an image). 14 | * Capsules represent relationships between parts of a whole object by using **dynamic routing** to weight the connections between one layer of capsules and the next and creating strong connections between spatially-related object parts. 15 | 16 |
17 |
18 |
37 |
38 |