├── LICENSE ├── README.md ├── data_loader.py ├── models ├── __init__.py ├── googlenet_imagenet.py ├── mobilenetv2.py ├── model_slicing.py ├── resnet_cifar.py ├── resnet_imagenet.py ├── vgg_cifar.py └── vgg_imagenet.py ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── lr_scheduler.py ├── profiling.py └── utilities.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model Slicing 2 | 3 | ![version](https://img.shields.io/badge/version-v2.2-brightgreen) 4 | ![python](https://img.shields.io/badge/python-3.8.3-blue) 5 | ![pytorch](https://img.shields.io/badge/pytorch-1.6.0-blue) 6 | ![singa](https://img.shields.io/badge/singa-3.1.0-orange) 7 | 8 | This repository contains our PyTorch implementation of [Model Slicing for Supporting Complex Analytics with Elastic Inference Cost and Resource Constraints](https://arxiv.org/abs/1904.01831). 9 | Model Slicing is a general *dynamic width technique* that enables neural networks to support 10 | budgeted inference, namely producing predictions within a prescribed computational budget by 11 | dynamically trading off accuracy for efficiency at runtime. 12 | 13 | Budgeted inference is achieved by dividing each layer of the network into equal-sized *groups* 14 | of basic components (i.e., neurons in dense layers and channels in convolutional layers). 15 | Technically, we use a single parameter called *slice rate **r*** to control the fraction of 16 | groups involved in computation for all layers at runtime, namely to control the width of 17 | the network in both training and inference. 18 | 19 | In particular, the groups involved in computation always start from the first group, and 20 | contiguously to the dynamically determined last group indexed by the current slice rate. 21 | E.g., a slice rate of 0.5 will select the first two groups in a layer of 4 groups 22 | as illustrated below. 23 | 24 | 25 | 26 | 27 | ### This repo includes: 28 | 29 | 1. representative models ([/models](https://github.com/nusdbsystem/model-slicing/blob/main/models)) 30 | 2. codes for model slicing training ([train.py](https://github.com/nusdbsystem/model-slicing/blob/main/train.py)) 31 | 3. codes for supporting *model slicing* functionalities ([models/model_slicing.py](https://github.com/nusdbsystem/model-slicing/blob/main/models/model_slicing.py)) 32 | * upgrading a PyTorch model to support *model slicing* by calling one function ([models/model_slicing/upgrade_dynamic_layers](https://github.com/nusdbsystem/model-slicing/blob/main/models/model_slicing.py)) 33 | 34 | 35 | ### Training 36 | 1. Dependencies 37 | ``` 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | 2. Model Training 42 | 43 | ``` 44 | Example training code: 45 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --exp_name resnet_50 --net_type resnet --group 8 --depth 50 --sr_list 1.0 0.75 0.5 0.25 --sr_scheduler_type random_min_max --sr_rand_num 1 --epoch 100 --batch_size 256 --lr 0.1 --dataset imagenet --data_dir /data/ --log_freq 50 46 | 47 | Please check help info in argparse.ArgumentParser (train.py) for configuration details. 48 | ``` 49 | 50 | 3. One line to support *Model Slicing* 51 | 52 | ``` 53 | model = upgrade_dynamic_layers(model, args.groups, args.sr_list) 54 | 55 | * groups: the number of groups for each layer, e.g. 8 56 | * sr_list: slice rate list, e.g. [1.0, 0.75, 0.5, 0.25] 57 | ``` 58 | 59 | ### Contact 60 | To ask questions or report issues, you can directly drop us an [email](mailto:shaofeng@comp.nus.edu.sg). 61 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | 6 | def data_loader(args): 7 | if args.dataset.startswith('cifar'): 8 | normalize = transforms.Normalize(mean=[x/255.0 for x in [125.3, 123.0, 113.9]], 9 | std=[x/255.0 for x in [63.0, 62.1, 66.7]]) 10 | if args.augment: 11 | transform_train = transforms.Compose([ 12 | transforms.RandomCrop(32, padding=4), 13 | transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | normalize, 16 | ]) 17 | else: 18 | transform_train = transforms.Compose([ 19 | transforms.ToTensor(), 20 | normalize, 21 | ]) 22 | transform_test = transforms.Compose([ 23 | transforms.ToTensor(), 24 | normalize 25 | ]) 26 | 27 | if args.dataset == 'cifar100': 28 | train_loader = torch.utils.data.DataLoader( 29 | datasets.CIFAR100(args.data_dir, train=True, download=False, transform=transform_train), 30 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 31 | val_loader = torch.utils.data.DataLoader( 32 | datasets.CIFAR100(args.data_dir, train=False, transform=transform_test), 33 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) 34 | class_num = 100 35 | elif args.dataset == 'cifar10': 36 | train_loader = torch.utils.data.DataLoader( 37 | datasets.CIFAR10(args.data_dir, train=True, download=False, transform=transform_train), 38 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 39 | val_loader = torch.utils.data.DataLoader( 40 | datasets.CIFAR10(args.data_dir, train=False, transform=transform_test), 41 | batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) 42 | class_num = 10 43 | else: 44 | raise Exception ('unknown dataset: {}'.format(args.dataset)) 45 | return train_loader, val_loader, class_num 46 | 47 | elif args.dataset == 'imagenet': 48 | traindir = os.path.join(args.data_dir, 'train') 49 | valdir = os.path.join(args.data_dir, 'val') 50 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225]) 52 | 53 | train_dataset = datasets.ImageFolder( 54 | traindir, 55 | transforms.Compose([ 56 | transforms.RandomResizedCrop(224), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.ToTensor(), 59 | normalize, 60 | ])) 61 | 62 | train_loader = torch.utils.data.DataLoader( 63 | train_dataset, batch_size=args.batch_size, shuffle=True, 64 | num_workers=args.workers, pin_memory=True) 65 | 66 | val_loader = torch.utils.data.DataLoader( 67 | datasets.ImageFolder(valdir, transforms.Compose([ 68 | transforms.Resize(256), 69 | transforms.CenterCrop(224), 70 | transforms.ToTensor(), 71 | normalize, 72 | ])), 73 | batch_size=args.batch_size, shuffle=False, 74 | num_workers=args.workers, pin_memory=True) 75 | 76 | class_num = 1000 77 | return train_loader, val_loader, class_num 78 | 79 | else: 80 | raise Exception ('unknown dataset: {}'.format(args.dataset)) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_slicing import * 2 | 3 | from .resnet_cifar import * 4 | from .vgg_cifar import * 5 | 6 | from .resnet_imagenet import * 7 | from .vgg_imagenet import * 8 | from .googlenet_imagenet import * 9 | from .mobilenetv2 import * 10 | -------------------------------------------------------------------------------- /models/googlenet_imagenet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['GoogLeNet', 'imagenet_googlenet'] 8 | 9 | model_urls = { 10 | # GoogLeNet ported from TensorFlow 11 | 'googlenet': 'https://download.pytorch.org/models/googlenet-1378be20.pth', 12 | } 13 | 14 | _GoogLeNetOutputs = namedtuple('GoogLeNetOutputs', ['logits', 'aux_logits2', 'aux_logits1']) 15 | 16 | 17 | class GoogLeNet(nn.Module): 18 | 19 | def __init__(self, num_classes=1000, aux_logits=True, transform_input=False, init_weights=True): 20 | super(GoogLeNet, self).__init__() 21 | self.aux_logits = aux_logits 22 | self.transform_input = transform_input 23 | 24 | self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3) 25 | self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 26 | self.conv2 = BasicConv2d(64, 64, kernel_size=1) 27 | self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 28 | self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 29 | 30 | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 31 | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 32 | self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 33 | 34 | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 35 | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 36 | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 37 | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 38 | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 39 | self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) 40 | 41 | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 42 | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 43 | 44 | if aux_logits: 45 | self.aux1 = InceptionAux(512, num_classes) 46 | self.aux2 = InceptionAux(528, num_classes) 47 | 48 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 49 | self.dropout = nn.Dropout(0.2) 50 | self.fc = nn.Linear(1024, num_classes) 51 | 52 | if init_weights: 53 | self._initialize_weights() 54 | 55 | def _initialize_weights(self): 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 58 | import scipy.stats as stats 59 | X = stats.truncnorm(-2, 2, scale=0.01) 60 | values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 61 | values = values.view(m.weight.size()) 62 | with torch.no_grad(): 63 | m.weight.copy_(values) 64 | elif isinstance(m, nn.BatchNorm2d): 65 | nn.init.constant_(m.weight, 1) 66 | nn.init.constant_(m.bias, 0) 67 | 68 | def forward(self, x): 69 | if self.transform_input: 70 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 71 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 72 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 73 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 74 | 75 | # N x 3 x 224 x 224 76 | x = self.conv1(x) 77 | # N x 64 x 112 x 112 78 | x = self.maxpool1(x) 79 | # N x 64 x 56 x 56 80 | x = self.conv2(x) 81 | # N x 64 x 56 x 56 82 | x = self.conv3(x) 83 | # N x 192 x 56 x 56 84 | x = self.maxpool2(x) 85 | 86 | # N x 192 x 28 x 28 87 | x = self.inception3a(x) 88 | # N x 256 x 28 x 28 89 | x = self.inception3b(x) 90 | # N x 480 x 28 x 28 91 | x = self.maxpool3(x) 92 | # N x 480 x 14 x 14 93 | x = self.inception4a(x) 94 | # N x 512 x 14 x 14 95 | if self.training and self.aux_logits: 96 | aux1 = self.aux1(x) 97 | 98 | x = self.inception4b(x) 99 | # N x 512 x 14 x 14 100 | x = self.inception4c(x) 101 | # N x 512 x 14 x 14 102 | x = self.inception4d(x) 103 | # N x 528 x 14 x 14 104 | if self.training and self.aux_logits: 105 | aux2 = self.aux2(x) 106 | 107 | x = self.inception4e(x) 108 | # N x 832 x 14 x 14 109 | x = self.maxpool4(x) 110 | # N x 832 x 7 x 7 111 | x = self.inception5a(x) 112 | # N x 832 x 7 x 7 113 | x = self.inception5b(x) 114 | # N x 1024 x 7 x 7 115 | 116 | x = self.avgpool(x) 117 | # N x 1024 x 1 x 1 118 | x = torch.flatten(x, 1) 119 | # N x 1024 120 | x = self.dropout(x) 121 | x = self.fc(x) 122 | # N x 1000 (num_classes) 123 | if self.training and self.aux_logits: 124 | return _GoogLeNetOutputs(x, aux2, aux1) 125 | return x 126 | 127 | 128 | class Inception(nn.Module): 129 | 130 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 131 | super(Inception, self).__init__() 132 | 133 | self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 134 | 135 | self.branch2 = nn.Sequential( 136 | BasicConv2d(in_channels, ch3x3red, kernel_size=1), 137 | BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) 138 | ) 139 | 140 | self.branch3 = nn.Sequential( 141 | BasicConv2d(in_channels, ch5x5red, kernel_size=1), 142 | BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1) 143 | ) 144 | 145 | self.branch4 = nn.Sequential( 146 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 147 | BasicConv2d(in_channels, pool_proj, kernel_size=1) 148 | ) 149 | 150 | def forward(self, x): 151 | branch1 = self.branch1(x) 152 | branch2 = self.branch2(x) 153 | branch3 = self.branch3(x) 154 | branch4 = self.branch4(x) 155 | 156 | outputs = [branch1, branch2, branch3, branch4] 157 | return torch.cat(outputs, 1) 158 | 159 | 160 | class InceptionAux(nn.Module): 161 | 162 | def __init__(self, in_channels, num_classes): 163 | super(InceptionAux, self).__init__() 164 | self.conv = BasicConv2d(in_channels, 128, kernel_size=1) 165 | 166 | self.fc1 = nn.Linear(2048, 1024) 167 | self.fc2 = nn.Linear(1024, num_classes) 168 | 169 | def forward(self, x): 170 | # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 171 | x = F.adaptive_avg_pool2d(x, (4, 4)) 172 | # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 173 | x = self.conv(x) 174 | # N x 128 x 4 x 4 175 | x = torch.flatten(x, 1) 176 | # N x 2048 177 | x = F.relu(self.fc1(x), inplace=True) 178 | # N x 2048 179 | x = F.dropout(x, 0.7, training=self.training) 180 | # N x 2048 181 | x = self.fc2(x) 182 | # N x 1024 183 | 184 | return x 185 | 186 | 187 | class BasicConv2d(nn.Module): 188 | 189 | def __init__(self, in_channels, out_channels, **kwargs): 190 | super(BasicConv2d, self).__init__() 191 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 192 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 193 | 194 | def forward(self, x): 195 | x = self.conv(x) 196 | x = self.bn(x) 197 | return F.relu(x, inplace=True) 198 | 199 | 200 | def imagenet_googlenet(args): 201 | return GoogLeNet() -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | ''' 2 | adopt from https://github.com/tonylins/pytorch-mobilenet-v2 3 | ''' 4 | 5 | import torch.nn as nn 6 | import math 7 | 8 | __all__ = ['MobileNetV2', 'imagenet_mobilenetv2'] 9 | 10 | 11 | def _make_divisible(v, divisor, min_value=None): 12 | """ 13 | This function is taken from the original tf repo. 14 | It ensures that all layers have a channel number that is divisible by 8 15 | It can be seen here: 16 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 17 | :param v: 18 | :param divisor: 19 | :param min_value: 20 | :return: 21 | """ 22 | if min_value is None: 23 | min_value = divisor 24 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 25 | # Make sure that round down does not go down by more than 10%. 26 | if new_v < 0.9 * v: 27 | new_v += divisor 28 | return new_v 29 | 30 | 31 | def conv_3x3_bn(inp, oup, stride): 32 | return nn.Sequential( 33 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 34 | nn.BatchNorm2d(oup), 35 | nn.ReLU6(inplace=True) 36 | ) 37 | 38 | 39 | def conv_1x1_bn(inp, oup): 40 | return nn.Sequential( 41 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 42 | nn.BatchNorm2d(oup), 43 | nn.ReLU6(inplace=True) 44 | ) 45 | 46 | 47 | class InvertedResidual(nn.Module): 48 | def __init__(self, inp, oup, stride, expand_ratio): 49 | super(InvertedResidual, self).__init__() 50 | assert stride in [1, 2] 51 | 52 | hidden_dim = round(inp * expand_ratio) 53 | self.identity = stride == 1 and inp == oup 54 | 55 | if expand_ratio == 1: 56 | self.conv = nn.Sequential( 57 | # dw 58 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 59 | nn.BatchNorm2d(hidden_dim), 60 | nn.ReLU6(inplace=True), 61 | # pw-linear 62 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 63 | nn.BatchNorm2d(oup), 64 | ) 65 | else: 66 | self.conv = nn.Sequential( 67 | # pw 68 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 69 | nn.BatchNorm2d(hidden_dim), 70 | nn.ReLU6(inplace=True), 71 | # dw 72 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 73 | nn.BatchNorm2d(hidden_dim), 74 | nn.ReLU6(inplace=True), 75 | # pw-linear 76 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 77 | nn.BatchNorm2d(oup), 78 | ) 79 | 80 | def forward(self, x): 81 | if self.identity: 82 | return x + self.conv(x) 83 | else: 84 | return self.conv(x) 85 | 86 | 87 | class MobileNetV2(nn.Module): 88 | def __init__(self, num_classes=1000, width_mult=1.): 89 | super(MobileNetV2, self).__init__() 90 | # setting of inverted residual blocks 91 | self.cfgs = [ 92 | # t, c, n, s 93 | [1, 16, 1, 1], 94 | [6, 24, 2, 2], 95 | [6, 32, 3, 2], 96 | [6, 64, 4, 2], 97 | [6, 96, 3, 1], 98 | [6, 160, 3, 2], 99 | [6, 320, 1, 1], 100 | ] 101 | 102 | # building first layer 103 | input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8) 104 | layers = [conv_3x3_bn(3, input_channel, 2)] 105 | # building inverted residual blocks 106 | block = InvertedResidual 107 | for t, c, n, s in self.cfgs: 108 | output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8) 109 | for i in range(n): 110 | layers.append(block(input_channel, output_channel, s if i == 0 else 1, t)) 111 | input_channel = output_channel 112 | self.features = nn.Sequential(*layers) 113 | # building last several layers 114 | output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280 115 | self.conv = conv_1x1_bn(input_channel, output_channel) 116 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 117 | self.classifier = nn.Linear(output_channel, num_classes) 118 | 119 | self._initialize_weights() 120 | 121 | def forward(self, x): 122 | x = self.features(x) 123 | x = self.conv(x) 124 | x = self.avgpool(x) 125 | x = x.view(x.size(0), -1) 126 | x = self.classifier(x) 127 | return x 128 | 129 | def _initialize_weights(self): 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | if m.bias is not None: 135 | m.bias.data.zero_() 136 | elif isinstance(m, nn.BatchNorm2d): 137 | m.weight.data.fill_(1) 138 | m.bias.data.zero_() 139 | elif isinstance(m, nn.Linear): 140 | m.weight.data.normal_(0, 0.01) 141 | m.bias.data.zero_() 142 | 143 | def imagenet_mobilenetv2(args): 144 | """ 145 | Constructs a MobileNet V2 model 146 | """ 147 | return MobileNetV2() 148 | 149 | -------------------------------------------------------------------------------- /models/model_slicing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | __all__ = ['DynamicConv2d', 'DynamicGN', 'DynamicLinear', 'DynamicBN', 'DYNAMIC_LAYERS', 6 | 'update_sr_idx', 'bind_update_sr_idx', 'upgrade_dynamic_layers', 7 | 'create_sr_scheduler'] 8 | 9 | class DynamicConv2d(nn.Conv2d): 10 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 11 | padding=0, dilation=1, groups=1, bias=True, sr_in_list=(1.,), sr_out_list=None): 12 | self.sr_idx, self.sr_in_list = 0, sorted(set(sr_in_list), reverse=True) 13 | if sr_out_list is not None: self.sr_out_list = sorted(set(sr_out_list), reverse=True) 14 | else: self.sr_out_list = self.sr_in_list 15 | super(DynamicConv2d, self).__init__(in_channels, out_channels, kernel_size, 16 | stride, padding, dilation, groups=groups, bias=bias) 17 | 18 | def forward(self, input): 19 | in_channels = round(self.in_channels*self.sr_in_list[self.sr_idx]) 20 | out_channels = round(self.out_channels*self.sr_out_list[self.sr_idx]) 21 | weight, bias = self.weight[:out_channels, :in_channels, :, :], None 22 | if self.bias is not None: bias = self.bias[:out_channels] 23 | return F.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, 24 | round(self.groups*self.sr_in_list[self.sr_idx]) if self.groups>1 else 1) 25 | 26 | class DynamicGN(nn.GroupNorm): 27 | def __init__(self, num_groups, num_channels, eps=1e-5, sr_in_list=(1.,)): 28 | self.sr_idx, self.sr_in_list = 0, sorted(set(sr_in_list), reverse=True) 29 | super(DynamicGN, self).__init__(num_groups, num_channels, eps) 30 | 31 | def forward(self, input): 32 | num_channels = int(self.num_channels*self.sr_in_list[self.sr_idx]) 33 | weight, bias = self.weight[:num_channels], self.bias[:num_channels] 34 | return F.group_norm(input, round(num_channels*self.num_groups/float(self.num_channels)), 35 | weight, bias, self.eps) 36 | 37 | class DynamicBN(nn.Module): 38 | def __init__(self, num_features, affine=True, track_running_stats=True, sr_in_list=(1.,)): 39 | super(DynamicBN, self).__init__() 40 | self.sr_idx, self.sr_in_list = 0, sorted(set(sr_in_list), reverse=True) 41 | self.bn_list = nn.Sequential(*[nn.BatchNorm2d(int(num_features * sr), 42 | affine=affine, track_running_stats=track_running_stats) for sr in self.sr_in_list]) 43 | 44 | def forward(self, input): 45 | return self.bn_list[self.sr_idx](input) 46 | 47 | class DynamicLinear(nn.Linear): 48 | def __init__(self, in_features, out_features, bias=True, sr_in_list=(1.0,), sr_out_list=None): 49 | self.sr_idx, self.sr_in_list = 0, sorted(set(sr_in_list), reverse=True) 50 | if sr_out_list is not None: self.sr_out_list = sorted(set(sr_out_list), reverse=True) 51 | else: self.sr_out_list = self.sr_in_list 52 | super(DynamicLinear, self).__init__(in_features, out_features, bias) 53 | 54 | def forward(self, input): 55 | in_features = round(self.in_features*self.sr_in_list[self.sr_idx]) 56 | out_features = round(self.out_features*self.sr_out_list[self.sr_idx]) 57 | weight, bias = self.weight[:out_features, :in_features], None 58 | if self.bias is not None: bias = self.bias[:out_features] 59 | return F.linear(input, weight, bias) 60 | 61 | DYNAMIC_LAYERS = (DynamicConv2d, DynamicLinear, DynamicGN, DynamicBN) 62 | 63 | def update_sr_idx(model, idx): 64 | model.apply(lambda module: (setattr(module, 'sr_idx', idx) 65 | if hasattr(module, 'sr_idx') else None)) 66 | 67 | def bind_update_sr_idx(model): 68 | model.update_sr_idx = update_sr_idx.__get__(model) 69 | 70 | def upgrade_dynamic_layers(model, num_groups=8, sr_in_list=(1.,)): 71 | sr_in_list = sorted(set(sr_in_list), reverse=True) 72 | 73 | def update(model): 74 | for name, module in model.named_children(): 75 | if isinstance(module, nn.Conv2d): 76 | setattr(model, name, DynamicConv2d(module.in_channels, module.out_channels, 77 | module.kernel_size, module.stride, module.padding, module.dilation, 78 | module.groups, module.bias is not None, sr_in_list)) 79 | elif isinstance(module, nn.Linear): 80 | setattr(model, name, DynamicLinear(module.in_features, module.out_features, 81 | module.bias is not None, sr_in_list)) 82 | elif isinstance(module, nn.BatchNorm2d): 83 | if num_groups>0: setattr(model, name, DynamicGN( 84 | num_groups, module.num_features, module.eps, sr_in_list)) 85 | else: setattr(model, name, DynamicBN(module.num_features, 86 | module.affine, module.track_running_stats, sr_in_list)) 87 | 88 | update(module) 89 | 90 | # replace all conv/bn/linear layers with dynamic counterparts 91 | update(model) 92 | # bind all dynamic layers with function update_sr_idx 93 | bind_update_sr_idx(model) 94 | 95 | # get all modules, and update the 1st module's sr_in_list and last module's sr_out_list to all 1s 96 | modules = list(filter(lambda module: isinstance(module, DYNAMIC_LAYERS), model.modules())) 97 | modules[0].sr_in_list = [1. for _ in range(len(sr_in_list))] 98 | modules[-1].sr_out_list = [1. for _ in range(len(sr_in_list))] 99 | 100 | # return self to support chain operations (optional return) 101 | return model 102 | 103 | def create_sr_scheduler(scheduler_type, sr_list, sr_rand_num=1, sr_prob=None): 104 | ''' 105 | :param scheduler_type: round_robin, random+optionally specified min/max slice rate 106 | :param sr_list: slice rate list 107 | :param sr_rand_num: number of slice rates for random sampling 108 | :param sr_prob: probabilities associated with each slice rate for random sampling 109 | default: None for uniform sampling, or len(sr_prob)==sr_rand_num 110 | :return: a list of slice rate for the current training batch 111 | ''' 112 | idx_num = len(sr_list) 113 | min_max_idxs, candidate_idxs = [], list(range(idx_num)) 114 | if sr_prob: sr_prob=np.array(sr_prob)/sum(sr_prob) 115 | 116 | if scheduler_type.find('max') >= 0: 117 | candidate_idxs.remove(0) 118 | min_max_idxs.append(0) 119 | if scheduler_type.find('min') >= 0: 120 | candidate_idxs.remove(idx_num-1) 121 | min_max_idxs.append(idx_num-1) 122 | 123 | while True: 124 | if scheduler_type.startswith('random'): 125 | rand_idxs = np.random.choice(candidate_idxs, size=sr_rand_num, p=sr_prob, replace=False) 126 | yield sorted(rand_idxs.tolist()+min_max_idxs) 127 | elif scheduler_type == 'round_robin': 128 | yield candidate_idxs 129 | else: 130 | raise Exception('unknown scheduler type: {}'.format(scheduler_type)) -------------------------------------------------------------------------------- /models/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def conv3x3(in_planes, out_planes, stride=1): 4 | "3x3 convolution with padding" 5 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 6 | padding=1, bias=False) 7 | 8 | class BasicBlock(nn.Module): 9 | expansion = 1 10 | 11 | def __init__(self, inplanes, planes, stride=1, downsample=None, preact='no_preact'): 12 | super(BasicBlock, self).__init__() 13 | 14 | self.bn1 = nn.BatchNorm2d(inplanes) 15 | self.conv1 = conv3x3(inplanes, planes, stride) 16 | self.bn2 = nn.BatchNorm2d(planes) 17 | self.conv2 = conv3x3(planes, planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | 20 | self.downsample = downsample 21 | self.stride = stride 22 | self.preact = preact 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.bn1(x) 28 | out = self.relu(out) 29 | 30 | if self.downsample is not None: 31 | if self.preact == 'preact': 32 | residual = self.downsample(out) 33 | else: 34 | residual = self.downsample(x) 35 | 36 | out = self.conv1(out) 37 | 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | out = self.conv2(out) 41 | 42 | out += residual 43 | 44 | return out 45 | 46 | class Bottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None, preact='no_preact'): 50 | super(Bottleneck, self).__init__() 51 | 52 | self.bn1 = nn.BatchNorm2d(inplanes) 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes) 57 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | self.downsample = downsample 61 | self.stride = stride 62 | self.preact = preact 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.bn1(x) 68 | out = self.relu(out) 69 | 70 | if self.downsample is not None: 71 | if self.preact == 'preact': 72 | residual = self.downsample(out) 73 | else: 74 | residual = self.downsample(x) 75 | 76 | out = self.conv1(out) 77 | 78 | out = self.bn2(out) 79 | out = self.relu(out) 80 | out = self.conv2(out) 81 | 82 | out = self.bn3(out) 83 | out = self.relu(out) 84 | out = self.conv3(out) 85 | 86 | out += residual 87 | 88 | return out 89 | 90 | 91 | class CifarResNet(nn.Module): 92 | def __init__(self, depth, num_classes=10, widen_factor=1., bottleneck=True): 93 | super(CifarResNet, self).__init__() 94 | self.inplanes = int(16*widen_factor) 95 | if bottleneck == True: 96 | n = int((depth - 2) / 9) 97 | block = Bottleneck 98 | else: 99 | n = int((depth - 2) / 6) 100 | block = BasicBlock 101 | 102 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 103 | self.layer1 = self._make_layer(block, int(16*widen_factor), n) 104 | self.layer2 = self._make_layer(block, int(32*widen_factor), n, stride=2) 105 | self.layer3 = self._make_layer(block, int(64*widen_factor), n, stride=2) 106 | self.bn1 = nn.BatchNorm2d(int(64*widen_factor) * block.expansion) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.avgpool = nn.AvgPool2d(8) 109 | self.fc = nn.Linear(int(64*widen_factor) * block.expansion, num_classes) 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1, preact='preact'): 112 | downsample = None 113 | if stride != 1 or self.inplanes != planes * block.expansion: 114 | downsample = nn.Sequential( 115 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 116 | nn.BatchNorm2d(planes * block.expansion), 117 | ) 118 | 119 | layers = [] 120 | layers.append(block(self.inplanes, planes, stride, downsample, preact)) 121 | self.inplanes = planes * block.expansion 122 | for i in range(1, blocks): 123 | layers.append(block(self.inplanes, planes)) 124 | 125 | return nn.Sequential(*layers) 126 | 127 | def forward(self, x): 128 | x = self.conv1(x) 129 | 130 | x = self.layer1(x) 131 | x = self.layer2(x) 132 | x = self.layer3(x) 133 | 134 | x = self.bn1(x) 135 | x = self.relu(x) 136 | x = self.avgpool(x) 137 | x = x.view(x.size(0), -1) 138 | x = self.fc(x) 139 | 140 | return x 141 | 142 | def cifar_resnet(args): 143 | return CifarResNet(args.depth, args.class_num, args.arg1, True) 144 | -------------------------------------------------------------------------------- /models/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'imagenet_resnet'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, num_classes=1000): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AvgPool2d(7, stride=1) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | x = self.avgpool(x) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | 156 | def resnet18(pretrained=False, **kwargs): 157 | """Constructs a ResNet-18 model. 158 | 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 163 | if pretrained: 164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 165 | return model 166 | 167 | 168 | def resnet34(pretrained=False, **kwargs): 169 | """Constructs a ResNet-34 model. 170 | 171 | Args: 172 | pretrained (bool): If True, returns a model pre-trained on ImageNet 173 | """ 174 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 175 | if pretrained: 176 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 177 | return model 178 | 179 | 180 | def resnet50(pretrained=False, **kwargs): 181 | """Constructs a ResNet-50 model. 182 | 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 189 | return model 190 | 191 | 192 | def resnet101(pretrained=False, **kwargs): 193 | """Constructs a ResNet-101 model. 194 | 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 201 | return model 202 | 203 | 204 | def resnet152(pretrained=False, **kwargs): 205 | """Constructs a ResNet-152 model. 206 | 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 213 | return model 214 | 215 | def imagenet_resnet(args): 216 | if args.depth == 18: return resnet18() 217 | elif args.depth == 34: return resnet34() 218 | elif args.depth == 50: return resnet50() 219 | elif args.depth == 101: return resnet101() 220 | elif args.depth == 152: return resnet152() 221 | else: raise Exception ('not support depth: {} for resnet'.format(args.depth)) 222 | -------------------------------------------------------------------------------- /models/vgg_cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | cfg = { 4 | 'VGG11': [64, 128, 'M', 256, 256, 'M', 512, 512, 512, 512], 5 | 'VGG13': [64, 64, 128, 128, 'M', 256, 256, 'M', 512, 512, 512, 512], 6 | 'VGG16': [64, 64, 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 512, 512, 512], 7 | 'VGG19': [64, 64, 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 512, 512, 512, 512], 8 | } 9 | 10 | class CifarVGG(nn.Module): 11 | def __init__(self, depth, num_classes=10, widen_factor=1.0): 12 | super(CifarVGG, self).__init__() 13 | self.widen_factor = widen_factor 14 | self.features = self._make_layers(cfg['VGG{0}'.format(depth)]) 15 | self.classifier = nn.Linear(int(512*widen_factor), num_classes) 16 | 17 | def forward(self, x): 18 | out = self.features(x) 19 | out = out.view(out.size(0), -1) 20 | out = self.classifier(out) 21 | return out 22 | 23 | def _make_layers(self, cfg): 24 | layers = [] 25 | in_channels = 3 26 | for x in cfg: 27 | if x == 'M': 28 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 29 | else: 30 | x = int(x*self.widen_factor) 31 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 32 | nn.BatchNorm2d(x), 33 | nn.ReLU(inplace=True)] 34 | in_channels = x 35 | layers += [nn.AvgPool2d(kernel_size=8, stride=1)] 36 | return nn.Sequential(*layers) 37 | 38 | def cifar_vgg(args): 39 | return CifarVGG(args.depth, args.class_num, args.arg1) 40 | -------------------------------------------------------------------------------- /models/vgg_imagenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | import math 4 | 5 | 6 | __all__ = [ 7 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 8 | 'vgg19_bn', 'vgg19', 'imagenet_vgg', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 15 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 16 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 17 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 18 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 19 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 20 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 21 | } 22 | 23 | 24 | class VGG(nn.Module): 25 | 26 | def __init__(self, features, num_classes=1000, init_weights=True): 27 | super(VGG, self).__init__() 28 | self.features = features 29 | self.classifier = nn.Sequential( 30 | nn.Linear(512 * 7 * 7, 4096), 31 | nn.ReLU(True), 32 | nn.Dropout(), 33 | nn.Linear(4096, 4096), 34 | nn.ReLU(True), 35 | nn.Dropout(), 36 | nn.Linear(4096, num_classes), 37 | ) 38 | if init_weights: 39 | self._initialize_weights() 40 | 41 | def forward(self, x): 42 | x = self.features(x) 43 | x = x.view(x.size(0), -1) 44 | x = self.classifier(x) 45 | return x 46 | 47 | def _initialize_weights(self): 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 51 | m.weight.data.normal_(0, math.sqrt(2. / n)) 52 | if m.bias is not None: 53 | m.bias.data.zero_() 54 | elif isinstance(m, nn.BatchNorm2d): 55 | m.weight.data.fill_(1) 56 | m.bias.data.zero_() 57 | elif isinstance(m, nn.Linear): 58 | m.weight.data.normal_(0, 0.01) 59 | m.bias.data.zero_() 60 | 61 | 62 | def make_layers(cfg, batch_norm=False): 63 | layers = [] 64 | in_channels = 3 65 | for v in cfg: 66 | if v == 'M': 67 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 68 | else: 69 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 70 | if batch_norm: 71 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 72 | else: 73 | layers += [conv2d, nn.ReLU(inplace=True)] 74 | in_channels = v 75 | return nn.Sequential(*layers) 76 | 77 | 78 | cfg = { 79 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 80 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 81 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 82 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 83 | } 84 | 85 | 86 | def vgg11(pretrained=False, **kwargs): 87 | """VGG 11-layer model (configuration "A") 88 | 89 | Args: 90 | pretrained (bool): If True, returns a model pre-trained on ImageNet 91 | """ 92 | if pretrained: 93 | kwargs['init_weights'] = False 94 | model = VGG(make_layers(cfg['A']), **kwargs) 95 | if pretrained: 96 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 97 | return model 98 | 99 | 100 | def vgg11_bn(pretrained=False, **kwargs): 101 | """VGG 11-layer model (configuration "A") with batch normalization 102 | 103 | Args: 104 | pretrained (bool): If True, returns a model pre-trained on ImageNet 105 | """ 106 | if pretrained: 107 | kwargs['init_weights'] = False 108 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs) 109 | if pretrained: 110 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 111 | return model 112 | 113 | 114 | def vgg13(pretrained=False, **kwargs): 115 | """VGG 13-layer model (configuration "B") 116 | 117 | Args: 118 | pretrained (bool): If True, returns a model pre-trained on ImageNet 119 | """ 120 | if pretrained: 121 | kwargs['init_weights'] = False 122 | model = VGG(make_layers(cfg['B']), **kwargs) 123 | if pretrained: 124 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 125 | return model 126 | 127 | 128 | def vgg13_bn(pretrained=False, **kwargs): 129 | """VGG 13-layer model (configuration "B") with batch normalization 130 | 131 | Args: 132 | pretrained (bool): If True, returns a model pre-trained on ImageNet 133 | """ 134 | if pretrained: 135 | kwargs['init_weights'] = False 136 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs) 137 | if pretrained: 138 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 139 | return model 140 | 141 | 142 | def vgg16(pretrained=False, **kwargs): 143 | """VGG 16-layer model (configuration "D") 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | if pretrained: 149 | kwargs['init_weights'] = False 150 | model = VGG(make_layers(cfg['D']), **kwargs) 151 | if pretrained: 152 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 153 | return model 154 | 155 | 156 | def vgg16_bn(pretrained=False, **kwargs): 157 | """VGG 16-layer model (configuration "D") with batch normalization 158 | 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | if pretrained: 163 | kwargs['init_weights'] = False 164 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs) 165 | if pretrained: 166 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 167 | return model 168 | 169 | 170 | def vgg19(pretrained=False, **kwargs): 171 | """VGG 19-layer model (configuration "E") 172 | 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | if pretrained: 177 | kwargs['init_weights'] = False 178 | model = VGG(make_layers(cfg['E']), **kwargs) 179 | if pretrained: 180 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 181 | return model 182 | 183 | 184 | def vgg19_bn(pretrained=False, **kwargs): 185 | """VGG 19-layer model (configuration 'E') with batch normalization 186 | 187 | Args: 188 | pretrained (bool): If True, returns a model pre-trained on ImageNet 189 | """ 190 | if pretrained: 191 | kwargs['init_weights'] = False 192 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs) 193 | if pretrained: 194 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 195 | return model 196 | 197 | def imagenet_vgg(args): 198 | if args.depth == 11: return vgg11_bn() 199 | elif args.depth == 13: return vgg13_bn() 200 | elif args.depth == 16: return vgg16_bn() 201 | elif args.depth == 19: return vgg19_bn() 202 | else: raise Exception ('not support depth: {} for vgg'.format(args.depth)) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.2 2 | torch==1.6.0 3 | scipy==1.4.1 4 | torchvision==0.7.0 5 | numpy==1.18.5 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import shutil 5 | from collections import OrderedDict 6 | import importlib 7 | 8 | import torch 9 | from torch.optim import lr_scheduler 10 | import torch.backends.cudnn as cudnn 11 | 12 | from data_loader import data_loader 13 | from utils.utilities import logger, AverageMeter, accuracy, timeSince 14 | from utils.lr_scheduler import GradualWarmupScheduler 15 | from models import upgrade_dynamic_layers, create_sr_scheduler 16 | 17 | parser = argparse.ArgumentParser(description='CIFAR-10, CIFAR-100 and ImageNet-1k Model Slicing Training') 18 | parser.add_argument('--exp_name', default='', type=str, help='optional exp name used to store log and checkpoint (default: none)') 19 | parser.add_argument('--net_type', default='resnet', type=str, help='network type: vgg, resnet, and so on') 20 | parser.add_argument('--groups', default=8, type=int, help='group num for Group Normalization (default 8, set to 0 for MultiBN)') 21 | parser.add_argument('--depth', default=50, type=int, help='depth of the network') 22 | parser.add_argument('--arg1', default=1.0, type=float, metavar='M', help='additional model arg, k for ResNet') 23 | 24 | parser.add_argument('--sr_list', nargs='+', help='the slice rate list in descending order', required=True) 25 | parser.add_argument('--sr_train_prob', nargs='+', help='the prob of picking subnet corresponding to sr_list') 26 | parser.add_argument('--sr_scheduler_type', default='random', type=str, help='slice rate scheduler, support random, random[_min][_max], round_robin') 27 | parser.add_argument('--sr_rand_num', default=1, type=int, metavar='N', help='the number of random sampled slice rate except min/max (default: 1)') 28 | 29 | parser.add_argument('--epoch', default=300, type=int, metavar='N', help='number of total epochs to run') 30 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 31 | parser.add_argument('--batch_size', '-b', default=128, type=int, metavar='N', help='mini-batch size') 32 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') 33 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') 34 | parser.add_argument('--lr', default=0.1, type=float, metavar='LR', help='initial learning rate') 35 | parser.add_argument('--cosine', dest='cosine', action='store_true', help='cosine LR scheduler') 36 | parser.add_argument('--warmup', dest='warmup', action='store_true', help='gradual warmup LR scheduler') 37 | parser.add_argument('--lr_multiplier', default=10., type=float, metavar='LR', help='LR warm up multiplier') 38 | parser.add_argument('--warmup_epoch', default=5, type=int, metavar='N', help='LR warm up epochs') 39 | 40 | parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') 41 | parser.add_argument('--resume_best', dest='resume_best', action='store_true', help='whether to resume the best_checkpoint (default: False)') 42 | parser.add_argument('--checkpoint_dir', default='~/checkpoint/', type=str, metavar='PATH', help='path to checkpoint') 43 | 44 | parser.add_argument('--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') 45 | parser.add_argument('--data_dir', default='./data/', type=str, metavar='PATH', help='path to dataset') 46 | parser.add_argument('--log_dir', default='./log/', type=str, metavar='PATH', help='path to log') 47 | parser.add_argument('--dataset', dest='dataset', default='cifar10', type=str, help='dataset (options: cifar10, cifar100, and imagenet)') 48 | parser.add_argument('--no_augment', dest='augment', action='store_false', help='whether to use standard augmentation for the datasets (default: True)') 49 | 50 | parser.add_argument('--log_freq', default=10, type=int, metavar='N', help='log frequency') 51 | 52 | parser.set_defaults(cosine=False) 53 | parser.set_defaults(warmup=False) 54 | parser.set_defaults(resume_best=False) 55 | parser.set_defaults(augment=True) 56 | 57 | # initialize all global variables 58 | args = parser.parse_args() 59 | args.data_dir += args.dataset 60 | args.sr_list = list(map(float, args.sr_list)) 61 | if args.sr_train_prob: args.sr_train_prob = list(map(float, args.sr_train_prob)) 62 | if not args.exp_name: args.exp_name = '{0}_{1}_{2}'.format(args.net_type, args.depth, args.dataset) 63 | args.checkpoint_dir = '{0}{1}/'.format(os.path.expanduser(args.checkpoint_dir), args.exp_name) 64 | args.log_path = '{0}{1}/log.txt'.format(args.log_dir, args.exp_name) 65 | best_err1, best_err5 = 100., 100. 66 | 67 | # create log dir 68 | if not os.path.isdir('log/{}'.format(args.exp_name)): 69 | os.mkdir('log/{}'.format(args.exp_name)) 70 | 71 | # load dataset 72 | train_loader, val_loader, args.class_num = data_loader(args) 73 | 74 | def main(): 75 | global args, best_err1, best_err5 76 | print_logger = logger(args.log_path, True, True) 77 | print_logger.info(vars(args)) 78 | 79 | # create model and upgrade model to support model slicing 80 | model = create_model(args, print_logger) 81 | model = upgrade_dynamic_layers(model, args.groups, args.sr_list) 82 | model = torch.nn.DataParallel(model).cuda() 83 | 84 | criterion = torch.nn.CrossEntropyLoss().cuda() 85 | optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, 86 | weight_decay=args.weight_decay, nesterov=True) 87 | scheduler = create_lr_scheduler(args, optimizer) 88 | 89 | if args.resume: 90 | checkpoint = load_checkpoint(print_logger) 91 | epoch, best_err1, best_err5, model_state, optimizer_state, scheduler_state = checkpoint.values() 92 | args.start_epoch = epoch+1 93 | model.load_state_dict(model_state) 94 | optimizer.load_state_dict(optimizer_state) 95 | scheduler.load_state_dict(optimizer_state) 96 | print_logger.info("==> finish loading checkpoint '{}' (epoch {})".format(args.resume, epoch)) 97 | 98 | cudnn.benchmark = True 99 | 100 | # start training 101 | sr_scheduler = create_sr_scheduler(args.sr_scheduler_type, args.sr_list, args.sr_rand_num, args.sr_train_prob) 102 | for epoch in range(args.start_epoch, args.epoch): 103 | print_logger.info('Epoch: [{0}/{1}]\tLR: {LR:.6f}'.format(epoch, args.epoch, LR=scheduler.get_lr()[0])) 104 | 105 | # train one epoch 106 | run(epoch, model, train_loader, criterion, print_logger, sr_scheduler, optimizer) 107 | scheduler.step() 108 | 109 | # evaluate on all the sr_idxs, from the smallest subnet to the largest 110 | for sr_idx in reversed(range(len(args.sr_list))): 111 | args.sr_idx = sr_idx 112 | model.module.update_sr_idx(sr_idx) 113 | err1, err5 = run(epoch, model, val_loader, criterion, print_logger) 114 | 115 | # record the best prec@1 for the largest subnet and save checkpoint 116 | is_best = err1 <= best_err1 117 | best_err1 = min(err1, best_err1) 118 | if is_best: best_err5 = err5 119 | print_logger.info('Current best accuracy:\ttop1 = {top1:.4f} | top5 = {top5:.4f}'. 120 | format(top1=best_err1, top5=best_err5)) 121 | save_checkpoint(OrderedDict([('epoch', epoch), ('best_err1', best_err1), ('best_err5', best_err5), 122 | ('model_state', model.state_dict()), ('optimizer_state', optimizer.state_dict()), 123 | ('scheduler_state', scheduler.state_dict())]), is_best, args.checkpoint_dir) 124 | 125 | def create_model(args, print_logger): 126 | print_logger.info("==> creating model '{}'".format(args.net_type)) 127 | models = importlib.import_module('models') 128 | if args.dataset.startswith('cifar'): model = getattr(models, 'cifar_{0}'.format(args.net_type))(args) 129 | elif args.dataset == 'imagenet': model = getattr(models, 'imagenet_{0}'.format(args.net_type))(args) 130 | print_logger.info('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 131 | return model 132 | 133 | def create_lr_scheduler(args, optimizer): 134 | if args.cosine: return lr_scheduler.CosineAnnealingLR(optimizer, args.epoch) 135 | elif args.dataset.startswith('cifar'): return lr_scheduler.MultiStepLR(optimizer, 136 | [int(args.epoch*0.5), int(args.epoch*0.75)], gamma=0.1) 137 | elif args.dataset == 'imagenet': 138 | if args.warmup: 139 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 100-args.warmup_epoch) 140 | return GradualWarmupScheduler(optimizer, multiplier=args.lr_multiplier, 141 | warmup_epoch=args.warmup_epoch, scheduler=scheduler) 142 | else: return lr_scheduler.MultiStepLR(optimizer, 143 | [int(args.epoch*0.3), int(args.epoch*0.6), int(args.epoch*0.9)], gamma=0.1) 144 | else: raise Exception('unknown scheduler for dataset: {}'.format(args.dataset)) 145 | 146 | def load_checkpoint(print_logger): 147 | print_logger.info("==> loading checkpoint '{}'".format(args.resume)) 148 | 149 | if os.path.isfile(args.resume): 150 | checkpoint = torch.load(args.resume) 151 | elif args.resume == 'checkpoint': 152 | if args.resume_best: checkpoint = torch.load('{0}{1}'.format(args.checkpoint_dir, 'best_checkpoint.ckpt')) 153 | else: checkpoint = torch.load('{0}{1}'.format(args.checkpoint_dir, 'checkpoint.ckpt')) 154 | else: 155 | raise Exception("=> no checkpoint found at '{}'".format(args.resume)) 156 | return checkpoint 157 | 158 | def save_checkpoint(checkpoint, is_best, checkpoint_dir, checkpoint_name='checkpoint.ckpt'): 159 | if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) 160 | ckpt_name = "{0}{1}".format(checkpoint_dir, checkpoint_name) 161 | torch.save(checkpoint, ckpt_name) 162 | if is_best: shutil.copyfile(ckpt_name, "{0}{1}".format(checkpoint_dir, 'best_' + checkpoint_name)) 163 | 164 | def run(epoch, model, data_loader, criterion, print_logger, sr_scheduler=None, optimizer=None): 165 | global args 166 | is_train = True if optimizer!=None else False 167 | if is_train: model.train() 168 | else: model.eval() 169 | 170 | batch_time_avg = AverageMeter() 171 | loss_avg, top1_avg, top5_avg = AverageMeter(), AverageMeter(), AverageMeter() 172 | 173 | timestamp = time.time() 174 | for idx, (input, target) in enumerate(data_loader): 175 | # torch.cuda.synchronize();print('start batch training', time.time()) 176 | if torch.cuda.is_available(): 177 | input = input.cuda(non_blocking=True) 178 | target = target.cuda(non_blocking=True) 179 | # torch.cuda.synchronize();print('loaded data to cuda', time.time()) 180 | if is_train: 181 | optimizer.zero_grad() 182 | 183 | for args.sr_idx in next(sr_scheduler): 184 | # update slice rate idx 185 | model.module.update_sr_idx(args.sr_idx) # DataParallel .module 186 | 187 | output = model(input) 188 | loss = criterion(output, target) 189 | loss.backward() 190 | 191 | optimizer.step() 192 | else: 193 | with torch.no_grad(): 194 | output = model(input) 195 | loss = criterion(output, target) 196 | # torch.cuda.synchronize();print('finnish batch training', time.time()) 197 | err1, err5 = accuracy(output, target, topk=(1,5)) 198 | loss_avg.update(loss.item(), input.size()[0]) 199 | top1_avg.update(err1, input.size()[0]) 200 | top5_avg.update(err5, input.size()[0]) 201 | 202 | batch_time_avg.update(time.time()-timestamp);timestamp = time.time() 203 | 204 | # torch.cuda.synchronize();print('start logging', time.time()) 205 | if idx % args.log_freq == 0: 206 | print_logger.info('Epoch: [{0}/{1}][{2}/{3}][SR-{4}]\t' 207 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\tLoss {loss.val:.4f} ({loss.avg:.4f})\t' 208 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\tTop 5-err {top5.val:.4f} ({top5.avg:.4f})'.format( 209 | epoch, args.epoch, idx, len(data_loader), args.sr_list[args.sr_idx], 210 | batch_time=batch_time_avg, loss=loss_avg, top1=top1_avg, top5=top5_avg)) 211 | 212 | print_logger.info('* Epoch: [{0}/{1}]{2:>8s} Total Time: {3}\tTop 1-err {top1.avg:.4f} ' 213 | 'Top 5-err {top5.avg:.4f}\tTest Loss {loss.avg:.4f}'.format(epoch, args.epoch, 214 | ('[train]' if is_train else '[val]'), timeSince(s=batch_time_avg.sum), 215 | top1=top1_avg, top5=top5_avg, loss=loss_avg)) 216 | return top1_avg.avg, top5_avg.avg 217 | 218 | 219 | if __name__ == '__main__': 220 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .profiling import * 2 | from .utilities import * 3 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import StepLR 2 | 3 | class GradualWarmupScheduler(StepLR): 4 | """ Gradually warm-up(increasing) learning rate in optimizer. 5 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 6 | Args: 7 | optimizer (Optimizer): Wrapped optimizer. 8 | multiplier: target learning rate = base lr * multiplier 9 | warmup_epoch: target learning rate is linearly reached at the warmup_epoch 10 | scheduler: scheduler used after warmup_epoch (eg. ReduceLROnPlateau) 11 | """ 12 | 13 | def __init__(self, optimizer, warmup_epoch, multiplier=1.0, scheduler=None): 14 | assert multiplier > 1., 'multiplier should be greater than 1.' 15 | self.multiplier = multiplier 16 | self.warmup_epoch = warmup_epoch 17 | self.scheduler = scheduler 18 | self.finish_warmup = False 19 | super().__init__(optimizer) 20 | 21 | def get_lr(self): 22 | if self.last_epoch > self.warmup_epoch: 23 | if self.scheduler: 24 | if not self.finish_warmup: 25 | self.scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 26 | self.finish_warmup = True 27 | return self.scheduler.get_lr() 28 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 29 | 30 | return [base_lr*((self.multiplier-1.)*self.last_epoch/self.warmup_epoch+1.) for base_lr in self.base_lrs] 31 | 32 | def step(self, epoch=None, metrics=None): 33 | if self.finish_warmup and self.scheduler: 34 | if epoch is None: self.scheduler.step(None) 35 | else: self.scheduler.step(epoch - self.warmup_epoch) 36 | else: 37 | return super(GradualWarmupScheduler, self).step(epoch) -------------------------------------------------------------------------------- /utils/profiling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from functools import reduce 4 | import operator 5 | import time 6 | 7 | MULTI_ADD = 1 8 | # global variables 9 | total_flops = 0 10 | total_params = 0 11 | total_time = 0 12 | module_cnt = 0 13 | verbose = False 14 | # output format control 15 | name_space = 95 16 | param_space = 18 17 | flop_space = 18 18 | time_space = 18 19 | # forward times 20 | forward_num = 10 21 | 22 | class Timer(object): 23 | def __init__(self, verbose=False): 24 | self.verbose = verbose 25 | self.start = None 26 | self.end = None 27 | 28 | def __enter__(self): 29 | self.start = time.time() 30 | return self 31 | 32 | def __exit__(self, *args): 33 | self.end = time.time() 34 | self.time = self.end - self.start 35 | if self.verbose: print('Elapsed time: %f ms.' % self.time) 36 | 37 | def calc_layer_param(model): 38 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()]) 39 | 40 | def profile_forward(layer, input): 41 | with Timer() as t: 42 | for _ in range(forward_num): 43 | layer.old_forward(input) 44 | if torch.cuda.is_available(): torch.cuda.synchronize() 45 | return int(t.time * 1e9 / forward_num) 46 | 47 | def profile_layer(layer, x): 48 | global total_flops, total_params, total_time, module_cnt, verbose, MULTI_ADD 49 | delta_ops, delta_params, delta_time = 0, 0, 0. 50 | 51 | # Conv2d 52 | if isinstance(layer, nn.Conv2d): 53 | out_h = int((x.size(2) + 2 * layer.padding[0] - layer.kernel_size[0]) / 54 | layer.stride[0] + 1) 55 | out_w = int((x.size(3) + 2 * layer.padding[1] - layer.kernel_size[1]) / 56 | layer.stride[1] + 1) 57 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 58 | layer.kernel_size[1] * out_h * out_w / layer.groups * MULTI_ADD 59 | delta_params = calc_layer_param(layer) 60 | delta_time = profile_forward(layer, x) 61 | module_cnt += 1 62 | 63 | # Linear 64 | elif isinstance(layer, nn.Linear): 65 | weight_ops = layer.weight.numel() * MULTI_ADD 66 | bias_ops = layer.bias.numel() 67 | delta_ops = (weight_ops + bias_ops) 68 | delta_params = calc_layer_param(layer) 69 | delta_time = profile_forward(layer, x) 70 | module_cnt += 1 71 | 72 | # ReLU can be omited 73 | elif isinstance(layer, nn.ReLU): 74 | delta_ops = reduce(operator.mul, x.size()[1:]) 75 | delta_time = profile_forward(layer, x) 76 | module_cnt += 1 77 | 78 | # Pool2d 79 | elif type(layer) in [nn.AvgPool2d, nn.MaxPool2d]: 80 | in_w = x.size(2) 81 | kernel_ops = layer.kernel_size * layer.kernel_size 82 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 83 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 84 | delta_ops = x.size(1) * out_w * out_h * kernel_ops 85 | delta_params = calc_layer_param(layer) 86 | delta_time = profile_forward(layer, x) 87 | module_cnt += 1 88 | 89 | # AdaptiveAvgPool2d 90 | elif isinstance(layer, nn.AdaptiveAvgPool2d): 91 | delta_ops = x.size(1) * x.size(2) * x.size(3) 92 | delta_params = calc_layer_param(layer) 93 | delta_time = profile_forward(layer, x) 94 | 95 | # BatchNorm2d, GroupNorm 96 | elif type(layer) in [nn.BatchNorm2d, nn.GroupNorm]: 97 | delta_ops = x.size(1) * x.size(2) * x.size(3) 98 | delta_params = calc_layer_param(layer) 99 | delta_time = profile_forward(layer, x) 100 | 101 | # ops ignore flops 102 | elif type(layer) in [nn.Dropout2d, nn.Dropout]: 103 | delta_params = calc_layer_param(layer) 104 | delta_time = profile_forward(layer, x) 105 | 106 | # ignore layer type 107 | elif type(layer) in [nn.Sequential]: # nn.BatchNorm2d, nn.GroupNorm, nn.ReLU 108 | return 109 | else: 110 | raise TypeError('unknown layer type: %s' % str(layer)) 111 | 112 | total_flops += delta_ops 113 | total_params += delta_params 114 | total_time += delta_time 115 | if verbose: 116 | print(str(layer).ljust(name_space, ' ') + 117 | '{:,}'.format(delta_params).rjust(param_space, ' ') + 118 | '{:,}'.format(delta_ops).rjust(flop_space, ' ') + 119 | '{:,}'.format(delta_time).rjust(time_space, ' ')) 120 | return 121 | 122 | 123 | def profiling(model, H, W, C=3, B=1, debug=False): 124 | global total_flops, total_params, total_time, module_cnt, verbose 125 | total_flops, total_params, module_cnt, verbose = 0, 0, 0, debug 126 | data = torch.zeros(B, C, H, W) 127 | 128 | def is_leaf(model): 129 | ''' measure all leaf nodes ''' 130 | return len(list(model.children())) == 0 131 | 132 | def modify_forward(model): 133 | for child in model.children(): 134 | if is_leaf(child): 135 | def new_forward(m): 136 | def lambda_forward(x): 137 | profile_layer(m, x) 138 | return m.old_forward(x) 139 | return lambda_forward 140 | child.old_forward = child.forward 141 | child.forward = new_forward(child) 142 | else: 143 | modify_forward(child) 144 | 145 | def restore_forward(model): 146 | for child in model.children(): 147 | # leaf node 148 | if is_leaf(child) and hasattr(child, 'old_forward'): 149 | child.forward = child.old_forward 150 | child.old_forward = None 151 | else: 152 | restore_forward(child) 153 | 154 | def line_breaker(): 155 | print(''.center(name_space + param_space + 156 | flop_space + time_space, '-')) 157 | 158 | print('Item'.ljust(name_space, ' ') + 159 | 'params'.rjust(param_space, ' ') + 160 | 'flops'.rjust(flop_space, ' ') + 161 | 'nanosecs'.rjust(time_space, ' ')) 162 | if verbose: line_breaker() 163 | modify_forward(model) 164 | model.forward(data) 165 | restore_forward(model) 166 | if verbose: 167 | line_breaker() 168 | print('Total'.ljust(name_space, ' ') + 169 | '{:,}'.format(total_params).rjust(param_space, ' ') + 170 | '{:,}'.format(total_flops).rjust(flop_space, ' ') + 171 | '{:,}'.format(total_time).rjust(time_space, ' ')) 172 | 173 | return total_params, total_flops, total_time 174 | -------------------------------------------------------------------------------- /utils/utilities.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import logging 4 | import sys 5 | 6 | # setup logger 7 | def logger(log_dir, need_time=True, need_stdout=False): 8 | log = logging.getLogger(__name__) 9 | log.setLevel(logging.DEBUG) 10 | fh = logging.FileHandler(log_dir) 11 | fh.setLevel(logging.DEBUG) 12 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y-%I:%M:%S') 13 | if need_stdout: 14 | ch = logging.StreamHandler(sys.stdout) 15 | ch.setLevel(logging.INFO) 16 | log.addHandler(ch) 17 | if need_time: 18 | fh.setFormatter(formatter) 19 | if need_stdout: 20 | ch.setFormatter(formatter) 21 | log.addHandler(fh) 22 | return log 23 | 24 | def timeSince(since=None, s=None): 25 | if s is None: 26 | s = int(time.time() - since) 27 | m = math.floor(s / 60) 28 | s %= 60 29 | h = math.floor(m / 60) 30 | m %= 60 31 | return '%dh %dm %ds' %(h, m, s) 32 | 33 | class AverageMeter(object): 34 | """Computes and stores the average and current value""" 35 | 36 | def __init__(self): 37 | self.reset() 38 | 39 | def reset(self): 40 | self.val = 0 41 | self.avg = 0 42 | self.sum = 0 43 | self.count = 0 44 | 45 | def update(self, val, n=1): 46 | self.val = val 47 | self.sum += val * n 48 | self.count += n 49 | self.avg = self.sum / self.count 50 | 51 | def accuracy(output, target, topk=(1,)): 52 | """Computes the precision@k for the specified values of k""" 53 | maxk = max(topk) 54 | batch_size = target.size(0) 55 | 56 | _, pred = output.topk(maxk, 1, True, True) 57 | pred = pred.t() 58 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 59 | 60 | res = [] 61 | for k in topk: 62 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 63 | wrong_k = batch_size - correct_k 64 | res.append(wrong_k.mul_(100.0 / batch_size).item()) 65 | 66 | return res 67 | 68 | --------------------------------------------------------------------------------