├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── resnet_backbone_CIFAR10.json └── resnet_backbone_CIFAR100.json ├── imgs ├── concurrent.png ├── inverted_attention.png └── overall.png ├── main_capsule.py ├── src ├── __init__.py ├── capsule_model.py └── layers.py └── utils.py /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software have utilized the following copyrighted material, the use of which is hereby acknowledged. 3 | 4 | _____________________ 5 | PyTorch (https://pytorch.org) 6 | We use PyTorch as the training framework for our model. 7 | 8 | From PyTorch: 9 | 10 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 11 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 12 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 13 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 14 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 15 | Copyright (c) 2011-2013 NYU (Clement Farabet) 16 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 17 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 18 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 19 | 20 | From Caffe2: 21 | 22 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 23 | 24 | All contributions by Facebook: 25 | Copyright (c) 2016 Facebook Inc. 26 | 27 | All contributions by Google: 28 | Copyright (c) 2015 Google Inc. 29 | All rights reserved. 30 | 31 | All contributions by Yangqing Jia: 32 | Copyright (c) 2015 Yangqing Jia 33 | All rights reserved. 34 | 35 | All contributions from Caffe: 36 | Copyright(c) 2013, 2014, 2015, the respective contributors 37 | All rights reserved. 38 | 39 | All other contributions: 40 | Copyright(c) 2015, 2016 the respective contributors 41 | All rights reserved. 42 | 43 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 44 | copyright over their contributions to Caffe2. The project versioning records 45 | all such contribution and copyright details. If a contributor wants to further 46 | mark their specific copyright on a particular contribution, they should 47 | indicate their copyright solely in the commit message of the change when it is 48 | committed. 49 | 50 | All rights reserved. 51 | 52 | Redistribution and use in source and binary forms, with or without 53 | modification, are permitted provided that the following conditions are met: 54 | 55 | 1. Redistributions of source code must retain the above copyright 56 | notice, this list of conditions and the following disclaimer. 57 | 58 | 2. Redistributions in binary form must reproduce the above copyright 59 | notice, this list of conditions and the following disclaimer in the 60 | documentation and/or other materials provided with the distribution. 61 | 62 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 63 | and IDIAP Research Institute nor the names of its contributors may be 64 | used to endorse or promote products derived from this software without 65 | specific prior written permission. 66 | 67 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 68 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 69 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 70 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 71 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 72 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 73 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 74 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 75 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 76 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 77 | POSSIBILITY OF SUCH DAMAGE. 78 | 79 | _____________________ 80 | Soumith Chintala (TorchVision: https://github.com/pytorch/vision/) 81 | TorchVision is used for handling data loading and IO utilities, which is distributed under BSD 3-Clause License. 82 | 83 | Copyright (c) Soumith Chintala 2016, 84 | All rights reserved. 85 | 86 | Redistribution and use in source and binary forms, with or without 87 | modification, are permitted provided that the following conditions are met: 88 | 89 | * Redistributions of source code must retain the above copyright notice, this 90 | list of conditions and the following disclaimer. 91 | 92 | * Redistributions in binary form must reproduce the above copyright notice, 93 | this list of conditions and the following disclaimer in the documentation 94 | and/or other materials provided with the distribution. 95 | 96 | * Neither the name of the copyright holder nor the names of its 97 | contributors may be used to endorse or promote products derived from 98 | this software without specific prior written permission. 99 | 100 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 101 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 102 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 103 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 104 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 105 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 106 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 107 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 108 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 109 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 110 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | ## Before you get started 6 | 7 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED IN THIS REPOSITORY: 43 | 44 | This software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 2 | 3 | # Capsules with Inverted Dot-Product Attention Routing 4 | 5 | > Pytorch implementation for Capsules with Inverted Dot-Product Attention Routing. 6 | 7 | 8 | ## Paper 9 | [**Capsules with Inverted Dot-Product Attention Routing**](https://openreview.net/pdf?id=HJe6uANtwH)
10 | [Yao-Hung Hubert Tsai](https://yaohungt.github.io), Nitish Srivastava, Hanlin Goh, and [Ruslan Salakhutdinov](https://www.cs.cmu.edu/~rsalakhu/)
11 | International Conference on Learning Representations (ICLR), 2020. 12 | 13 | Please cite our paper if you find our work useful for your research: 14 | 15 | ```tex 16 | @inproceedings{tsai2020Capsules, 17 | title={Capsules with Inverted Dot-Product Attention Routing}, 18 | author={Tsai, Yao-Hung Hubert and Srivastava, Nitish and Goh, Hanlin and Salakhutdinov, Ruslan}, 19 | booktitle={International Conference on Learning Representations (ICLR)}, 20 | year={2020}, 21 | } 22 | ``` 23 | 24 | ## Overview 25 | 26 | ### Overall Architecture 27 |

28 | 29 | 30 | An example of our proposed architecture is shown above. The backbone is a standard feed-forward convolutional neural network. The features extracted from this network are fed through another convolutional layer. At each spatial location, groups of 16 channels are made to create capsules (we assume a 16-dimensional pose in a capsule). LayerNorm is then applied across the 16 channels to obtain the primary capsules. This is followed by two convolutional capsule layers, and then by two fully-connected capsule layers. In the last capsule layer, each capsule corresponds to a class. These capsules are then used to compute logits that feed into a softmax to computed the classification probabilities. Inference in this network requires a feed-forward pass up to the primary capsules. After this, our proposed routing mechanism (discussed later) takes over. 31 | 32 | ### Inverted Dot-Product Attention Routing 33 |

34 | 35 | 36 | In our method, the routing procedure resembles an inverted attention mechanism, where dot products are used to measure agreement. Specifically, the higher-level (parent) units compete for the attention of the lower-level (child) units, instead of the other way around, which is commonly used in attention models. Hence, the routing probability directly depends on the agreement between the parent’s pose (from the previous iteration step) and the child’s vote for the parent’s pose (in the current iteration step). We (1) use Layer Normalization (Ba et al., 2016) as normalization, and we (2) perform inference of the latent capsule states and routing probabilities jointly across multiple capsule layers (instead of doing it layer-wise). 37 | 38 | ### Concurrent Routing 39 |

40 | 41 | 42 | The concurrent routing is a parallel-in-time routing procedure for all capsules layers. 43 | 44 | ## Usage 45 | 46 | ### Prerequisites 47 | - Python 3.6/3.7 48 | - [Pytorch (>=1.2.0) and torchvision](https://pytorch.org/) 49 | - CUDA 10.0 or above 50 | 51 | ### Datasets 52 | 53 | We use [CIFAR10 and CIFAR100](https://www.cs.toronto.edu/~kriz/cifar.html). 54 | 55 | ### Run the Code 56 | 57 | #### Arguments 58 | 59 | | Args | Value | help | 60 | |:-----------:|:---------------------------------------------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| 61 | | debug | - | Enter into a debug mode, which means no models and results will be saved. True or False | 62 | | num_routing | 1 | The number of routing iteration. The number should > 1. | 63 | | dataset | CIFAR10 | Choice of the dataset. CIFAR10 or CIFAR100. | 64 | | backbone | resnet | Choice of the backbone. simple or resnet. | 65 | | config_path | ./configs/resnet_backbone_CIFAR10.json | Configurations for capsule layers. | 66 | 67 | #### Running CIFAR-10 68 | 69 | ```bash 70 | python main_capsule.py --num_routing 2 --dataset CIFAR10 --backbone resnet --config_path ./configs/resnet_backbone_CIFAR10.json 71 | ``` 72 | When ```num_routing``` is ```1```, the average performance we obtained is _94.73%_. 73 | 74 | When ```num_routing``` is ```2```, the average performance we obtained is _94.85%_ and the best model we obtained is _95.14%_. 75 | 76 | 77 | #### Running CIFAR-100 78 | 79 | ```bash 80 | python main_capsule.py --num_routing 2 --dataset CIFAR100 --backbone resnet --config_path ./configs/resnet_backbone_CIFAR100.json 81 | ``` 82 | 83 | When ```num_routing``` is ```1```, the average performance we obtained is _76.02%_. 84 | 85 | When ```num_routing``` is ```2```, the average performance we obtained is _76.27%_ and the best model we obtained is _78.02%_. 86 | 87 | 88 | ## License 89 | This code is released under the [LICENSE](LICENSE) terms. 90 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /configs/resnet_backbone_CIFAR10.json: -------------------------------------------------------------------------------- 1 | { 2 | "backbone": { 3 | "kernel_size": 3, 4 | "output_dim": 128, 5 | "input_dim": 3, 6 | "stride": 2, 7 | "padding": 1, 8 | "out_img_size": 16 9 | }, 10 | "primary_capsules": { 11 | "kernel_size": 1, 12 | "stride": 1, 13 | "input_dim": 128, 14 | "caps_dim": 16, 15 | "num_caps": 32, 16 | "padding": 0, 17 | "out_img_size": 16 18 | }, 19 | "capsules": [ 20 | { 21 | "type" : "CONV", 22 | "num_caps": 32, 23 | "caps_dim": 16, 24 | "kernel_size": 3, 25 | "stride": 2, 26 | "matrix_pose": true, 27 | "out_img_size": 7 28 | }, 29 | { 30 | "type": "CONV", 31 | "num_caps": 32, 32 | "caps_dim": 16, 33 | "kernel_size": 3, 34 | "stride": 1, 35 | "matrix_pose": true, 36 | "out_img_size": 5 37 | } 38 | ], 39 | "class_capsules": { 40 | "num_caps": 10, 41 | "caps_dim": 16, 42 | "matrix_pose": true 43 | } 44 | } -------------------------------------------------------------------------------- /configs/resnet_backbone_CIFAR100.json: -------------------------------------------------------------------------------- 1 | { 2 | "backbone": { 3 | "kernel_size": 3, 4 | "output_dim": 128, 5 | "input_dim": 3, 6 | "stride": 2, 7 | "padding": 1, 8 | "out_img_size": 16 9 | }, 10 | "primary_capsules": { 11 | "kernel_size": 1, 12 | "stride": 1, 13 | "input_dim": 128, 14 | "caps_dim": 36, 15 | "num_caps": 32, 16 | "padding": 0, 17 | "out_img_size": 16 18 | }, 19 | "capsules": [ 20 | { 21 | "type" : "CONV", 22 | "num_caps": 32, 23 | "caps_dim": 36, 24 | "kernel_size": 3, 25 | "stride": 2, 26 | "matrix_pose": true, 27 | "out_img_size": 7 28 | }, 29 | { 30 | "type": "CONV", 31 | "num_caps": 32, 32 | "caps_dim": 36, 33 | "kernel_size": 3, 34 | "stride": 1, 35 | "matrix_pose": true, 36 | "out_img_size": 5 37 | }, 38 | { 39 | "type": "FC", 40 | "num_caps": 20, 41 | "caps_dim": 36, 42 | "matrix_pose": true 43 | } 44 | ], 45 | "class_capsules": { 46 | "num_caps": 100, 47 | "caps_dim": 36, 48 | "matrix_pose": true 49 | } 50 | } -------------------------------------------------------------------------------- /imgs/concurrent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-capsules-inverted-attention-routing/53899ae5b75b576a693a1bd79a2173ac3f1cdf1d/imgs/concurrent.png -------------------------------------------------------------------------------- /imgs/inverted_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-capsules-inverted-attention-routing/53899ae5b75b576a693a1bd79a2173ac3f1cdf1d/imgs/inverted_attention.png -------------------------------------------------------------------------------- /imgs/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-capsules-inverted-attention-routing/53899ae5b75b576a693a1bd79a2173ac3f1cdf1d/imgs/overall.png -------------------------------------------------------------------------------- /main_capsule.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All rights reserved. 4 | # 5 | '''Train CIFAR10 with PyTorch.''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import torch.backends.cudnn as cudnn 11 | 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | 15 | import os 16 | import argparse 17 | 18 | from src import capsule_model 19 | from utils import progress_bar 20 | import pickle 21 | import json 22 | 23 | from datetime import datetime 24 | 25 | # + 26 | parser = argparse.ArgumentParser(description='Training Capsules using Inverted Dot-Product Attention Routing') 27 | 28 | parser.add_argument('--resume_dir', '-r', default='', type=str, help='dir where we resume from checkpoint') 29 | parser.add_argument('--num_routing', default=1, type=int, help='number of routing. Recommended: 0,1,2,3.') 30 | parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset. CIFAR10 or CIFAR100.') 31 | parser.add_argument('--backbone', default='resnet', type=str, help='type of backbone. simple or resnet') 32 | parser.add_argument('--num_workers', default=2, type=int, help='number of workers. 0 or 2') 33 | parser.add_argument('--config_path', default='./configs/full_rank_2C1F_matrix_for_iterations.json', type=str, help='path of the config') 34 | parser.add_argument('--debug', action='store_true', 35 | help='use debug mode (without saving to a directory)') 36 | parser.add_argument('--sequential_routing', action='store_true', help='not using concurrent_routing') 37 | 38 | 39 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate. 0.1 for SGD') 40 | parser.add_argument('--dp', default=0.0, type=float, help='dropout rate') 41 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') 42 | # - 43 | 44 | args = parser.parse_args() 45 | assert args.num_routing > 0 46 | 47 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 48 | best_acc = 0 # best test accuracy 49 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 50 | 51 | # Data 52 | print('==> Preparing data..') 53 | assert args.dataset == 'CIFAR10' or args.dataset == 'CIFAR100' 54 | transform_train = transforms.Compose([ 55 | transforms.RandomCrop(32, padding=4), 56 | transforms.RandomHorizontalFlip(), 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 59 | ]) 60 | transform_test = transforms.Compose([ 61 | transforms.ToTensor(), 62 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 63 | ]) 64 | 65 | trainset = getattr(torchvision.datasets, args.dataset)(root='../data', train=True, download=True, transform=transform_train) 66 | testset = getattr(torchvision.datasets, args.dataset)(root='../data', train=False, download=True, transform=transform_test) 67 | num_class = int(args.dataset.split('CIFAR')[1]) 68 | 69 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=args.num_workers) 70 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=args.num_workers) 71 | 72 | print('==> Building model..') 73 | # Model parameters 74 | 75 | image_dim_size = 32 76 | 77 | with open(args.config_path, 'rb') as file: 78 | params = json.load(file) 79 | 80 | net = capsule_model.CapsModel(image_dim_size, 81 | params, 82 | args.backbone, 83 | args.dp, 84 | args.num_routing, 85 | sequential_routing=args.sequential_routing) 86 | 87 | # + 88 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) 89 | 90 | lr_decay = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1) 91 | 92 | 93 | # - 94 | 95 | def count_parameters(model): 96 | for name, param in model.named_parameters(): 97 | if param.requires_grad: 98 | print(name, param.numel()) 99 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 100 | 101 | print(net) 102 | total_params = count_parameters(net) 103 | print(total_params) 104 | 105 | if not os.path.isdir('results') and not args.debug: 106 | os.mkdir('results') 107 | if not args.debug: 108 | store_dir = os.path.join('results', datetime.today().strftime('%Y-%m-%d-%H-%M-%S')) 109 | os.mkdir(store_dir) 110 | 111 | net = net.to(device) 112 | if device == 'cuda': 113 | net = torch.nn.DataParallel(net) 114 | cudnn.benchmark = True 115 | 116 | loss_func = nn.CrossEntropyLoss() 117 | 118 | if args.resume_dir and not args.debug: 119 | # Load checkpoint. 120 | print('==> Resuming from checkpoint..') 121 | checkpoint = torch.load(os.path.join(args.resume_dir, 'ckpt.pth')) 122 | net.load_state_dict(checkpoint['net']) 123 | best_acc = checkpoint['acc'] 124 | start_epoch = checkpoint['epoch'] 125 | 126 | # Training 127 | def train(epoch): 128 | print('\nEpoch: %d' % epoch) 129 | net.train() 130 | train_loss = 0 131 | correct = 0 132 | total = 0 133 | for batch_idx, (inputs, targets) in enumerate(trainloader): 134 | inputs = inputs.to(device) 135 | 136 | targets = targets.to(device) 137 | 138 | optimizer.zero_grad() 139 | 140 | v = net(inputs) 141 | 142 | loss = loss_func(v, targets) 143 | 144 | loss.backward() 145 | optimizer.step() 146 | 147 | train_loss += loss.item() 148 | _, predicted = v.max(dim=1) 149 | 150 | total += targets.size(0) 151 | 152 | correct += predicted.eq(targets).sum().item() 153 | 154 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 155 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total)) 156 | return 100.*correct/total 157 | 158 | def test(epoch): 159 | global best_acc 160 | net.eval() 161 | test_loss = 0 162 | correct = 0 163 | total = 0 164 | with torch.no_grad(): 165 | for batch_idx, (inputs, targets) in enumerate(testloader): 166 | inputs = inputs.to(device) 167 | 168 | targets = targets.to(device) 169 | 170 | v = net(inputs) 171 | 172 | loss = loss_func(v, targets) 173 | 174 | test_loss += loss.item() 175 | 176 | _, predicted = v.max(dim=1) 177 | 178 | total += targets.size(0) 179 | 180 | correct += predicted.eq(targets).sum().item() 181 | 182 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' 183 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total)) 184 | 185 | # Save checkpoint. 186 | acc = 100.*correct/total 187 | if acc > best_acc and not args.debug: 188 | print('Saving..') 189 | state = { 190 | 'net': net.state_dict(), 191 | 'acc': acc, 192 | 'epoch': epoch, 193 | } 194 | torch.save(state, os.path.join(store_dir, 'ckpt.pth')) 195 | best_acc = acc 196 | return 100.*correct/total 197 | 198 | # + 199 | results = { 200 | 'total_params': total_params, 201 | 'args': args, 202 | 'params': params, 203 | 'train_acc': [], 204 | 'test_acc': [], 205 | } 206 | 207 | total_epochs = 350 208 | 209 | for epoch in range(start_epoch, start_epoch+total_epochs): 210 | results['train_acc'].append(train(epoch)) 211 | 212 | lr_decay.step() 213 | results['test_acc'].append(test(epoch)) 214 | # - 215 | 216 | if not args.debug: 217 | store_file = os.path.join(store_dir, 'dataset_' + str(args.dataset) + '_num_routing_' + str(args.num_routing) + \ 218 | '_backbone_' + args.backbone + '.dct') 219 | 220 | pickle.dump(results, open(store_file, 'wb')) 221 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | -------------------------------------------------------------------------------- /src/capsule_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | from src import layers 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch 9 | 10 | # Capsule model 11 | class CapsModel(nn.Module): 12 | def __init__(self, 13 | image_dim_size, 14 | params, 15 | backbone, 16 | dp, 17 | num_routing, 18 | sequential_routing=True): 19 | 20 | super(CapsModel, self).__init__() 21 | #### Parameters 22 | self.sequential_routing = sequential_routing 23 | 24 | ## Primary Capsule Layer 25 | self.pc_num_caps = params['primary_capsules']['num_caps'] 26 | self.pc_caps_dim = params['primary_capsules']['caps_dim'] 27 | self.pc_output_dim = params['primary_capsules']['out_img_size'] 28 | ## General 29 | self.num_routing = num_routing # >3 may cause slow converging 30 | 31 | #### Building Networks 32 | ## Backbone (before capsule) 33 | if backbone == 'simple': 34 | self.pre_caps = layers.simple_backbone(params['backbone']['input_dim'], 35 | params['backbone']['output_dim'], 36 | params['backbone']['kernel_size'], 37 | params['backbone']['stride'], 38 | params['backbone']['padding']) 39 | elif backbone == 'resnet': 40 | self.pre_caps = layers.resnet_backbone(params['backbone']['input_dim'], 41 | params['backbone']['output_dim'], 42 | params['backbone']['stride']) 43 | 44 | ## Primary Capsule Layer (a single CNN) 45 | self.pc_layer = nn.Conv2d(in_channels=params['primary_capsules']['input_dim'], 46 | out_channels=params['primary_capsules']['num_caps'] *\ 47 | params['primary_capsules']['caps_dim'], 48 | kernel_size=params['primary_capsules']['kernel_size'], 49 | stride=params['primary_capsules']['stride'], 50 | padding=params['primary_capsules']['padding'], 51 | bias=False) 52 | 53 | #self.pc_layer = nn.Sequential() 54 | 55 | self.nonlinear_act = nn.LayerNorm(params['primary_capsules']['caps_dim']) 56 | 57 | ## Main Capsule Layers 58 | self.capsule_layers = nn.ModuleList([]) 59 | for i in range(len(params['capsules'])): 60 | if params['capsules'][i]['type'] == 'CONV': 61 | in_n_caps = params['primary_capsules']['num_caps'] if i==0 else \ 62 | params['capsules'][i-1]['num_caps'] 63 | in_d_caps = params['primary_capsules']['caps_dim'] if i==0 else \ 64 | params['capsules'][i-1]['caps_dim'] 65 | self.capsule_layers.append( 66 | layers.CapsuleCONV(in_n_capsules=in_n_caps, 67 | in_d_capsules=in_d_caps, 68 | out_n_capsules=params['capsules'][i]['num_caps'], 69 | out_d_capsules=params['capsules'][i]['caps_dim'], 70 | kernel_size=params['capsules'][i]['kernel_size'], 71 | stride=params['capsules'][i]['stride'], 72 | matrix_pose=params['capsules'][i]['matrix_pose'], 73 | dp=dp, 74 | coordinate_add=False 75 | ) 76 | ) 77 | elif params['capsules'][i]['type'] == 'FC': 78 | if i == 0: 79 | in_n_caps = params['primary_capsules']['num_caps'] * params['primary_capsules']['out_img_size'] *\ 80 | params['primary_capsules']['out_img_size'] 81 | in_d_caps = params['primary_capsules']['caps_dim'] 82 | elif params['capsules'][i-1]['type'] == 'FC': 83 | in_n_caps = params['capsules'][i-1]['num_caps'] 84 | in_d_caps = params['capsules'][i-1]['caps_dim'] 85 | elif params['capsules'][i-1]['type'] == 'CONV': 86 | in_n_caps = params['capsules'][i-1]['num_caps'] * params['capsules'][i-1]['out_img_size'] *\ 87 | params['capsules'][i-1]['out_img_size'] 88 | in_d_caps = params['capsules'][i-1]['caps_dim'] 89 | self.capsule_layers.append( 90 | layers.CapsuleFC(in_n_capsules=in_n_caps, 91 | in_d_capsules=in_d_caps, 92 | out_n_capsules=params['capsules'][i]['num_caps'], 93 | out_d_capsules=params['capsules'][i]['caps_dim'], 94 | matrix_pose=params['capsules'][i]['matrix_pose'], 95 | dp=dp 96 | ) 97 | ) 98 | 99 | ## Class Capsule Layer 100 | if not len(params['capsules'])==0: 101 | if params['capsules'][-1]['type'] == 'FC': 102 | in_n_caps = params['capsules'][-1]['num_caps'] 103 | in_d_caps = params['capsules'][-1]['caps_dim'] 104 | elif params['capsules'][-1]['type'] == 'CONV': 105 | in_n_caps = params['capsules'][-1]['num_caps'] * params['capsules'][-1]['out_img_size'] *\ 106 | params['capsules'][-1]['out_img_size'] 107 | in_d_caps = params['capsules'][-1]['caps_dim'] 108 | else: 109 | in_n_caps = params['primary_capsules']['num_caps'] * params['primary_capsules']['out_img_size'] *\ 110 | params['primary_capsules']['out_img_size'] 111 | in_d_caps = params['primary_capsules']['caps_dim'] 112 | self.capsule_layers.append( 113 | layers.CapsuleFC(in_n_capsules=in_n_caps, 114 | in_d_capsules=in_d_caps, 115 | out_n_capsules=params['class_capsules']['num_caps'], 116 | out_d_capsules=params['class_capsules']['caps_dim'], 117 | matrix_pose=params['class_capsules']['matrix_pose'], 118 | dp=dp 119 | ) 120 | ) 121 | 122 | ## After Capsule 123 | # fixed classifier for all class capsules 124 | self.final_fc = nn.Linear(params['class_capsules']['caps_dim'], 1) 125 | # different classifier for different capsules 126 | #self.final_fc = nn.Parameter(torch.randn(params['class_capsules']['num_caps'], params['class_capsules']['caps_dim'])) 127 | 128 | def forward(self, x, lbl_1=None, lbl_2=None): 129 | #### Forward Pass 130 | ## Backbone (before capsule) 131 | c = self.pre_caps(x) 132 | 133 | ## Primary Capsule Layer (a single CNN) 134 | u = self.pc_layer(c) # torch.Size([100, 512, 14, 14]) 135 | u = u.permute(0, 2, 3, 1) # 100, 14, 14, 512 136 | u = u.view(u.shape[0], self.pc_output_dim, self.pc_output_dim, self.pc_num_caps, self.pc_caps_dim) # 100, 14, 14, 32, 16 137 | u = u.permute(0, 3, 1, 2, 4) # 100, 32, 14, 14, 16 138 | init_capsule_value = self.nonlinear_act(u)#capsule_utils.squash(u) 139 | 140 | ## Main Capsule Layers 141 | # concurrent routing 142 | if not self.sequential_routing: 143 | # first iteration 144 | # perform initilialization for the capsule values as single forward passing 145 | capsule_values, _val = [init_capsule_value], init_capsule_value 146 | for i in range(len(self.capsule_layers)): 147 | _val = self.capsule_layers[i].forward(_val, 0) 148 | capsule_values.append(_val) # get the capsule value for next layer 149 | 150 | # second to t iterations 151 | # perform the routing between capsule layers 152 | for n in range(self.num_routing-1): 153 | _capsule_values = [init_capsule_value] 154 | for i in range(len(self.capsule_layers)): 155 | _val = self.capsule_layers[i].forward(capsule_values[i], n, 156 | capsule_values[i+1]) 157 | _capsule_values.append(_val) 158 | capsule_values = _capsule_values 159 | # sequential routing 160 | else: 161 | capsule_values, _val = [init_capsule_value], init_capsule_value 162 | for i in range(len(self.capsule_layers)): 163 | # first iteration 164 | __val = self.capsule_layers[i].forward(_val, 0) 165 | # second to t iterations 166 | # perform the routing between capsule layers 167 | for n in range(self.num_routing-1): 168 | __val = self.capsule_layers[i].forward(_val, n, __val) 169 | _val = __val 170 | capsule_values.append(_val) 171 | 172 | ## After Capsule 173 | out = capsule_values[-1] 174 | out = self.final_fc(out) # fixed classifier for all capsules 175 | out = out.squeeze() # fixed classifier for all capsules 176 | #out = torch.einsum('bnd, nd->bn', out, self.final_fc) # different classifiers for distinct capsules 177 | 178 | return out 179 | -------------------------------------------------------------------------------- /src/layers.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved. 4 | # 5 | '''Capsule in PyTorch 6 | TBD 7 | ''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | import numpy as np 14 | 15 | #### Simple Backbone #### 16 | class simple_backbone(nn.Module): 17 | def __init__(self, cl_input_channels,cl_num_filters,cl_filter_size, 18 | cl_stride,cl_padding): 19 | super(simple_backbone, self).__init__() 20 | self.pre_caps = nn.Sequential( 21 | nn.Conv2d(in_channels=cl_input_channels, 22 | out_channels=cl_num_filters, 23 | kernel_size=cl_filter_size, 24 | stride=cl_stride, 25 | padding=cl_padding), 26 | nn.ReLU(), 27 | ) 28 | def forward(self, x): 29 | out = self.pre_caps(x) # x is an image 30 | return out 31 | 32 | 33 | #### ResNet Backbone #### 34 | class BasicBlock(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, in_planes, planes, stride=1): 38 | super(BasicBlock, self).__init__() 39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(planes) 41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | 44 | self.shortcut = nn.Sequential() 45 | if stride != 1 or in_planes != self.expansion*planes: 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 48 | nn.BatchNorm2d(self.expansion*planes) 49 | ) 50 | 51 | def forward(self, x): 52 | out = F.relu(self.bn1(self.conv1(x))) 53 | out = self.bn2(self.conv2(out)) 54 | out += self.shortcut(x) 55 | out = F.relu(out) 56 | return out 57 | 58 | class resnet_backbone(nn.Module): 59 | def __init__(self, cl_input_channels, cl_num_filters, 60 | cl_stride): 61 | super(resnet_backbone, self).__init__() 62 | self.in_planes = 64 63 | def _make_layer(block, planes, num_blocks, stride): 64 | strides = [stride] + [1]*(num_blocks-1) 65 | layers = [] 66 | for stride in strides: 67 | layers.append(block(self.in_planes, planes, stride)) 68 | self.in_planes = planes * block.expansion 69 | return nn.Sequential(*layers) 70 | 71 | self.pre_caps = nn.Sequential( 72 | nn.Conv2d(in_channels=cl_input_channels, 73 | out_channels=64, 74 | kernel_size=3, 75 | stride=1, 76 | padding=1, 77 | bias=False), 78 | nn.BatchNorm2d(64), 79 | nn.ReLU(), 80 | _make_layer(block=BasicBlock, planes=64, num_blocks=3, stride=1), # num_blocks=2 or 3 81 | _make_layer(block=BasicBlock, planes=cl_num_filters, num_blocks=4, stride=cl_stride), # num_blocks=2 or 4 82 | ) 83 | def forward(self, x): 84 | out = self.pre_caps(x) # x is an image 85 | return out 86 | 87 | #### Capsule Layer #### 88 | class CapsuleFC(nn.Module): 89 | r"""Applies as a capsule fully-connected layer. 90 | TBD 91 | """ 92 | def __init__(self, in_n_capsules, in_d_capsules, out_n_capsules, out_d_capsules, matrix_pose, dp): 93 | super(CapsuleFC, self).__init__() 94 | self.in_n_capsules = in_n_capsules 95 | self.in_d_capsules = in_d_capsules 96 | self.out_n_capsules = out_n_capsules 97 | self.out_d_capsules = out_d_capsules 98 | self.matrix_pose = matrix_pose 99 | 100 | if matrix_pose: 101 | self.sqrt_d = int(np.sqrt(self.in_d_capsules)) 102 | self.weight_init_const = np.sqrt(out_n_capsules/(self.sqrt_d*in_n_capsules)) 103 | self.w = nn.Parameter(self.weight_init_const* \ 104 | torch.randn(in_n_capsules, self.sqrt_d, self.sqrt_d, out_n_capsules)) 105 | 106 | else: 107 | self.weight_init_const = np.sqrt(out_n_capsules/(in_d_capsules*in_n_capsules)) 108 | self.w = nn.Parameter(self.weight_init_const* \ 109 | torch.randn(in_n_capsules, in_d_capsules, out_n_capsules, out_d_capsules)) 110 | self.dropout_rate = dp 111 | self.nonlinear_act = nn.LayerNorm(out_d_capsules) 112 | self.drop = nn.Dropout(self.dropout_rate) 113 | self.scale = 1. / (out_d_capsules ** 0.5) 114 | 115 | def extra_repr(self): 116 | return 'in_n_capsules={}, in_d_capsules={}, out_n_capsules={}, out_d_capsules={}, matrix_pose={}, \ 117 | weight_init_const={}, dropout_rate={}'.format( 118 | self.in_n_capsules, self.in_d_capsules, self.out_n_capsules, self.out_d_capsules, self.matrix_pose, 119 | self.weight_init_const, self.dropout_rate 120 | ) 121 | def forward(self, input, num_iter, next_capsule_value=None): 122 | # b: batch size 123 | # n: num of capsules in current layer 124 | # a: dim of capsules in current layer 125 | # m: num of capsules in next layer 126 | # d: dim of capsules in next layer 127 | if len(input.shape) == 5: 128 | input = input.permute(0, 4, 1, 2, 3) 129 | input = input.contiguous().view(input.shape[0], input.shape[1], -1) 130 | input = input.permute(0,2,1) 131 | 132 | if self.matrix_pose: 133 | w = self.w # nxdm 134 | _input = input.view(input.shape[0], input.shape[1], self.sqrt_d, self.sqrt_d) # bnax 135 | else: 136 | w = self.w 137 | 138 | if next_capsule_value is None: 139 | query_key = torch.zeros(self.in_n_capsules, self.out_n_capsules).type_as(input) 140 | query_key = F.softmax(query_key, dim=1) 141 | 142 | if self.matrix_pose: 143 | next_capsule_value = torch.einsum('nm, bnax, nxdm->bmad', query_key, _input, w) 144 | else: 145 | next_capsule_value = torch.einsum('nm, bna, namd->bmd', query_key, input, w) 146 | else: 147 | if self.matrix_pose: 148 | next_capsule_value = next_capsule_value.view(next_capsule_value.shape[0], 149 | next_capsule_value.shape[1], self.sqrt_d, self.sqrt_d) 150 | _query_key = torch.einsum('bnax, nxdm, bmad->bnm', _input, w, next_capsule_value) 151 | else: 152 | _query_key = torch.einsum('bna, namd, bmd->bnm', input, w, next_capsule_value) 153 | _query_key.mul_(self.scale) 154 | query_key = F.softmax(_query_key, dim=2) 155 | query_key = query_key / (torch.sum(query_key, dim=2, keepdim=True) + 1e-10) 156 | 157 | if self.matrix_pose: 158 | next_capsule_value = torch.einsum('bnm, bnax, nxdm->bmad', query_key, _input, 159 | w) 160 | else: 161 | next_capsule_value = torch.einsum('bnm, bna, namd->bmd', query_key, input, 162 | w) 163 | 164 | next_capsule_value = self.drop(next_capsule_value) 165 | if not next_capsule_value.shape[-1] == 1: 166 | if self.matrix_pose: 167 | next_capsule_value = next_capsule_value.view(next_capsule_value.shape[0], 168 | next_capsule_value.shape[1], self.out_d_capsules) 169 | next_capsule_value = self.nonlinear_act(next_capsule_value) 170 | else: 171 | next_capsule_value = self.nonlinear_act(next_capsule_value) 172 | return next_capsule_value 173 | 174 | class CapsuleCONV(nn.Module): 175 | r"""Applies as a capsule convolutional layer. 176 | TBD 177 | """ 178 | def __init__(self, in_n_capsules, in_d_capsules, out_n_capsules, out_d_capsules, 179 | kernel_size, stride, matrix_pose, dp, coordinate_add=False): 180 | super(CapsuleCONV, self).__init__() 181 | self.in_n_capsules = in_n_capsules 182 | self.in_d_capsules = in_d_capsules 183 | self.out_n_capsules = out_n_capsules 184 | self.out_d_capsules = out_d_capsules 185 | self.kernel_size = kernel_size 186 | self.stride = stride 187 | self.matrix_pose = matrix_pose 188 | self.coordinate_add = coordinate_add 189 | 190 | if matrix_pose: 191 | self.sqrt_d = int(np.sqrt(self.in_d_capsules)) 192 | self.weight_init_const = np.sqrt(out_n_capsules/(self.sqrt_d*in_n_capsules*kernel_size*kernel_size)) 193 | self.w = nn.Parameter(self.weight_init_const*torch.randn(kernel_size, kernel_size, 194 | in_n_capsules, self.sqrt_d, self.sqrt_d, out_n_capsules)) 195 | else: 196 | self.weight_init_const = np.sqrt(out_n_capsules/(in_d_capsules*in_n_capsules*kernel_size*kernel_size)) 197 | self.w = nn.Parameter(self.weight_init_const*torch.randn(kernel_size, kernel_size, 198 | in_n_capsules, in_d_capsules, out_n_capsules, 199 | out_d_capsules)) 200 | self.nonlinear_act = nn.LayerNorm(out_d_capsules) 201 | self.dropout_rate = dp 202 | self.drop = nn.Dropout(self.dropout_rate) 203 | self.scale = 1. / (out_d_capsules ** 0.5) 204 | 205 | def extra_repr(self): 206 | return 'in_n_capsules={}, in_d_capsules={}, out_n_capsules={}, out_d_capsules={}, \ 207 | kernel_size={}, stride={}, coordinate_add={}, matrix_pose={}, weight_init_const={}, \ 208 | dropout_rate={}'.format( 209 | self.in_n_capsules, self.in_d_capsules, self.out_n_capsules, self.out_d_capsules, 210 | self.kernel_size, self.stride, self.coordinate_add, self.matrix_pose, self.weight_init_const, 211 | self.dropout_rate 212 | ) 213 | def input_expansion(self, input): 214 | # input has size [batch x num_of_capsule x height x width x x capsule_dimension] 215 | b, n, h, w, d = input.shape 216 | 217 | h_out = int((h-self.kernel_size)/self.stride+1) 218 | w_out = int((w-self.kernel_size)/self.stride+1) 219 | 220 | # this may be slow if the image size is large 221 | # TODO: kind of stupid implementation :'( 222 | inputs = torch.stack([input[:, :, self.stride * i:self.stride * i + self.kernel_size, 223 | self.stride * j:self.stride * j + self.kernel_size, :] for i in range(h_out) for j in range(w_out)], 224 | dim=-1) # b,n, kernel_size, kernel_size, h_out*w_out*d 225 | inputs = inputs.view(b,n,self.kernel_size, self.kernel_size, h_out, w_out, d) 226 | return inputs 227 | 228 | def forward(self, input, num_iter, next_capsule_value=None): 229 | # k,l: kernel size 230 | # h,w: output width and length 231 | inputs = self.input_expansion(input) 232 | 233 | if self.matrix_pose: 234 | w = self.w # klnxdm 235 | _inputs = inputs.view(inputs.shape[0], inputs.shape[1], inputs.shape[2], inputs.shape[3],\ 236 | inputs.shape[4], inputs.shape[5], self.sqrt_d, self.sqrt_d) # bnklmhax 237 | else: 238 | w = self.w 239 | 240 | if next_capsule_value is None: 241 | query_key = torch.zeros(self.in_n_capsules, self.kernel_size, self.kernel_size, 242 | self.out_n_capsules).type_as(inputs) 243 | query_key = F.softmax(query_key, dim=3) 244 | 245 | if self.matrix_pose: 246 | next_capsule_value = torch.einsum('nklm, bnklhwax, klnxdm->bmhwad', query_key, 247 | _inputs, w) 248 | else: 249 | next_capsule_value = torch.einsum('nklm, bnklhwa, klnamd->bmhwd', query_key, 250 | inputs, w) 251 | else: 252 | if self.matrix_pose: 253 | next_capsule_value = next_capsule_value.view(next_capsule_value.shape[0],\ 254 | next_capsule_value.shape[1], next_capsule_value.shape[2],\ 255 | next_capsule_value.shape[3], self.sqrt_d, self.sqrt_d) 256 | _query_key = torch.einsum('bnklhwax, klnxdm, bmhwad->bnklmhw', _inputs, w, 257 | next_capsule_value) 258 | else: 259 | _query_key = torch.einsum('bnklhwa, klnamd, bmhwd->bnklmhw', inputs, w, 260 | next_capsule_value) 261 | _query_key.mul_(self.scale) 262 | query_key = F.softmax(_query_key, dim=4) 263 | query_key = query_key / (torch.sum(query_key, dim=4, keepdim=True) + 1e-10) 264 | 265 | if self.matrix_pose: 266 | next_capsule_value = torch.einsum('bnklmhw, bnklhwax, klnxdm->bmhwad', query_key, 267 | _inputs, w) 268 | else: 269 | next_capsule_value = torch.einsum('bnklmhw, bnklhwa, klnamd->bmhwd', query_key, 270 | inputs, w) 271 | 272 | next_capsule_value = self.drop(next_capsule_value) 273 | if not next_capsule_value.shape[-1] == 1: 274 | if self.matrix_pose: 275 | next_capsule_value = next_capsule_value.view(next_capsule_value.shape[0],\ 276 | next_capsule_value.shape[1], next_capsule_value.shape[2],\ 277 | next_capsule_value.shape[3], self.out_d_capsules) 278 | next_capsule_value = self.nonlinear_act(next_capsule_value) 279 | else: 280 | next_capsule_value = self.nonlinear_act(next_capsule_value) 281 | 282 | return next_capsule_value 283 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2020 Apple Inc. All rights reserved. 4 | # 5 | '''Some helper functions for PyTorch, including: 6 | - get_mean_and_std: calculate the mean and std value of dataset. 7 | - msr_init: net parameter initialization. 8 | - progress_bar: progress bar mimic xlua.progress. 9 | copy from https://github.com/kuangliu/pytorch-cifar/blob/master/utils.py 10 | ''' 11 | import os 12 | import sys 13 | import time 14 | import math 15 | 16 | import torch.nn as nn 17 | import torch.nn.init as init 18 | 19 | 20 | def get_mean_and_std(dataset): 21 | '''Compute the mean and std value of dataset.''' 22 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 23 | mean = torch.zeros(3) 24 | std = torch.zeros(3) 25 | print('==> Computing mean and std..') 26 | for inputs, targets in dataloader: 27 | for i in range(3): 28 | mean[i] += inputs[:,i,:,:].mean() 29 | std[i] += inputs[:,i,:,:].std() 30 | mean.div_(len(dataset)) 31 | std.div_(len(dataset)) 32 | return mean, std 33 | 34 | def init_params(net): 35 | '''Init layer parameters.''' 36 | for m in net.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | init.kaiming_normal(m.weight, mode='fan_out') 39 | if m.bias: 40 | init.constant(m.bias, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant(m.weight, 1) 43 | init.constant(m.bias, 0) 44 | elif isinstance(m, nn.Linear): 45 | init.normal(m.weight, std=1e-3) 46 | if m.bias: 47 | init.constant(m.bias, 0) 48 | 49 | 50 | _, term_width = os.popen('stty size', 'r').read().split() 51 | term_width = int(term_width) 52 | 53 | TOTAL_BAR_LENGTH = 65. 54 | last_time = time.time() 55 | begin_time = last_time 56 | def progress_bar(current, total, msg=None): 57 | global last_time, begin_time 58 | if current == 0: 59 | begin_time = time.time() # Reset for new bar. 60 | 61 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 62 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 63 | 64 | sys.stdout.write(' [') 65 | for i in range(cur_len): 66 | sys.stdout.write('=') 67 | sys.stdout.write('>') 68 | for i in range(rest_len): 69 | sys.stdout.write('.') 70 | sys.stdout.write(']') 71 | 72 | cur_time = time.time() 73 | step_time = cur_time - last_time 74 | last_time = cur_time 75 | tot_time = cur_time - begin_time 76 | 77 | L = [] 78 | L.append(' Step: %s' % format_time(step_time)) 79 | L.append(' | Tot: %s' % format_time(tot_time)) 80 | if msg: 81 | L.append(' | ' + msg) 82 | 83 | msg = ''.join(L) 84 | sys.stdout.write(msg) 85 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 86 | sys.stdout.write(' ') 87 | 88 | # Go back to the center of the bar. 89 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 90 | sys.stdout.write('\b') 91 | sys.stdout.write(' %d/%d ' % (current+1, total)) 92 | 93 | if current < total-1: 94 | sys.stdout.write('\r') 95 | else: 96 | sys.stdout.write('\n') 97 | sys.stdout.flush() 98 | 99 | def format_time(seconds): 100 | days = int(seconds / 3600/24) 101 | seconds = seconds - days*3600*24 102 | hours = int(seconds / 3600) 103 | seconds = seconds - hours*3600 104 | minutes = int(seconds / 60) 105 | seconds = seconds - minutes*60 106 | secondsf = int(seconds) 107 | seconds = seconds - secondsf 108 | millis = int(seconds*1000) 109 | 110 | f = '' 111 | i = 1 112 | if days > 0: 113 | f += str(days) + 'D' 114 | i += 1 115 | if hours > 0 and i <= 2: 116 | f += str(hours) + 'h' 117 | i += 1 118 | if minutes > 0 and i <= 2: 119 | f += str(minutes) + 'm' 120 | i += 1 121 | if secondsf > 0 and i <= 2: 122 | f += str(secondsf) + 's' 123 | i += 1 124 | if millis > 0 and i <= 2: 125 | f += str(millis) + 'ms' 126 | i += 1 127 | if f == '': 128 | f = '0ms' 129 | return f 130 | --------------------------------------------------------------------------------