├── ACKNOWLEDGEMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── configs
├── __init__.py
├── resnet_backbone_CIFAR10.json
└── resnet_backbone_CIFAR100.json
├── imgs
├── concurrent.png
├── inverted_attention.png
└── overall.png
├── main_capsule.py
├── src
├── __init__.py
├── capsule_model.py
└── layers.py
└── utils.py
/ACKNOWLEDGEMENTS:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this Software have utilized the following copyrighted material, the use of which is hereby acknowledged.
3 |
4 | _____________________
5 | PyTorch (https://pytorch.org)
6 | We use PyTorch as the training framework for our model.
7 |
8 | From PyTorch:
9 |
10 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
11 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
12 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
13 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
14 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
15 | Copyright (c) 2011-2013 NYU (Clement Farabet)
16 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
17 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
18 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
19 |
20 | From Caffe2:
21 |
22 | Copyright (c) 2016-present, Facebook Inc. All rights reserved.
23 |
24 | All contributions by Facebook:
25 | Copyright (c) 2016 Facebook Inc.
26 |
27 | All contributions by Google:
28 | Copyright (c) 2015 Google Inc.
29 | All rights reserved.
30 |
31 | All contributions by Yangqing Jia:
32 | Copyright (c) 2015 Yangqing Jia
33 | All rights reserved.
34 |
35 | All contributions from Caffe:
36 | Copyright(c) 2013, 2014, 2015, the respective contributors
37 | All rights reserved.
38 |
39 | All other contributions:
40 | Copyright(c) 2015, 2016 the respective contributors
41 | All rights reserved.
42 |
43 | Caffe2 uses a copyright model similar to Caffe: each contributor holds
44 | copyright over their contributions to Caffe2. The project versioning records
45 | all such contribution and copyright details. If a contributor wants to further
46 | mark their specific copyright on a particular contribution, they should
47 | indicate their copyright solely in the commit message of the change when it is
48 | committed.
49 |
50 | All rights reserved.
51 |
52 | Redistribution and use in source and binary forms, with or without
53 | modification, are permitted provided that the following conditions are met:
54 |
55 | 1. Redistributions of source code must retain the above copyright
56 | notice, this list of conditions and the following disclaimer.
57 |
58 | 2. Redistributions in binary form must reproduce the above copyright
59 | notice, this list of conditions and the following disclaimer in the
60 | documentation and/or other materials provided with the distribution.
61 |
62 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
63 | and IDIAP Research Institute nor the names of its contributors may be
64 | used to endorse or promote products derived from this software without
65 | specific prior written permission.
66 |
67 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
68 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
69 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
70 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
71 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
72 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
73 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
74 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
75 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
76 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
77 | POSSIBILITY OF SUCH DAMAGE.
78 |
79 | _____________________
80 | Soumith Chintala (TorchVision: https://github.com/pytorch/vision/)
81 | TorchVision is used for handling data loading and IO utilities, which is distributed under BSD 3-Clause License.
82 |
83 | Copyright (c) Soumith Chintala 2016,
84 | All rights reserved.
85 |
86 | Redistribution and use in source and binary forms, with or without
87 | modification, are permitted provided that the following conditions are met:
88 |
89 | * Redistributions of source code must retain the above copyright notice, this
90 | list of conditions and the following disclaimer.
91 |
92 | * Redistributions in binary form must reproduce the above copyright notice,
93 | this list of conditions and the following disclaimer in the documentation
94 | and/or other materials provided with the distribution.
95 |
96 | * Neither the name of the copyright holder nor the names of its
97 | contributors may be used to endorse or promote products derived from
98 | this software without specific prior written permission.
99 |
100 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
101 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
102 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
103 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
104 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
105 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
106 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
107 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
108 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
109 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
110 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | ## Before you get started
6 |
7 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2019 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
41 | -------------------------------------------------------------------------------
42 | SOFTWARE DISTRIBUTED IN THIS REPOSITORY:
43 |
44 | This software includes a number of subcomponents with separate
45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
46 | -------------------------------------------------------------------------------
47 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # Capsules with Inverted Dot-Product Attention Routing
4 |
5 | > Pytorch implementation for Capsules with Inverted Dot-Product Attention Routing.
6 |
7 |
8 | ## Paper
9 | [**Capsules with Inverted Dot-Product Attention Routing**](https://openreview.net/pdf?id=HJe6uANtwH)
10 | [Yao-Hung Hubert Tsai](https://yaohungt.github.io), Nitish Srivastava, Hanlin Goh, and [Ruslan Salakhutdinov](https://www.cs.cmu.edu/~rsalakhu/)
11 | International Conference on Learning Representations (ICLR), 2020.
12 |
13 | Please cite our paper if you find our work useful for your research:
14 |
15 | ```tex
16 | @inproceedings{tsai2020Capsules,
17 | title={Capsules with Inverted Dot-Product Attention Routing},
18 | author={Tsai, Yao-Hung Hubert and Srivastava, Nitish and Goh, Hanlin and Salakhutdinov, Ruslan},
19 | booktitle={International Conference on Learning Representations (ICLR)},
20 | year={2020},
21 | }
22 | ```
23 |
24 | ## Overview
25 |
26 | ### Overall Architecture
27 |
28 |
29 |
30 | An example of our proposed architecture is shown above. The backbone is a standard feed-forward convolutional neural network. The features extracted from this network are fed through another convolutional layer. At each spatial location, groups of 16 channels are made to create capsules (we assume a 16-dimensional pose in a capsule). LayerNorm is then applied across the 16 channels to obtain the primary capsules. This is followed by two convolutional capsule layers, and then by two fully-connected capsule layers. In the last capsule layer, each capsule corresponds to a class. These capsules are then used to compute logits that feed into a softmax to computed the classification probabilities. Inference in this network requires a feed-forward pass up to the primary capsules. After this, our proposed routing mechanism (discussed later) takes over.
31 |
32 | ### Inverted Dot-Product Attention Routing
33 |
34 |
35 |
36 | In our method, the routing procedure resembles an inverted attention mechanism, where dot products are used to measure agreement. Specifically, the higher-level (parent) units compete for the attention of the lower-level (child) units, instead of the other way around, which is commonly used in attention models. Hence, the routing probability directly depends on the agreement between the parent’s pose (from the previous iteration step) and the child’s vote for the parent’s pose (in the current iteration step). We (1) use Layer Normalization (Ba et al., 2016) as normalization, and we (2) perform inference of the latent capsule states and routing probabilities jointly across multiple capsule layers (instead of doing it layer-wise).
37 |
38 | ### Concurrent Routing
39 |
40 |
41 |
42 | The concurrent routing is a parallel-in-time routing procedure for all capsules layers.
43 |
44 | ## Usage
45 |
46 | ### Prerequisites
47 | - Python 3.6/3.7
48 | - [Pytorch (>=1.2.0) and torchvision](https://pytorch.org/)
49 | - CUDA 10.0 or above
50 |
51 | ### Datasets
52 |
53 | We use [CIFAR10 and CIFAR100](https://www.cs.toronto.edu/~kriz/cifar.html).
54 |
55 | ### Run the Code
56 |
57 | #### Arguments
58 |
59 | | Args | Value | help |
60 | |:-----------:|:---------------------------------------------------:|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
61 | | debug | - | Enter into a debug mode, which means no models and results will be saved. True or False |
62 | | num_routing | 1 | The number of routing iteration. The number should > 1. |
63 | | dataset | CIFAR10 | Choice of the dataset. CIFAR10 or CIFAR100. |
64 | | backbone | resnet | Choice of the backbone. simple or resnet. |
65 | | config_path | ./configs/resnet_backbone_CIFAR10.json | Configurations for capsule layers. |
66 |
67 | #### Running CIFAR-10
68 |
69 | ```bash
70 | python main_capsule.py --num_routing 2 --dataset CIFAR10 --backbone resnet --config_path ./configs/resnet_backbone_CIFAR10.json
71 | ```
72 | When ```num_routing``` is ```1```, the average performance we obtained is _94.73%_.
73 |
74 | When ```num_routing``` is ```2```, the average performance we obtained is _94.85%_ and the best model we obtained is _95.14%_.
75 |
76 |
77 | #### Running CIFAR-100
78 |
79 | ```bash
80 | python main_capsule.py --num_routing 2 --dataset CIFAR100 --backbone resnet --config_path ./configs/resnet_backbone_CIFAR100.json
81 | ```
82 |
83 | When ```num_routing``` is ```1```, the average performance we obtained is _76.02%_.
84 |
85 | When ```num_routing``` is ```2```, the average performance we obtained is _76.27%_ and the best model we obtained is _78.02%_.
86 |
87 |
88 | ## License
89 | This code is released under the [LICENSE](LICENSE) terms.
90 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved.
4 | #
5 |
--------------------------------------------------------------------------------
/configs/resnet_backbone_CIFAR10.json:
--------------------------------------------------------------------------------
1 | {
2 | "backbone": {
3 | "kernel_size": 3,
4 | "output_dim": 128,
5 | "input_dim": 3,
6 | "stride": 2,
7 | "padding": 1,
8 | "out_img_size": 16
9 | },
10 | "primary_capsules": {
11 | "kernel_size": 1,
12 | "stride": 1,
13 | "input_dim": 128,
14 | "caps_dim": 16,
15 | "num_caps": 32,
16 | "padding": 0,
17 | "out_img_size": 16
18 | },
19 | "capsules": [
20 | {
21 | "type" : "CONV",
22 | "num_caps": 32,
23 | "caps_dim": 16,
24 | "kernel_size": 3,
25 | "stride": 2,
26 | "matrix_pose": true,
27 | "out_img_size": 7
28 | },
29 | {
30 | "type": "CONV",
31 | "num_caps": 32,
32 | "caps_dim": 16,
33 | "kernel_size": 3,
34 | "stride": 1,
35 | "matrix_pose": true,
36 | "out_img_size": 5
37 | }
38 | ],
39 | "class_capsules": {
40 | "num_caps": 10,
41 | "caps_dim": 16,
42 | "matrix_pose": true
43 | }
44 | }
--------------------------------------------------------------------------------
/configs/resnet_backbone_CIFAR100.json:
--------------------------------------------------------------------------------
1 | {
2 | "backbone": {
3 | "kernel_size": 3,
4 | "output_dim": 128,
5 | "input_dim": 3,
6 | "stride": 2,
7 | "padding": 1,
8 | "out_img_size": 16
9 | },
10 | "primary_capsules": {
11 | "kernel_size": 1,
12 | "stride": 1,
13 | "input_dim": 128,
14 | "caps_dim": 36,
15 | "num_caps": 32,
16 | "padding": 0,
17 | "out_img_size": 16
18 | },
19 | "capsules": [
20 | {
21 | "type" : "CONV",
22 | "num_caps": 32,
23 | "caps_dim": 36,
24 | "kernel_size": 3,
25 | "stride": 2,
26 | "matrix_pose": true,
27 | "out_img_size": 7
28 | },
29 | {
30 | "type": "CONV",
31 | "num_caps": 32,
32 | "caps_dim": 36,
33 | "kernel_size": 3,
34 | "stride": 1,
35 | "matrix_pose": true,
36 | "out_img_size": 5
37 | },
38 | {
39 | "type": "FC",
40 | "num_caps": 20,
41 | "caps_dim": 36,
42 | "matrix_pose": true
43 | }
44 | ],
45 | "class_capsules": {
46 | "num_caps": 100,
47 | "caps_dim": 36,
48 | "matrix_pose": true
49 | }
50 | }
--------------------------------------------------------------------------------
/imgs/concurrent.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-capsules-inverted-attention-routing/53899ae5b75b576a693a1bd79a2173ac3f1cdf1d/imgs/concurrent.png
--------------------------------------------------------------------------------
/imgs/inverted_attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-capsules-inverted-attention-routing/53899ae5b75b576a693a1bd79a2173ac3f1cdf1d/imgs/inverted_attention.png
--------------------------------------------------------------------------------
/imgs/overall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-capsules-inverted-attention-routing/53899ae5b75b576a693a1bd79a2173ac3f1cdf1d/imgs/overall.png
--------------------------------------------------------------------------------
/main_capsule.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2020 Apple Inc. All rights reserved.
4 | #
5 | '''Train CIFAR10 with PyTorch.'''
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | import torch.backends.cudnn as cudnn
11 |
12 | import torchvision
13 | import torchvision.transforms as transforms
14 |
15 | import os
16 | import argparse
17 |
18 | from src import capsule_model
19 | from utils import progress_bar
20 | import pickle
21 | import json
22 |
23 | from datetime import datetime
24 |
25 | # +
26 | parser = argparse.ArgumentParser(description='Training Capsules using Inverted Dot-Product Attention Routing')
27 |
28 | parser.add_argument('--resume_dir', '-r', default='', type=str, help='dir where we resume from checkpoint')
29 | parser.add_argument('--num_routing', default=1, type=int, help='number of routing. Recommended: 0,1,2,3.')
30 | parser.add_argument('--dataset', default='CIFAR10', type=str, help='dataset. CIFAR10 or CIFAR100.')
31 | parser.add_argument('--backbone', default='resnet', type=str, help='type of backbone. simple or resnet')
32 | parser.add_argument('--num_workers', default=2, type=int, help='number of workers. 0 or 2')
33 | parser.add_argument('--config_path', default='./configs/full_rank_2C1F_matrix_for_iterations.json', type=str, help='path of the config')
34 | parser.add_argument('--debug', action='store_true',
35 | help='use debug mode (without saving to a directory)')
36 | parser.add_argument('--sequential_routing', action='store_true', help='not using concurrent_routing')
37 |
38 |
39 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate. 0.1 for SGD')
40 | parser.add_argument('--dp', default=0.0, type=float, help='dropout rate')
41 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay')
42 | # -
43 |
44 | args = parser.parse_args()
45 | assert args.num_routing > 0
46 |
47 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
48 | best_acc = 0 # best test accuracy
49 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch
50 |
51 | # Data
52 | print('==> Preparing data..')
53 | assert args.dataset == 'CIFAR10' or args.dataset == 'CIFAR100'
54 | transform_train = transforms.Compose([
55 | transforms.RandomCrop(32, padding=4),
56 | transforms.RandomHorizontalFlip(),
57 | transforms.ToTensor(),
58 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
59 | ])
60 | transform_test = transforms.Compose([
61 | transforms.ToTensor(),
62 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
63 | ])
64 |
65 | trainset = getattr(torchvision.datasets, args.dataset)(root='../data', train=True, download=True, transform=transform_train)
66 | testset = getattr(torchvision.datasets, args.dataset)(root='../data', train=False, download=True, transform=transform_test)
67 | num_class = int(args.dataset.split('CIFAR')[1])
68 |
69 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=args.num_workers)
70 | testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=args.num_workers)
71 |
72 | print('==> Building model..')
73 | # Model parameters
74 |
75 | image_dim_size = 32
76 |
77 | with open(args.config_path, 'rb') as file:
78 | params = json.load(file)
79 |
80 | net = capsule_model.CapsModel(image_dim_size,
81 | params,
82 | args.backbone,
83 | args.dp,
84 | args.num_routing,
85 | sequential_routing=args.sequential_routing)
86 |
87 | # +
88 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
89 |
90 | lr_decay = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1)
91 |
92 |
93 | # -
94 |
95 | def count_parameters(model):
96 | for name, param in model.named_parameters():
97 | if param.requires_grad:
98 | print(name, param.numel())
99 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
100 |
101 | print(net)
102 | total_params = count_parameters(net)
103 | print(total_params)
104 |
105 | if not os.path.isdir('results') and not args.debug:
106 | os.mkdir('results')
107 | if not args.debug:
108 | store_dir = os.path.join('results', datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
109 | os.mkdir(store_dir)
110 |
111 | net = net.to(device)
112 | if device == 'cuda':
113 | net = torch.nn.DataParallel(net)
114 | cudnn.benchmark = True
115 |
116 | loss_func = nn.CrossEntropyLoss()
117 |
118 | if args.resume_dir and not args.debug:
119 | # Load checkpoint.
120 | print('==> Resuming from checkpoint..')
121 | checkpoint = torch.load(os.path.join(args.resume_dir, 'ckpt.pth'))
122 | net.load_state_dict(checkpoint['net'])
123 | best_acc = checkpoint['acc']
124 | start_epoch = checkpoint['epoch']
125 |
126 | # Training
127 | def train(epoch):
128 | print('\nEpoch: %d' % epoch)
129 | net.train()
130 | train_loss = 0
131 | correct = 0
132 | total = 0
133 | for batch_idx, (inputs, targets) in enumerate(trainloader):
134 | inputs = inputs.to(device)
135 |
136 | targets = targets.to(device)
137 |
138 | optimizer.zero_grad()
139 |
140 | v = net(inputs)
141 |
142 | loss = loss_func(v, targets)
143 |
144 | loss.backward()
145 | optimizer.step()
146 |
147 | train_loss += loss.item()
148 | _, predicted = v.max(dim=1)
149 |
150 | total += targets.size(0)
151 |
152 | correct += predicted.eq(targets).sum().item()
153 |
154 | progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
155 | % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
156 | return 100.*correct/total
157 |
158 | def test(epoch):
159 | global best_acc
160 | net.eval()
161 | test_loss = 0
162 | correct = 0
163 | total = 0
164 | with torch.no_grad():
165 | for batch_idx, (inputs, targets) in enumerate(testloader):
166 | inputs = inputs.to(device)
167 |
168 | targets = targets.to(device)
169 |
170 | v = net(inputs)
171 |
172 | loss = loss_func(v, targets)
173 |
174 | test_loss += loss.item()
175 |
176 | _, predicted = v.max(dim=1)
177 |
178 | total += targets.size(0)
179 |
180 | correct += predicted.eq(targets).sum().item()
181 |
182 | progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
183 | % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
184 |
185 | # Save checkpoint.
186 | acc = 100.*correct/total
187 | if acc > best_acc and not args.debug:
188 | print('Saving..')
189 | state = {
190 | 'net': net.state_dict(),
191 | 'acc': acc,
192 | 'epoch': epoch,
193 | }
194 | torch.save(state, os.path.join(store_dir, 'ckpt.pth'))
195 | best_acc = acc
196 | return 100.*correct/total
197 |
198 | # +
199 | results = {
200 | 'total_params': total_params,
201 | 'args': args,
202 | 'params': params,
203 | 'train_acc': [],
204 | 'test_acc': [],
205 | }
206 |
207 | total_epochs = 350
208 |
209 | for epoch in range(start_epoch, start_epoch+total_epochs):
210 | results['train_acc'].append(train(epoch))
211 |
212 | lr_decay.step()
213 | results['test_acc'].append(test(epoch))
214 | # -
215 |
216 | if not args.debug:
217 | store_file = os.path.join(store_dir, 'dataset_' + str(args.dataset) + '_num_routing_' + str(args.num_routing) + \
218 | '_backbone_' + args.backbone + '.dct')
219 |
220 | pickle.dump(results, open(store_file, 'wb'))
221 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved.
4 | #
5 |
--------------------------------------------------------------------------------
/src/capsule_model.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved.
4 | #
5 | from src import layers
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch
9 |
10 | # Capsule model
11 | class CapsModel(nn.Module):
12 | def __init__(self,
13 | image_dim_size,
14 | params,
15 | backbone,
16 | dp,
17 | num_routing,
18 | sequential_routing=True):
19 |
20 | super(CapsModel, self).__init__()
21 | #### Parameters
22 | self.sequential_routing = sequential_routing
23 |
24 | ## Primary Capsule Layer
25 | self.pc_num_caps = params['primary_capsules']['num_caps']
26 | self.pc_caps_dim = params['primary_capsules']['caps_dim']
27 | self.pc_output_dim = params['primary_capsules']['out_img_size']
28 | ## General
29 | self.num_routing = num_routing # >3 may cause slow converging
30 |
31 | #### Building Networks
32 | ## Backbone (before capsule)
33 | if backbone == 'simple':
34 | self.pre_caps = layers.simple_backbone(params['backbone']['input_dim'],
35 | params['backbone']['output_dim'],
36 | params['backbone']['kernel_size'],
37 | params['backbone']['stride'],
38 | params['backbone']['padding'])
39 | elif backbone == 'resnet':
40 | self.pre_caps = layers.resnet_backbone(params['backbone']['input_dim'],
41 | params['backbone']['output_dim'],
42 | params['backbone']['stride'])
43 |
44 | ## Primary Capsule Layer (a single CNN)
45 | self.pc_layer = nn.Conv2d(in_channels=params['primary_capsules']['input_dim'],
46 | out_channels=params['primary_capsules']['num_caps'] *\
47 | params['primary_capsules']['caps_dim'],
48 | kernel_size=params['primary_capsules']['kernel_size'],
49 | stride=params['primary_capsules']['stride'],
50 | padding=params['primary_capsules']['padding'],
51 | bias=False)
52 |
53 | #self.pc_layer = nn.Sequential()
54 |
55 | self.nonlinear_act = nn.LayerNorm(params['primary_capsules']['caps_dim'])
56 |
57 | ## Main Capsule Layers
58 | self.capsule_layers = nn.ModuleList([])
59 | for i in range(len(params['capsules'])):
60 | if params['capsules'][i]['type'] == 'CONV':
61 | in_n_caps = params['primary_capsules']['num_caps'] if i==0 else \
62 | params['capsules'][i-1]['num_caps']
63 | in_d_caps = params['primary_capsules']['caps_dim'] if i==0 else \
64 | params['capsules'][i-1]['caps_dim']
65 | self.capsule_layers.append(
66 | layers.CapsuleCONV(in_n_capsules=in_n_caps,
67 | in_d_capsules=in_d_caps,
68 | out_n_capsules=params['capsules'][i]['num_caps'],
69 | out_d_capsules=params['capsules'][i]['caps_dim'],
70 | kernel_size=params['capsules'][i]['kernel_size'],
71 | stride=params['capsules'][i]['stride'],
72 | matrix_pose=params['capsules'][i]['matrix_pose'],
73 | dp=dp,
74 | coordinate_add=False
75 | )
76 | )
77 | elif params['capsules'][i]['type'] == 'FC':
78 | if i == 0:
79 | in_n_caps = params['primary_capsules']['num_caps'] * params['primary_capsules']['out_img_size'] *\
80 | params['primary_capsules']['out_img_size']
81 | in_d_caps = params['primary_capsules']['caps_dim']
82 | elif params['capsules'][i-1]['type'] == 'FC':
83 | in_n_caps = params['capsules'][i-1]['num_caps']
84 | in_d_caps = params['capsules'][i-1]['caps_dim']
85 | elif params['capsules'][i-1]['type'] == 'CONV':
86 | in_n_caps = params['capsules'][i-1]['num_caps'] * params['capsules'][i-1]['out_img_size'] *\
87 | params['capsules'][i-1]['out_img_size']
88 | in_d_caps = params['capsules'][i-1]['caps_dim']
89 | self.capsule_layers.append(
90 | layers.CapsuleFC(in_n_capsules=in_n_caps,
91 | in_d_capsules=in_d_caps,
92 | out_n_capsules=params['capsules'][i]['num_caps'],
93 | out_d_capsules=params['capsules'][i]['caps_dim'],
94 | matrix_pose=params['capsules'][i]['matrix_pose'],
95 | dp=dp
96 | )
97 | )
98 |
99 | ## Class Capsule Layer
100 | if not len(params['capsules'])==0:
101 | if params['capsules'][-1]['type'] == 'FC':
102 | in_n_caps = params['capsules'][-1]['num_caps']
103 | in_d_caps = params['capsules'][-1]['caps_dim']
104 | elif params['capsules'][-1]['type'] == 'CONV':
105 | in_n_caps = params['capsules'][-1]['num_caps'] * params['capsules'][-1]['out_img_size'] *\
106 | params['capsules'][-1]['out_img_size']
107 | in_d_caps = params['capsules'][-1]['caps_dim']
108 | else:
109 | in_n_caps = params['primary_capsules']['num_caps'] * params['primary_capsules']['out_img_size'] *\
110 | params['primary_capsules']['out_img_size']
111 | in_d_caps = params['primary_capsules']['caps_dim']
112 | self.capsule_layers.append(
113 | layers.CapsuleFC(in_n_capsules=in_n_caps,
114 | in_d_capsules=in_d_caps,
115 | out_n_capsules=params['class_capsules']['num_caps'],
116 | out_d_capsules=params['class_capsules']['caps_dim'],
117 | matrix_pose=params['class_capsules']['matrix_pose'],
118 | dp=dp
119 | )
120 | )
121 |
122 | ## After Capsule
123 | # fixed classifier for all class capsules
124 | self.final_fc = nn.Linear(params['class_capsules']['caps_dim'], 1)
125 | # different classifier for different capsules
126 | #self.final_fc = nn.Parameter(torch.randn(params['class_capsules']['num_caps'], params['class_capsules']['caps_dim']))
127 |
128 | def forward(self, x, lbl_1=None, lbl_2=None):
129 | #### Forward Pass
130 | ## Backbone (before capsule)
131 | c = self.pre_caps(x)
132 |
133 | ## Primary Capsule Layer (a single CNN)
134 | u = self.pc_layer(c) # torch.Size([100, 512, 14, 14])
135 | u = u.permute(0, 2, 3, 1) # 100, 14, 14, 512
136 | u = u.view(u.shape[0], self.pc_output_dim, self.pc_output_dim, self.pc_num_caps, self.pc_caps_dim) # 100, 14, 14, 32, 16
137 | u = u.permute(0, 3, 1, 2, 4) # 100, 32, 14, 14, 16
138 | init_capsule_value = self.nonlinear_act(u)#capsule_utils.squash(u)
139 |
140 | ## Main Capsule Layers
141 | # concurrent routing
142 | if not self.sequential_routing:
143 | # first iteration
144 | # perform initilialization for the capsule values as single forward passing
145 | capsule_values, _val = [init_capsule_value], init_capsule_value
146 | for i in range(len(self.capsule_layers)):
147 | _val = self.capsule_layers[i].forward(_val, 0)
148 | capsule_values.append(_val) # get the capsule value for next layer
149 |
150 | # second to t iterations
151 | # perform the routing between capsule layers
152 | for n in range(self.num_routing-1):
153 | _capsule_values = [init_capsule_value]
154 | for i in range(len(self.capsule_layers)):
155 | _val = self.capsule_layers[i].forward(capsule_values[i], n,
156 | capsule_values[i+1])
157 | _capsule_values.append(_val)
158 | capsule_values = _capsule_values
159 | # sequential routing
160 | else:
161 | capsule_values, _val = [init_capsule_value], init_capsule_value
162 | for i in range(len(self.capsule_layers)):
163 | # first iteration
164 | __val = self.capsule_layers[i].forward(_val, 0)
165 | # second to t iterations
166 | # perform the routing between capsule layers
167 | for n in range(self.num_routing-1):
168 | __val = self.capsule_layers[i].forward(_val, n, __val)
169 | _val = __val
170 | capsule_values.append(_val)
171 |
172 | ## After Capsule
173 | out = capsule_values[-1]
174 | out = self.final_fc(out) # fixed classifier for all capsules
175 | out = out.squeeze() # fixed classifier for all capsules
176 | #out = torch.einsum('bnd, nd->bn', out, self.final_fc) # different classifiers for distinct capsules
177 |
178 | return out
179 |
--------------------------------------------------------------------------------
/src/layers.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2019 Apple Inc. All Rights Reserved.
4 | #
5 | '''Capsule in PyTorch
6 | TBD
7 | '''
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | import numpy as np
14 |
15 | #### Simple Backbone ####
16 | class simple_backbone(nn.Module):
17 | def __init__(self, cl_input_channels,cl_num_filters,cl_filter_size,
18 | cl_stride,cl_padding):
19 | super(simple_backbone, self).__init__()
20 | self.pre_caps = nn.Sequential(
21 | nn.Conv2d(in_channels=cl_input_channels,
22 | out_channels=cl_num_filters,
23 | kernel_size=cl_filter_size,
24 | stride=cl_stride,
25 | padding=cl_padding),
26 | nn.ReLU(),
27 | )
28 | def forward(self, x):
29 | out = self.pre_caps(x) # x is an image
30 | return out
31 |
32 |
33 | #### ResNet Backbone ####
34 | class BasicBlock(nn.Module):
35 | expansion = 1
36 |
37 | def __init__(self, in_planes, planes, stride=1):
38 | super(BasicBlock, self).__init__()
39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
40 | self.bn1 = nn.BatchNorm2d(planes)
41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
42 | self.bn2 = nn.BatchNorm2d(planes)
43 |
44 | self.shortcut = nn.Sequential()
45 | if stride != 1 or in_planes != self.expansion*planes:
46 | self.shortcut = nn.Sequential(
47 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
48 | nn.BatchNorm2d(self.expansion*planes)
49 | )
50 |
51 | def forward(self, x):
52 | out = F.relu(self.bn1(self.conv1(x)))
53 | out = self.bn2(self.conv2(out))
54 | out += self.shortcut(x)
55 | out = F.relu(out)
56 | return out
57 |
58 | class resnet_backbone(nn.Module):
59 | def __init__(self, cl_input_channels, cl_num_filters,
60 | cl_stride):
61 | super(resnet_backbone, self).__init__()
62 | self.in_planes = 64
63 | def _make_layer(block, planes, num_blocks, stride):
64 | strides = [stride] + [1]*(num_blocks-1)
65 | layers = []
66 | for stride in strides:
67 | layers.append(block(self.in_planes, planes, stride))
68 | self.in_planes = planes * block.expansion
69 | return nn.Sequential(*layers)
70 |
71 | self.pre_caps = nn.Sequential(
72 | nn.Conv2d(in_channels=cl_input_channels,
73 | out_channels=64,
74 | kernel_size=3,
75 | stride=1,
76 | padding=1,
77 | bias=False),
78 | nn.BatchNorm2d(64),
79 | nn.ReLU(),
80 | _make_layer(block=BasicBlock, planes=64, num_blocks=3, stride=1), # num_blocks=2 or 3
81 | _make_layer(block=BasicBlock, planes=cl_num_filters, num_blocks=4, stride=cl_stride), # num_blocks=2 or 4
82 | )
83 | def forward(self, x):
84 | out = self.pre_caps(x) # x is an image
85 | return out
86 |
87 | #### Capsule Layer ####
88 | class CapsuleFC(nn.Module):
89 | r"""Applies as a capsule fully-connected layer.
90 | TBD
91 | """
92 | def __init__(self, in_n_capsules, in_d_capsules, out_n_capsules, out_d_capsules, matrix_pose, dp):
93 | super(CapsuleFC, self).__init__()
94 | self.in_n_capsules = in_n_capsules
95 | self.in_d_capsules = in_d_capsules
96 | self.out_n_capsules = out_n_capsules
97 | self.out_d_capsules = out_d_capsules
98 | self.matrix_pose = matrix_pose
99 |
100 | if matrix_pose:
101 | self.sqrt_d = int(np.sqrt(self.in_d_capsules))
102 | self.weight_init_const = np.sqrt(out_n_capsules/(self.sqrt_d*in_n_capsules))
103 | self.w = nn.Parameter(self.weight_init_const* \
104 | torch.randn(in_n_capsules, self.sqrt_d, self.sqrt_d, out_n_capsules))
105 |
106 | else:
107 | self.weight_init_const = np.sqrt(out_n_capsules/(in_d_capsules*in_n_capsules))
108 | self.w = nn.Parameter(self.weight_init_const* \
109 | torch.randn(in_n_capsules, in_d_capsules, out_n_capsules, out_d_capsules))
110 | self.dropout_rate = dp
111 | self.nonlinear_act = nn.LayerNorm(out_d_capsules)
112 | self.drop = nn.Dropout(self.dropout_rate)
113 | self.scale = 1. / (out_d_capsules ** 0.5)
114 |
115 | def extra_repr(self):
116 | return 'in_n_capsules={}, in_d_capsules={}, out_n_capsules={}, out_d_capsules={}, matrix_pose={}, \
117 | weight_init_const={}, dropout_rate={}'.format(
118 | self.in_n_capsules, self.in_d_capsules, self.out_n_capsules, self.out_d_capsules, self.matrix_pose,
119 | self.weight_init_const, self.dropout_rate
120 | )
121 | def forward(self, input, num_iter, next_capsule_value=None):
122 | # b: batch size
123 | # n: num of capsules in current layer
124 | # a: dim of capsules in current layer
125 | # m: num of capsules in next layer
126 | # d: dim of capsules in next layer
127 | if len(input.shape) == 5:
128 | input = input.permute(0, 4, 1, 2, 3)
129 | input = input.contiguous().view(input.shape[0], input.shape[1], -1)
130 | input = input.permute(0,2,1)
131 |
132 | if self.matrix_pose:
133 | w = self.w # nxdm
134 | _input = input.view(input.shape[0], input.shape[1], self.sqrt_d, self.sqrt_d) # bnax
135 | else:
136 | w = self.w
137 |
138 | if next_capsule_value is None:
139 | query_key = torch.zeros(self.in_n_capsules, self.out_n_capsules).type_as(input)
140 | query_key = F.softmax(query_key, dim=1)
141 |
142 | if self.matrix_pose:
143 | next_capsule_value = torch.einsum('nm, bnax, nxdm->bmad', query_key, _input, w)
144 | else:
145 | next_capsule_value = torch.einsum('nm, bna, namd->bmd', query_key, input, w)
146 | else:
147 | if self.matrix_pose:
148 | next_capsule_value = next_capsule_value.view(next_capsule_value.shape[0],
149 | next_capsule_value.shape[1], self.sqrt_d, self.sqrt_d)
150 | _query_key = torch.einsum('bnax, nxdm, bmad->bnm', _input, w, next_capsule_value)
151 | else:
152 | _query_key = torch.einsum('bna, namd, bmd->bnm', input, w, next_capsule_value)
153 | _query_key.mul_(self.scale)
154 | query_key = F.softmax(_query_key, dim=2)
155 | query_key = query_key / (torch.sum(query_key, dim=2, keepdim=True) + 1e-10)
156 |
157 | if self.matrix_pose:
158 | next_capsule_value = torch.einsum('bnm, bnax, nxdm->bmad', query_key, _input,
159 | w)
160 | else:
161 | next_capsule_value = torch.einsum('bnm, bna, namd->bmd', query_key, input,
162 | w)
163 |
164 | next_capsule_value = self.drop(next_capsule_value)
165 | if not next_capsule_value.shape[-1] == 1:
166 | if self.matrix_pose:
167 | next_capsule_value = next_capsule_value.view(next_capsule_value.shape[0],
168 | next_capsule_value.shape[1], self.out_d_capsules)
169 | next_capsule_value = self.nonlinear_act(next_capsule_value)
170 | else:
171 | next_capsule_value = self.nonlinear_act(next_capsule_value)
172 | return next_capsule_value
173 |
174 | class CapsuleCONV(nn.Module):
175 | r"""Applies as a capsule convolutional layer.
176 | TBD
177 | """
178 | def __init__(self, in_n_capsules, in_d_capsules, out_n_capsules, out_d_capsules,
179 | kernel_size, stride, matrix_pose, dp, coordinate_add=False):
180 | super(CapsuleCONV, self).__init__()
181 | self.in_n_capsules = in_n_capsules
182 | self.in_d_capsules = in_d_capsules
183 | self.out_n_capsules = out_n_capsules
184 | self.out_d_capsules = out_d_capsules
185 | self.kernel_size = kernel_size
186 | self.stride = stride
187 | self.matrix_pose = matrix_pose
188 | self.coordinate_add = coordinate_add
189 |
190 | if matrix_pose:
191 | self.sqrt_d = int(np.sqrt(self.in_d_capsules))
192 | self.weight_init_const = np.sqrt(out_n_capsules/(self.sqrt_d*in_n_capsules*kernel_size*kernel_size))
193 | self.w = nn.Parameter(self.weight_init_const*torch.randn(kernel_size, kernel_size,
194 | in_n_capsules, self.sqrt_d, self.sqrt_d, out_n_capsules))
195 | else:
196 | self.weight_init_const = np.sqrt(out_n_capsules/(in_d_capsules*in_n_capsules*kernel_size*kernel_size))
197 | self.w = nn.Parameter(self.weight_init_const*torch.randn(kernel_size, kernel_size,
198 | in_n_capsules, in_d_capsules, out_n_capsules,
199 | out_d_capsules))
200 | self.nonlinear_act = nn.LayerNorm(out_d_capsules)
201 | self.dropout_rate = dp
202 | self.drop = nn.Dropout(self.dropout_rate)
203 | self.scale = 1. / (out_d_capsules ** 0.5)
204 |
205 | def extra_repr(self):
206 | return 'in_n_capsules={}, in_d_capsules={}, out_n_capsules={}, out_d_capsules={}, \
207 | kernel_size={}, stride={}, coordinate_add={}, matrix_pose={}, weight_init_const={}, \
208 | dropout_rate={}'.format(
209 | self.in_n_capsules, self.in_d_capsules, self.out_n_capsules, self.out_d_capsules,
210 | self.kernel_size, self.stride, self.coordinate_add, self.matrix_pose, self.weight_init_const,
211 | self.dropout_rate
212 | )
213 | def input_expansion(self, input):
214 | # input has size [batch x num_of_capsule x height x width x x capsule_dimension]
215 | b, n, h, w, d = input.shape
216 |
217 | h_out = int((h-self.kernel_size)/self.stride+1)
218 | w_out = int((w-self.kernel_size)/self.stride+1)
219 |
220 | # this may be slow if the image size is large
221 | # TODO: kind of stupid implementation :'(
222 | inputs = torch.stack([input[:, :, self.stride * i:self.stride * i + self.kernel_size,
223 | self.stride * j:self.stride * j + self.kernel_size, :] for i in range(h_out) for j in range(w_out)],
224 | dim=-1) # b,n, kernel_size, kernel_size, h_out*w_out*d
225 | inputs = inputs.view(b,n,self.kernel_size, self.kernel_size, h_out, w_out, d)
226 | return inputs
227 |
228 | def forward(self, input, num_iter, next_capsule_value=None):
229 | # k,l: kernel size
230 | # h,w: output width and length
231 | inputs = self.input_expansion(input)
232 |
233 | if self.matrix_pose:
234 | w = self.w # klnxdm
235 | _inputs = inputs.view(inputs.shape[0], inputs.shape[1], inputs.shape[2], inputs.shape[3],\
236 | inputs.shape[4], inputs.shape[5], self.sqrt_d, self.sqrt_d) # bnklmhax
237 | else:
238 | w = self.w
239 |
240 | if next_capsule_value is None:
241 | query_key = torch.zeros(self.in_n_capsules, self.kernel_size, self.kernel_size,
242 | self.out_n_capsules).type_as(inputs)
243 | query_key = F.softmax(query_key, dim=3)
244 |
245 | if self.matrix_pose:
246 | next_capsule_value = torch.einsum('nklm, bnklhwax, klnxdm->bmhwad', query_key,
247 | _inputs, w)
248 | else:
249 | next_capsule_value = torch.einsum('nklm, bnklhwa, klnamd->bmhwd', query_key,
250 | inputs, w)
251 | else:
252 | if self.matrix_pose:
253 | next_capsule_value = next_capsule_value.view(next_capsule_value.shape[0],\
254 | next_capsule_value.shape[1], next_capsule_value.shape[2],\
255 | next_capsule_value.shape[3], self.sqrt_d, self.sqrt_d)
256 | _query_key = torch.einsum('bnklhwax, klnxdm, bmhwad->bnklmhw', _inputs, w,
257 | next_capsule_value)
258 | else:
259 | _query_key = torch.einsum('bnklhwa, klnamd, bmhwd->bnklmhw', inputs, w,
260 | next_capsule_value)
261 | _query_key.mul_(self.scale)
262 | query_key = F.softmax(_query_key, dim=4)
263 | query_key = query_key / (torch.sum(query_key, dim=4, keepdim=True) + 1e-10)
264 |
265 | if self.matrix_pose:
266 | next_capsule_value = torch.einsum('bnklmhw, bnklhwax, klnxdm->bmhwad', query_key,
267 | _inputs, w)
268 | else:
269 | next_capsule_value = torch.einsum('bnklmhw, bnklhwa, klnamd->bmhwd', query_key,
270 | inputs, w)
271 |
272 | next_capsule_value = self.drop(next_capsule_value)
273 | if not next_capsule_value.shape[-1] == 1:
274 | if self.matrix_pose:
275 | next_capsule_value = next_capsule_value.view(next_capsule_value.shape[0],\
276 | next_capsule_value.shape[1], next_capsule_value.shape[2],\
277 | next_capsule_value.shape[3], self.out_d_capsules)
278 | next_capsule_value = self.nonlinear_act(next_capsule_value)
279 | else:
280 | next_capsule_value = self.nonlinear_act(next_capsule_value)
281 |
282 | return next_capsule_value
283 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2020 Apple Inc. All rights reserved.
4 | #
5 | '''Some helper functions for PyTorch, including:
6 | - get_mean_and_std: calculate the mean and std value of dataset.
7 | - msr_init: net parameter initialization.
8 | - progress_bar: progress bar mimic xlua.progress.
9 | copy from https://github.com/kuangliu/pytorch-cifar/blob/master/utils.py
10 | '''
11 | import os
12 | import sys
13 | import time
14 | import math
15 |
16 | import torch.nn as nn
17 | import torch.nn.init as init
18 |
19 |
20 | def get_mean_and_std(dataset):
21 | '''Compute the mean and std value of dataset.'''
22 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
23 | mean = torch.zeros(3)
24 | std = torch.zeros(3)
25 | print('==> Computing mean and std..')
26 | for inputs, targets in dataloader:
27 | for i in range(3):
28 | mean[i] += inputs[:,i,:,:].mean()
29 | std[i] += inputs[:,i,:,:].std()
30 | mean.div_(len(dataset))
31 | std.div_(len(dataset))
32 | return mean, std
33 |
34 | def init_params(net):
35 | '''Init layer parameters.'''
36 | for m in net.modules():
37 | if isinstance(m, nn.Conv2d):
38 | init.kaiming_normal(m.weight, mode='fan_out')
39 | if m.bias:
40 | init.constant(m.bias, 0)
41 | elif isinstance(m, nn.BatchNorm2d):
42 | init.constant(m.weight, 1)
43 | init.constant(m.bias, 0)
44 | elif isinstance(m, nn.Linear):
45 | init.normal(m.weight, std=1e-3)
46 | if m.bias:
47 | init.constant(m.bias, 0)
48 |
49 |
50 | _, term_width = os.popen('stty size', 'r').read().split()
51 | term_width = int(term_width)
52 |
53 | TOTAL_BAR_LENGTH = 65.
54 | last_time = time.time()
55 | begin_time = last_time
56 | def progress_bar(current, total, msg=None):
57 | global last_time, begin_time
58 | if current == 0:
59 | begin_time = time.time() # Reset for new bar.
60 |
61 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
62 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
63 |
64 | sys.stdout.write(' [')
65 | for i in range(cur_len):
66 | sys.stdout.write('=')
67 | sys.stdout.write('>')
68 | for i in range(rest_len):
69 | sys.stdout.write('.')
70 | sys.stdout.write(']')
71 |
72 | cur_time = time.time()
73 | step_time = cur_time - last_time
74 | last_time = cur_time
75 | tot_time = cur_time - begin_time
76 |
77 | L = []
78 | L.append(' Step: %s' % format_time(step_time))
79 | L.append(' | Tot: %s' % format_time(tot_time))
80 | if msg:
81 | L.append(' | ' + msg)
82 |
83 | msg = ''.join(L)
84 | sys.stdout.write(msg)
85 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
86 | sys.stdout.write(' ')
87 |
88 | # Go back to the center of the bar.
89 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
90 | sys.stdout.write('\b')
91 | sys.stdout.write(' %d/%d ' % (current+1, total))
92 |
93 | if current < total-1:
94 | sys.stdout.write('\r')
95 | else:
96 | sys.stdout.write('\n')
97 | sys.stdout.flush()
98 |
99 | def format_time(seconds):
100 | days = int(seconds / 3600/24)
101 | seconds = seconds - days*3600*24
102 | hours = int(seconds / 3600)
103 | seconds = seconds - hours*3600
104 | minutes = int(seconds / 60)
105 | seconds = seconds - minutes*60
106 | secondsf = int(seconds)
107 | seconds = seconds - secondsf
108 | millis = int(seconds*1000)
109 |
110 | f = ''
111 | i = 1
112 | if days > 0:
113 | f += str(days) + 'D'
114 | i += 1
115 | if hours > 0 and i <= 2:
116 | f += str(hours) + 'h'
117 | i += 1
118 | if minutes > 0 and i <= 2:
119 | f += str(minutes) + 'm'
120 | i += 1
121 | if secondsf > 0 and i <= 2:
122 | f += str(secondsf) + 's'
123 | i += 1
124 | if millis > 0 and i <= 2:
125 | f += str(millis) + 'ms'
126 | i += 1
127 | if f == '':
128 | f = '0ms'
129 | return f
130 |
--------------------------------------------------------------------------------