├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── layers.py ├── main.py ├── models ├── __init__.py ├── densenet.py ├── dynamic_densenet.py ├── dynamic_resnet.py └── resnet.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # result files 104 | results/* 105 | *run*.sh 106 | 107 | .DS_Store 108 | *.swp 109 | results* 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zhuo Su 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Group Convolution 2 | 3 | This repository contains the PyTorch implementation for 4 | "Dynamic Group Convolution for Accelerating Convolutional Neural Networks" 5 | by 6 | [Zhuo Su](https://zhuogege1943.com/homepage/)\*, 7 | [Linpu Fang](https://dblp.org/pers/hd/f/Fang:Linpu)\*, 8 | [Wenxiong Kang](http://www.scholat.com/auwxkang.en), 9 | [Dewen Hu](https://dblp.org/pers/h/Hu:Dewen.html), 10 | [Matti Pietikäinen](https://en.wikipedia.org/wiki/Matti_Pietik%C3%A4inen_(academic)) and 11 | [Li Liu](http://www.ee.oulu.fi/~lili/LiLiuHomepage.html) 12 | (\* Authors have equal contributions). \[[arXiv](https://arxiv.org/abs/2007.04242)\] 13 | 14 | The code is based on [CondenseNet](https://github.com/ShichenLiu/CondenseNet). 15 | 16 | 17 | ### Citation 18 | 19 | If you find our project useful in your research, please consider citing: 20 | 21 | ``` 22 | @inproceedings{su2020dgc, 23 | title={Dynamic Group Convolution for Accelerating Convolutional Neural Networks}, 24 | author={Su, Zhuo and Fang, Linpu and Kang, Wenxiong and Hu, Dewen and Pietik{\"a}inen, Matti and Liu, Li}, 25 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 26 | year={2020} 27 | } 28 | ``` 29 | 30 | ## Introduction 31 | 32 | 33 | Dynamic Group Convolution (DGC) can adaptively select which part 34 | of input channels to be connected within each group for individual 35 | samples on the fly. Specifically, we equip each group with a small feature 36 | selector to automatically select the most important input channels 37 | conditioned on the input images. Multiple groups can adptively capture 38 | abundant and complementary visual/semantic features for each input 39 | image. The DGC preserves the original network structure and has 40 | similar computational efficiency as the conventional group convolutions 41 | simultaneously. Extensive experiments on multiple image classification 42 | benchmarks including CIFAR-10, CIFAR-100 and ImageNet demonstrate its 43 | superiority over the exiting group convolution techniques and dynamic execution methods. 44 | 45 |
46 |
47 | Figure 1: Overview of a DGC layer. 48 |
49 | 50 | The DGC network can be trained from scratch by an 51 | end-to-end manner, without the need of model pre-training. During backward 52 | propagation in a DGC layer, gradients are calculated 53 | only for weights connected to selected channels during the forward pass, and 54 | safely set as 0 for others thanks to the unbiased gating strategy (refer to the paper). 55 | To avoid abrupt changes in training loss while pruning, 56 | we gradually deactivate input channels along the training process 57 | with a cosine shape learning rate. 58 | 59 |
60 |
61 | Figure 2: Training pipeline. 62 |
63 | 64 | 65 | ## Training and Evaluation 66 | 67 | At the moment, we are sorry to tell that the training process on ImageNet is somewhat slow and memory consuming because this is still a coarse implementation. For the first bash script of condensenet on ImageNet, the model was trained on two v100 GPUs with 32G gpu memory each. 68 | 69 | Remove `--evaluate xxx.tar` to Train, otherwise to Evaluate (the trained models can be downloaded through the following links or [baidunetdisk](https://pan.baidu.com/s/17BqJ4slwwNxRydj9RBT8yQ) (code: 9dtn)) 70 | 71 | (condensenet with dgc on ImageNet, pruning rate=0.75, heads=4, ***top1=25.4, top5=7.8***) 72 | 73 | Links for `imagenet_dydensenet_h4.tar` (92.3M): 74 | [google drive](https://drive.google.com/file/d/1gKrugAFGLea7kjTa_nmhwVAsinoxze8T/view?usp=sharing), 75 | [onedirve](https://unioulu-my.sharepoint.com/:u:/g/personal/zsu18_univ_yo_oulu_fi/EeU7Lpe2AUBPsONNZYBVv5kBNAy0sdOlj94iuqCdRRkneQ?e=NaZpQF) 76 | ```bash 77 | python main.py --model dydensenet -b 256 -j 4 --data imagenet --datadir /path/to/imagenet \ 78 | --epochs 120 --lr-type cosine --stages 4-6-8-10-8 --growth 8-16-32-64-128 --bottleneck 4 \ 79 | --heads 4 --group-3x3 4 --gate-factor 0.25 --squeeze-rate 16 --resume --gpu 0 --savedir results/exp \ 80 | --evaluate /path/to/imagenet_dydensenet_h4.tar 81 | ``` 82 | 83 | 84 | (resnet18 with dgc on ImageNet, pruning rate=0.55, heads=4, ***top1=31.22, top5=11.38***) 85 | 86 | Links for `imagenet_dyresnet18_h4.tar` (47.2M): 87 | [google drive](https://drive.google.com/file/d/1rtSU3iUKlA0NhgnUJz-QksW5aL2Lt2Cg/view?usp=sharing), 88 | [onedirve](https://unioulu-my.sharepoint.com/:u:/g/personal/zsu18_univ_yo_oulu_fi/EaiXCgT7H7NBmBWObq1lOukBUYaQb5J6DOcD3RHFA4PLLQ?e=myQHRN) 89 | ```bash 90 | python main.py --model dyresnet18 -b 256 -j 4 --data imagenet --datadir /path/to/imagenet \ 91 | --epochs 120 --lr-type cosine --heads 4 --gate-factor 0.45 --squeeze-rate 16 --resume \ 92 | --gpu 0 --savedir results/exp --evaluate /path/to/imagenet_dyresnet18_h4.tar 93 | ``` 94 | 95 | (densenet86 with dgc on Cifar10, pruning rate=0.75, heads=4, ***top1=4.77***) 96 | 97 | Links for `cifar10_dydensenet86_h4.tar` (16.7M): 98 | [google drive](https://drive.google.com/file/d/1o1cVxqa7juDgNRK53dKpfTKEbfMhPSdG/view?usp=sharing), 99 | [onedirve](https://unioulu-my.sharepoint.com/:u:/g/personal/zsu18_univ_yo_oulu_fi/EZ6cmeLZGHdLtIJeFiM-FzYBVPDoaj70wZ1r4yT8X48Ivw?e=YocnXs) 100 | ```bash 101 | python main.py --model dydensenet -b 64 -j 4 --data cifar10 --datadir ../data --epochs 300 \ 102 | --lr-type cosine --stages 14-14-14 --growth 8-16-32 --bottleneck 4 --heads 4 --group-3x3 4 \ 103 | --gate-factor 0.25 --squeeze-rate 16 --resume --gpu 0 --savedir results/exp \ 104 | --evaluate /path/to/cifar10_dydensenet86_h4.tar 105 | ``` 106 | 107 | 108 | (densenet86 with dgc on Cifar100, pruning rate=0.75, heads=4, ***top1=23.41***) 109 | 110 | Links for `cifar100_dydensenet86_h4.tar` (17.0M): 111 | [google drive](https://drive.google.com/file/d/1Wne46Znto-uivTV-Evc5RHywUEe7Emyn/view?usp=sharing), 112 | [onedirve](https://unioulu-my.sharepoint.com/:u:/g/personal/zsu18_univ_yo_oulu_fi/EXci72YYC_1CiA7GwOybIw0BJK9rUg48ZXaapPQvHq0Viw?e=ZKVXk9) 113 | ```bash 114 | python main.py --model dydensenet -b 64 -j 4 --data cifar100 --datadir ../data --epochs 300 \ 115 | --lr-type cosine --stages 14-14-14 --growth 8-16-32 --bottleneck 4 --heads 4 --group-3x3 4 \ 116 | --gate-factor 0.25 --squeeze-rate 16 --resume --gpu 0 --savedir results/exp \ 117 | --evaluate /path/to/cifar100_dydensenet86_h4.tar 118 | ``` 119 | 120 | ## Other notes 121 | 122 | Any discussions or concerns are welcomed in the [Issues](https://github.com/zhuogege1943/dgc/issues)! 123 | 124 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellozhuo/dgc/86befbd7f7b685ab3bbfafcd027ca3551dda48e9/__init__.py -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class DynamicMultiHeadConv(nn.Module): 11 | global_progress = 0.0 12 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 13 | padding=0, dilation=1, heads=4, squeeze_rate=16, gate_factor=0.25): 14 | super(DynamicMultiHeadConv, self).__init__() 15 | self.norm = nn.BatchNorm2d(in_channels) 16 | self.relu = nn.ReLU(inplace=True) 17 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 18 | self.in_channels = in_channels 19 | self.out_channels = out_channels 20 | self.heads = heads 21 | self.squeeze_rate = squeeze_rate 22 | self.gate_factor = gate_factor 23 | self.stride = stride 24 | self.padding = padding 25 | self.dilation = dilation 26 | self.is_pruned = True 27 | self.register_buffer('_inactive_channels', torch.zeros(1)) 28 | 29 | ### Check if arguments are valid 30 | assert self.in_channels % self.heads == 0, \ 31 | "head number can not be divided by input channels" 32 | assert self.out_channels % self.heads == 0, \ 33 | "head number can not be divided by output channels" 34 | assert self.gate_factor <= 1.0, "gate factor is greater than 1" 35 | 36 | for i in range(self.heads): 37 | self.__setattr__('headconv_%1d' % i, 38 | HeadConv(in_channels, out_channels // self.heads, squeeze_rate, 39 | kernel_size, stride, padding, dilation, 1, gate_factor)) 40 | 41 | def forward(self, x): 42 | """ 43 | The code here is just a coarse implementation. 44 | The forward process can be quite slow and memory consuming, need to be optimized. 45 | """ 46 | if self.training: 47 | progress = DynamicMultiHeadConv.global_progress 48 | # gradually deactivate input channels 49 | if progress < 3.0 / 4 and progress > 1.0 / 12: 50 | self.inactive_channels = round(self.in_channels * (1 - self.gate_factor) * 3.0 / 2 * (progress - 1.0 / 12)) 51 | elif progress >= 3.0 / 4: 52 | self.inactive_channels = round(self.in_channels * (1 - self.gate_factor)) 53 | 54 | _lasso_loss = 0.0 55 | 56 | x = self.norm(x) 57 | x = self.relu(x) 58 | 59 | x_averaged = self.avg_pool(x) 60 | x_mask = [] 61 | weight = [] 62 | for i in range(self.heads): 63 | i_x, i_lasso_loss= self.__getattr__('headconv_%1d' % i)(x, x_averaged, self.inactive_channels) 64 | x_mask.append(i_x) 65 | weight.append(self.__getattr__('headconv_%1d' % i).conv.weight) 66 | _lasso_loss = _lasso_loss + i_lasso_loss 67 | 68 | x_mask = torch.cat(x_mask, dim=1) # batch_size, 4 x C_in, H, W 69 | weight = torch.cat(weight, dim=0) # 4 x C_out, C_in, k, k 70 | 71 | out = F.conv2d(x_mask, weight, None, self.stride, 72 | self.padding, self.dilation, self.heads) 73 | b, c, h, w = out.size() 74 | out = out.view(b, self.heads, c // self.heads, h, w) 75 | out = out.transpose(1, 2).contiguous().view(b, c, h, w) 76 | return [out, _lasso_loss] 77 | 78 | @property 79 | def inactive_channels(self): 80 | return int(self._inactive_channels[0]) 81 | 82 | @inactive_channels.setter 83 | def inactive_channels(self, val): 84 | self._inactive_channels.fill_(val) 85 | 86 | class HeadConv(nn.Module): 87 | def __init__(self, in_channels, out_channels, squeeze_rate, kernel_size, stride=1, 88 | padding=0, dilation=1, groups=1, gate_factor=0.25): 89 | super(HeadConv, self).__init__() 90 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, 91 | padding, dilation, groups=1, bias=False) 92 | self.target_pruning_rate = gate_factor 93 | if in_channels < 80: 94 | squeeze_rate = squeeze_rate // 2 95 | self.fc1 = nn.Linear(in_channels, in_channels // squeeze_rate, bias=False) 96 | self.relu_fc1 = nn.ReLU(inplace=True) 97 | self.fc2 = nn.Linear(in_channels // squeeze_rate, in_channels, bias=True) 98 | self.relu_fc2 = nn.ReLU(inplace=True) 99 | 100 | nn.init.kaiming_normal_(self.fc1.weight) 101 | nn.init.kaiming_normal_(self.fc2.weight) 102 | nn.init.constant_(self.fc2.bias, 1.0) 103 | 104 | def forward(self, x, x_averaged, inactive_channels): 105 | b, c, _, _ = x.size() 106 | x_averaged = x_averaged.view(b, c) 107 | y = self.fc1(x_averaged) 108 | y = self.relu_fc1(y) 109 | y = self.fc2(y) 110 | 111 | 112 | mask = self.relu_fc2(y) # b, c 113 | _lasso_loss = mask.mean() 114 | 115 | mask_d = mask.detach() 116 | mask_c = mask 117 | 118 | if inactive_channels > 0: 119 | mask_c = mask.clone() 120 | topk_maxmum, _ = mask_d.topk(inactive_channels, dim=1, largest=False, sorted=False) 121 | clamp_max, _ = topk_maxmum.max(dim=1, keepdim=True) 122 | mask_index = mask_d.le(clamp_max) 123 | mask_c[mask_index] = 0 124 | 125 | mask_c = mask_c.view(b, c, 1, 1) 126 | x = x * mask_c.expand_as(x) 127 | return x, _lasso_loss 128 | 129 | 130 | class Conv(nn.Sequential): 131 | def __init__(self, in_channels, out_channels, kernel_size, 132 | stride=1, padding=0, groups=1): 133 | super(Conv, self).__init__() 134 | self.add_module('norm', nn.BatchNorm2d(in_channels)) 135 | self.add_module('relu', nn.ReLU(inplace=True)) 136 | self.add_module('conv', nn.Conv2d(in_channels, out_channels, 137 | kernel_size=kernel_size, 138 | stride=stride, 139 | padding=padding, bias=False, 140 | groups=groups)) 141 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dynamic Group Convolution 3 | date: July 5th, 2020 4 | authors: Zhuo Su, Linpu Fang 5 | paper: Dynamic Group Convolution for Accelerating Convolutional Neural Networks, ECCV 2020. 6 | 7 | Code forked from "https://github.com/ShichenLiu/CondenseNet" 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import unicode_literals 12 | from __future__ import print_function 13 | from __future__ import division 14 | 15 | import argparse 16 | import os 17 | import time 18 | import models 19 | from utils import * 20 | 21 | import torch 22 | import torch.nn as nn 23 | import torch.backends.cudnn as cudnn 24 | import torchvision.transforms as transforms 25 | import torchvision.datasets as datasets 26 | 27 | parser = argparse.ArgumentParser(description='PyTorch main code for Dynamic Group Convolution') 28 | parser.add_argument('--data', type=str, default='imagenet', 29 | help='name of dataset', 30 | choices=['cifar10', 'cifar100', 'imagenet']) 31 | parser.add_argument('--datadir', type=str, default='../data', 32 | help='dir to the dataset') 33 | parser.add_argument('--savedir', type=str, default='results/exp', 34 | help='path to save result and checkpoint') 35 | 36 | parser.add_argument('--model', type=str, default='dydensenet', 37 | help='model to train the dataset') 38 | parser.add_argument('-j', '--workers', type=int, default=8, 39 | help='number of data loading workers') 40 | parser.add_argument('--epochs', type=int, default=120, 41 | help='number of total epochs to run') 42 | parser.add_argument('-b', '--batch-size', type=int, default=256, 43 | help='mini-batch size') 44 | parser.add_argument('--lr', '--learning-rate', type=float, default=0.1, 45 | help='initial learning rate') 46 | parser.add_argument('--lr-type', type=str, default='cosine', 47 | help='learning rate strategy', 48 | choices=['cosine', 'multistep']) 49 | parser.add_argument('--group-lasso-lambda', type=float, default=1e-5, 50 | help='group lasso loss weight') 51 | parser.add_argument('--momentum', type=float, default=0.9, 52 | help='momentum for sgd') 53 | parser.add_argument('--weight-decay', '--wd', type=float, default=1e-4, 54 | help='weight decay') 55 | parser.add_argument('--seed', type=int, default=None, 56 | help='manual seed') 57 | parser.add_argument('--gpu', type=str, default='', 58 | help='gpu available') 59 | 60 | parser.add_argument('--stages', type=str, 61 | help='per layer depth') 62 | parser.add_argument('--squeeze-rate', type=int, default=16, 63 | help='squeeze rate in SE head') 64 | parser.add_argument('--heads', type=int, default=4, 65 | help='number of heads for 1x1 convolution') 66 | parser.add_argument('--group-3x3', type=int, default=4, 67 | help='3x3 group convolution') 68 | parser.add_argument('--gate-factor', type=float, default=0.25, 69 | help='gate factor') 70 | parser.add_argument('--growth', type=str, 71 | help='per layer growth') 72 | parser.add_argument('--bottleneck', type=int, default=4, 73 | help='bottleneck in densenet') 74 | 75 | parser.add_argument('--print-freq', type=int, default=10, 76 | help='print frequency') 77 | parser.add_argument('--save-freq', type=int, default=10, 78 | help='save frequency') 79 | parser.add_argument('--resume', action='store_true', 80 | help='use latest checkpoint if have any') 81 | parser.add_argument('--evaluate', type=str, default=None, 82 | help="full path to checkpoint to be evaluated") 83 | 84 | args = parser.parse_args() 85 | 86 | 87 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 88 | 89 | best_prec1 = 0 90 | 91 | def main(): 92 | global args, best_prec1 93 | 94 | if args.seed is None: 95 | args.seed = int(time.time()) 96 | torch.manual_seed(args.seed) 97 | torch.cuda.manual_seed_all(args.seed) 98 | 99 | R = 32 100 | if args.data == 'cifar10': 101 | args.num_classes = 10 102 | elif args.data == 'cifar100': 103 | args.num_classes = 100 104 | else: 105 | args.num_classes = 1000 106 | R = 224 107 | 108 | if 'densenet' in args.model: 109 | args.stages = list(map(int, args.stages.split('-'))) 110 | args.growth = list(map(int, args.growth.split('-'))) 111 | 112 | 113 | ### Calculate FLOPs & Param 114 | model = getattr(models, args.model)(args) 115 | n_flops, n_params = measure_model(model, R, R) 116 | print('FLOPs: %.2fM, Params: %.2fM' % (n_flops / 1e6, n_params / 1e6)) 117 | 118 | os.makedirs(args.savedir, exist_ok=True) 119 | log_file = os.path.join(args.savedir, "%s_%d_%d.txt" % \ 120 | (args.model, int(n_params), int(n_flops))) 121 | del(model) 122 | 123 | ### Create model 124 | model = getattr(models, args.model)(args) 125 | model = torch.nn.DataParallel(model).cuda() 126 | 127 | ### Define loss function (criterion) and optimizer 128 | criterion = nn.CrossEntropyLoss().cuda() 129 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 130 | momentum=args.momentum, 131 | weight_decay=args.weight_decay, 132 | nesterov=True) 133 | 134 | cudnn.benchmark = True 135 | 136 | ### Data loading 137 | if args.data == "cifar10": 138 | normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467], 139 | std=[0.2471, 0.2435, 0.2616]) 140 | train_set = datasets.CIFAR10(args.datadir, train=True, download=True, 141 | transform=transforms.Compose([ 142 | transforms.RandomCrop(32, padding=4), 143 | transforms.RandomHorizontalFlip(), 144 | transforms.ToTensor(), 145 | normalize, 146 | ])) 147 | val_set = datasets.CIFAR10(args.datadir, train=False, 148 | transform=transforms.Compose([ 149 | transforms.ToTensor(), 150 | normalize, 151 | ])) 152 | elif args.data == "cifar100": 153 | normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], 154 | std=[0.2675, 0.2565, 0.2761]) 155 | train_set = datasets.CIFAR100(args.datadir, train=True, download=True, 156 | transform=transforms.Compose([ 157 | transforms.RandomCrop(32, padding=4), 158 | transforms.RandomHorizontalFlip(), 159 | transforms.ToTensor(), 160 | normalize, 161 | ])) 162 | val_set = datasets.CIFAR100(args.datadir, train=False, 163 | transform=transforms.Compose([ 164 | transforms.ToTensor(), 165 | normalize, 166 | ])) 167 | else: #imagenet 168 | traindir = os.path.join(args.datadir, 'train') 169 | valdir = os.path.join(args.datadir, 'val') 170 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 171 | std=[0.229, 0.224, 0.225]) 172 | train_set = datasets.ImageFolder(traindir, transforms.Compose([ 173 | transforms.RandomResizedCrop(224), 174 | transforms.RandomHorizontalFlip(), 175 | transforms.ToTensor(), 176 | normalize, 177 | ])) 178 | 179 | val_set = datasets.ImageFolder(valdir, transforms.Compose([ 180 | transforms.Resize(256), 181 | transforms.CenterCrop(224), 182 | transforms.ToTensor(), 183 | normalize, 184 | ])) 185 | 186 | train_loader = torch.utils.data.DataLoader( 187 | train_set, 188 | batch_size=args.batch_size, shuffle=True, 189 | num_workers=args.workers, pin_memory=True) 190 | 191 | val_loader = torch.utils.data.DataLoader( 192 | val_set, 193 | batch_size=args.batch_size, shuffle=False, 194 | num_workers=args.workers, pin_memory=True) 195 | 196 | ### Optionally resume from a checkpoint 197 | args.start_epoch = 0 198 | if args.resume or (args.evaluate is not None): 199 | checkpoint = load_checkpoint(args) 200 | if checkpoint is not None: 201 | model.load_state_dict(checkpoint['state_dict']) 202 | try: 203 | args.start_epoch = checkpoint['epoch'] + 1 204 | best_prec1 = checkpoint['best_prec1'] 205 | optimizer.load_state_dict(checkpoint['optimizer']) 206 | except KeyError: 207 | pass 208 | 209 | ### Evaluate directly if required 210 | print(args) 211 | if args.evaluate is not None: 212 | validate(val_loader, model, criterion, args) 213 | return 214 | 215 | saveID = None 216 | for epoch in range(args.start_epoch, args.epochs): 217 | ### Train for one epoch 218 | tr_prec1, tr_prec5, loss, lr = \ 219 | train(train_loader, model, criterion, optimizer, epoch, args) 220 | 221 | ### Evaluate on validation set 222 | val_prec1, val_prec5 = validate(val_loader, model, criterion, args) 223 | 224 | ### Remember best prec@1 and save checkpoint 225 | is_best = val_prec1 >= best_prec1 226 | best_prec1 = max(val_prec1, best_prec1) 227 | 228 | log = ("Epoch %03d/%03d: top1 %.4f | top5 %.4f" + \ 229 | " | train-top1 %.4f | train-top5 %.4f | loss %.4f | lr %.5f | Time %s\n") \ 230 | % (epoch, args.epochs, val_prec1, val_prec5, tr_prec1, \ 231 | tr_prec5, loss, lr, time.strftime('%Y-%m-%d %H:%M:%S')) 232 | with open(log_file, 'a') as f: 233 | f.write(log) 234 | 235 | saveID = save_checkpoint({ 236 | 'epoch': epoch, 237 | 'state_dict': model.state_dict(), 238 | 'best_prec1': best_prec1, 239 | 'optimizer': optimizer.state_dict(), 240 | }, epoch, args.savedir, is_best, 241 | saveID, keep_freq=args.save_freq) 242 | 243 | return 244 | 245 | 246 | def train(train_loader, model, criterion, optimizer, epoch, args): 247 | batch_time = AverageMeter() 248 | data_time = AverageMeter() 249 | losses = AverageMeter() 250 | lasso_losses = AverageMeter() 251 | top1 = AverageMeter() 252 | top5 = AverageMeter() 253 | 254 | ### Switch to train mode 255 | model.train() 256 | wD = len(str(len(train_loader))) 257 | wE = len(str(args.epochs)) 258 | 259 | end = time.time() 260 | for i, (input, target) in enumerate(train_loader): 261 | 262 | progress = float(epoch * len(train_loader) + i) / \ 263 | (args.epochs * len(train_loader)) 264 | ## Adjust learning rate 265 | lr = adjust_learning_rate(optimizer, epoch, args, batch=i, 266 | nBatch=len(train_loader), method=args.lr_type) 267 | 268 | ## Measure data loading time 269 | data_time.update(time.time() - end) 270 | 271 | input = input.cuda(non_blocking=True) 272 | target = target.cuda(non_blocking=True) 273 | 274 | ## Compute output 275 | output, _lasso_list = model(input, progress) 276 | loss = criterion(output, target) 277 | 278 | ## Add group lasso loss 279 | lasso_loss = 0 280 | if args.group_lasso_lambda > 0: 281 | for lasso_m in _lasso_list: 282 | lasso_loss = lasso_loss + lasso_m.mean() 283 | loss = loss + args.group_lasso_lambda * lasso_loss 284 | 285 | ## Measure accuracy and record loss 286 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 287 | losses.update(loss.item(), input.size(0)) 288 | lasso_losses.update(lasso_loss.item()) 289 | top1.update(prec1.item(), input.size(0)) 290 | top5.update(prec5.item(), input.size(0)) 291 | 292 | ## Compute gradient and do SGD step 293 | optimizer.zero_grad() 294 | loss.backward() 295 | optimizer.step() 296 | 297 | ## Measure elapsed time 298 | batch_time.update(time.time() - end) 299 | end = time.time() 300 | 301 | ## Record 302 | if i % args.print_freq == 0: 303 | print(('Epoch: [{0}/{1}][{2}/{3}]\t' + \ 304 | 'Time {batch_time.val:.3f}\t' + \ 305 | 'Data {data_time.val:.3f}\t' + \ 306 | 'Loss (lasso_loss) {loss.val:.4f} ({lasso_loss.val:.4f})\t' + \ 307 | 'Prec@1 {top1.val:.3f}\t' + \ 308 | 'Prec@5 {top5.val:.3f}\t' + \ 309 | 'lr {lr: .5f}\t').format( 310 | epoch, args.epochs, i, len(train_loader), batch_time=batch_time, 311 | data_time=data_time, loss=losses, lasso_loss=lasso_losses, 312 | top1=top1, top5=top5, lr=lr)) 313 | 314 | return top1.avg, top5.avg, losses.avg, lr 315 | 316 | 317 | def validate(val_loader, model, criterion, args): 318 | batch_time = AverageMeter() 319 | losses = AverageMeter() 320 | top1 = AverageMeter() 321 | top5 = AverageMeter() 322 | 323 | ## Switch to evaluate mode 324 | model.eval() 325 | 326 | end = time.time() 327 | for i, (input, target) in enumerate(val_loader): 328 | ## Compute output 329 | with torch.no_grad(): 330 | target = target.cuda(non_blocking=True) 331 | input = input.cuda(non_blocking=True) 332 | 333 | output, _ = model(input) 334 | loss = criterion(output, target) 335 | 336 | ## Measure accuracy and record loss 337 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 338 | losses.update(loss.data.item(), input.size(0)) 339 | top1.update(prec1.item(), input.size(0)) 340 | top5.update(prec5.item(), input.size(0)) 341 | 342 | ## Measure elapsed time 343 | batch_time.update(time.time() - end) 344 | end = time.time() 345 | 346 | if i % args.print_freq == 0: 347 | print('Test: [{0}/{1}]\t' 348 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 349 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 350 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 351 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 352 | i, len(val_loader), batch_time=batch_time, loss=losses, 353 | top1=top1, top5=top5)) 354 | 355 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}' 356 | .format(top1=top1, top5=top5)) 357 | 358 | return top1.avg, top5.avg 359 | 360 | 361 | 362 | if __name__ == '__main__': 363 | main() 364 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .densenet import DenseNet as densenet 2 | from .resnet import resnet18 3 | 4 | 5 | from .dynamic_densenet import DydenseNet as dydensenet 6 | from .dynamic_resnet import dyresnet18 7 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | import math 11 | from layers import Conv 12 | 13 | __all__ = ['DenseNet'] 14 | 15 | 16 | def make_divisible(x, y): 17 | return int((x // y + 1) * y) if x % y else int(x) 18 | 19 | 20 | class _DenseLayer(nn.Module): 21 | def __init__(self, in_channels, growth_rate, args): 22 | super(_DenseLayer, self).__init__() 23 | self.group_1x1 = args.group_1x1 24 | self.group_3x3 = args.group_3x3 25 | ### 1x1 conv i --> b*k 26 | self.conv_1 = Conv(in_channels, args.bottleneck * growth_rate, 27 | kernel_size=1, groups=self.group_1x1) 28 | ### 3x3 conv b*k --> k 29 | self.conv_2 = Conv(args.bottleneck * growth_rate, growth_rate, 30 | kernel_size=3, padding=1, groups=self.group_3x3) 31 | 32 | def forward(self, x): 33 | x_ = x 34 | x = self.conv_1(x) 35 | x = self.conv_2(x) 36 | return torch.cat([x_, x], 1) 37 | 38 | 39 | class _DenseBlock(nn.Sequential): 40 | def __init__(self, num_layers, in_channels, growth_rate, args): 41 | super(_DenseBlock, self).__init__() 42 | for i in range(num_layers): 43 | layer = _DenseLayer(in_channels + i * growth_rate, growth_rate, args) 44 | self.add_module('denselayer_%d' % (i + 1), layer) 45 | 46 | 47 | class _Transition(nn.Module): 48 | def __init__(self, in_channels, out_channels, args): 49 | super(_Transition, self).__init__() 50 | #self.conv = Conv(in_channels, out_channels, 51 | # kernel_size=1, groups=args.group_1x1) 52 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2) 53 | 54 | def forward(self, x): 55 | #x = self.conv(x) 56 | x = self.pool(x) 57 | return x 58 | 59 | 60 | class DenseNet(nn.Module): 61 | def __init__(self, args): 62 | 63 | super(DenseNet, self).__init__() 64 | 65 | self.stages = args.stages 66 | self.growth = args.growth 67 | self.reduction = args.reduction 68 | assert len(self.stages) == len(self.growth) 69 | self.args = args 70 | self.progress = 0.0 71 | if args.data in ['cifar10', 'cifar100']: 72 | self.init_stride = 1 73 | self.pool_size = 8 74 | else: 75 | self.init_stride = 2 76 | self.pool_size = 7 77 | 78 | self.features = nn.Sequential() 79 | ### Set initial width to 2 x growth_rate[0] 80 | self.num_features = 2 * self.growth[0] 81 | ### Dense-block 1 (224x224) 82 | self.features.add_module('init_conv', nn.Conv2d(3, self.num_features, 83 | kernel_size=3, 84 | stride=self.init_stride, 85 | padding=1, 86 | bias=False)) 87 | for i in range(len(self.stages)): 88 | ### Dense-block i 89 | self.add_block(i) 90 | ### Linear layer 91 | self.classifier = nn.Linear(self.num_features, args.num_classes) 92 | 93 | ### initialize 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 97 | m.weight.data.normal_(0, math.sqrt(2. / n)) 98 | elif isinstance(m, nn.BatchNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | elif isinstance(m, nn.Linear): 102 | m.bias.data.zero_() 103 | 104 | def add_block(self, i): 105 | ### Check if ith is the last one 106 | last = (i == len(self.stages) - 1) 107 | block = _DenseBlock( 108 | num_layers=self.stages[i], 109 | in_channels=self.num_features, 110 | growth_rate=self.growth[i], 111 | args=self.args 112 | ) 113 | self.features.add_module('denseblock_%d' % (i + 1), block) 114 | self.num_features += self.stages[i] * self.growth[i] 115 | if not last: 116 | out_features = make_divisible(math.ceil(self.num_features * self.reduction), 117 | self.args.group_1x1) 118 | trans = _Transition(in_channels=self.num_features, 119 | out_channels=out_features, 120 | args=self.args) 121 | self.features.add_module('transition_%d' % (i + 1), trans) 122 | #self.num_features = out_features 123 | else: 124 | self.features.add_module('norm_last', 125 | nn.BatchNorm2d(self.num_features)) 126 | self.features.add_module('relu_last', 127 | nn.ReLU(inplace=True)) 128 | ### Use adaptive ave pool as global pool 129 | self.features.add_module('pool_last', 130 | nn.AvgPool2d(self.pool_size)) 131 | 132 | def forward(self, x, progress=None): 133 | features = self.features(x) 134 | out = features.view(features.size(0), -1) 135 | out = self.classifier(out) 136 | return out 137 | -------------------------------------------------------------------------------- /models/dynamic_densenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | import math 10 | from layers import Conv, DynamicMultiHeadConv 11 | 12 | __all__ = ['Dydensenet'] 13 | 14 | 15 | class _DenseLayer(nn.Module): 16 | def __init__(self, in_channels, growth_rate, args): 17 | super(_DenseLayer, self).__init__() 18 | ### 1x1 conv: i --> bottleneck * k 19 | self.conv_1 = DynamicMultiHeadConv( 20 | in_channels, args.bottleneck * growth_rate, 21 | kernel_size=1, heads=args.heads, squeeze_rate=args.squeeze_rate, 22 | gate_factor=args.gate_factor) 23 | 24 | ### 3x3 conv: bottleneck * k --> k 25 | self.conv_2 = Conv(args.bottleneck * growth_rate, growth_rate, 26 | kernel_size=3, padding=1, groups=args.group_3x3) 27 | 28 | def forward(self, x): 29 | _lasso_loss = x[1] 30 | x_ = x[0] 31 | x, lasso_loss = self.conv_1(x[0]) 32 | x = self.conv_2(x) 33 | x = torch.cat([x_, x], 1) 34 | _lasso_loss.append(lasso_loss) 35 | return [x, _lasso_loss] 36 | 37 | 38 | class _DenseBlock(nn.Sequential): 39 | def __init__(self, num_layers, in_channels, growth_rate, args): 40 | super(_DenseBlock, self).__init__() 41 | for i in range(num_layers): 42 | layer = _DenseLayer(in_channels + i * growth_rate, growth_rate, args) 43 | self.add_module('denselayer_%d' % (i + 1), layer) 44 | 45 | 46 | class _Transition(nn.Module): 47 | def __init__(self, in_channels, args): 48 | super(_Transition, self).__init__() 49 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2) 50 | 51 | def forward(self, x): 52 | _lasso_loss = x[1] 53 | x_ = x[0] 54 | x = self.pool(x_) 55 | return [x, _lasso_loss] 56 | 57 | class Conv2d_lasso(nn.Conv2d): 58 | def forward(self, x): 59 | x = super(Conv2d_lasso, self).forward(x) 60 | return [x, []] 61 | 62 | class DydenseNet(nn.Module): 63 | def __init__(self, args): 64 | 65 | super(DydenseNet, self).__init__() 66 | 67 | self.stages = args.stages 68 | self.growth = args.growth 69 | assert len(self.stages) == len(self.growth) 70 | self.args = args 71 | self.progress = 0.0 72 | if args.data in ['cifar10', 'cifar100']: 73 | self.init_stride = 1 74 | self.pool_size = 8 75 | else: 76 | self.init_stride = 2 77 | self.pool_size = 7 78 | 79 | self.features = nn.Sequential() 80 | ### Initial nChannels should be 3 81 | self.num_features = 2 * self.growth[0] 82 | ### Dense-block 1 (224x224) 83 | self.features.add_module('init_conv', Conv2d_lasso(3, self.num_features, 84 | kernel_size=3, 85 | stride=self.init_stride, 86 | padding=1, 87 | bias=False)) 88 | for i in range(len(self.stages)): 89 | ### Dense-block i 90 | self.add_block(i) 91 | 92 | ### Linear layer 93 | self.bn_last = nn.BatchNorm2d(self.num_features) 94 | self.relu_last = nn.ReLU(inplace=True) 95 | self.pool_last = nn.AvgPool2d(self.pool_size) 96 | self.classifier = nn.Linear(self.num_features, args.num_classes) 97 | self.classifier.bias.data.zero_() 98 | 99 | ### initialize 100 | for m in self.modules(): 101 | if isinstance(m, nn.Conv2d): 102 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 103 | m.weight.data.normal_(0, math.sqrt(2. / n)) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | m.weight.data.fill_(1) 106 | m.bias.data.zero_() 107 | return 108 | 109 | def add_block(self, i): 110 | ### Check if ith is the last one 111 | last = (i == len(self.stages) - 1) 112 | block = _DenseBlock( 113 | num_layers=self.stages[i], 114 | in_channels=self.num_features, 115 | growth_rate=self.growth[i], 116 | args=self.args, 117 | ) 118 | self.features.add_module('denseblock_%d' % (i + 1), block) 119 | self.num_features += self.stages[i] * self.growth[i] 120 | if not last: 121 | trans = _Transition(in_channels=self.num_features, 122 | args=self.args) 123 | self.features.add_module('transition_%d' % (i + 1), trans) 124 | 125 | def forward(self, x, progress=None, threshold=None): 126 | if progress: 127 | DynamicMultiHeadConv.global_progress = progress 128 | features, _lasso_loss = self.features(x) 129 | features = self.bn_last(features) 130 | features = self.relu_last(features) 131 | features = self.pool_last(features) 132 | out = features.view(features.size(0), -1) 133 | out = self.classifier(out) 134 | return out, _lasso_loss 135 | -------------------------------------------------------------------------------- /models/dynamic_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified based on Official Pytorch repository 3 | """ 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | from layers import DynamicMultiHeadConv 9 | 10 | __all__ = ['dyresnet18'] 11 | 12 | 13 | 14 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=dilation, groups=groups, bias=False, dilation=dilation) 18 | 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, args, inplanes, planes, stride=1, downsample=None, groups=1, 29 | base_width=64, dilation=1, norm_layer=None): 30 | super(BasicBlock, self).__init__() 31 | if norm_layer is None: 32 | norm_layer = nn.BatchNorm2d 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 38 | self.conv1 = DynamicMultiHeadConv(inplanes, planes, kernel_size=3, stride=stride, 39 | padding=1, heads=args.heads, squeeze_rate=args.squeeze_rate, 40 | gate_factor=args.gate_factor) 41 | self.bn1 = norm_layer(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = DynamicMultiHeadConv(planes, planes, kernel_size=3, stride=1, 44 | padding=1, heads=args.heads, squeeze_rate=args.squeeze_rate, 45 | gate_factor=args.gate_factor) 46 | self.bn2 = norm_layer(planes) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | _lasso_loss = x[1] 52 | identity = x[0] 53 | 54 | out = self.conv1(x[0]) 55 | _lasso_loss.append(out[1]) 56 | out = self.bn1(out[0]) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | _lasso_loss.append(out[1]) 61 | out = self.bn2(out[0]) 62 | 63 | if self.downsample is not None: 64 | x_down = self.downsample(x[0]) 65 | identity = x_down[0] 66 | _lasso_loss.append(x_down[1]) 67 | 68 | out += identity 69 | out = self.relu(out) 70 | 71 | return [out, _lasso_loss] 72 | 73 | class Norm_after_downsample(nn.Module): 74 | 75 | def __init__(self, norm_layer, planes): 76 | super(Norm_after_downsample, self).__init__() 77 | self.norm = norm_layer(planes) 78 | 79 | def forward(self, x): 80 | _lasso_loss = x[1] 81 | out = self.norm(x[0]) 82 | return [out, _lasso_loss] 83 | 84 | 85 | class ResNet(nn.Module): 86 | 87 | def __init__(self, args, block, layers, num_classes=1000, zero_init_residual=False, 88 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 89 | norm_layer=None): 90 | super(ResNet, self).__init__() 91 | if norm_layer is None: 92 | norm_layer = nn.BatchNorm2d 93 | self._norm_layer = norm_layer 94 | self.args = args 95 | 96 | self.inplanes = 64 97 | self.dilation = 1 98 | if replace_stride_with_dilation is None: 99 | # each element in the tuple indicates if we should replace 100 | # the 2x2 stride with a dilated convolution instead 101 | replace_stride_with_dilation = [False, False, False] 102 | if len(replace_stride_with_dilation) != 3: 103 | raise ValueError("replace_stride_with_dilation should be None " 104 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 105 | self.groups = groups 106 | self.base_width = width_per_group 107 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 108 | bias=False) 109 | self.bn1 = norm_layer(self.inplanes) 110 | self.relu = nn.ReLU(inplace=True) 111 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 112 | self.layer1 = self._make_layer(block, 64, layers[0]) 113 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 114 | dilate=replace_stride_with_dilation[0]) 115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 116 | dilate=replace_stride_with_dilation[1]) 117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 118 | dilate=replace_stride_with_dilation[2]) 119 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 120 | self.fc = nn.Linear(512 * block.expansion, num_classes) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 125 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 126 | nn.init.constant_(m.weight, 1) 127 | nn.init.constant_(m.bias, 0) 128 | 129 | # Zero-initialize the last BN in each residual branch, 130 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 131 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 132 | if zero_init_residual: 133 | for m in self.modules(): 134 | if isinstance(m, Bottleneck): 135 | nn.init.constant_(m.bn3.weight, 0) 136 | elif isinstance(m, BasicBlock): 137 | nn.init.constant_(m.bn2.weight, 0) 138 | 139 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 140 | norm_layer = self._norm_layer 141 | downsample = None 142 | previous_dilation = self.dilation 143 | if dilate: 144 | self.dilation *= stride 145 | stride = 1 146 | if stride != 1 or self.inplanes != planes * block.expansion: 147 | downsample = nn.Sequential( 148 | DynamicMultiHeadConv(self.inplanes, planes * block.expansion, 149 | kernel_size=1, stride=stride, padding=0, heads=self.args.heads, 150 | squeeze_rate=self.args.squeeze_rate, gate_factor=self.args.gate_factor 151 | ), 152 | Norm_after_downsample(norm_layer, planes * block.expansion), 153 | ) 154 | 155 | layers = [] 156 | layers.append(block(self.args, self.inplanes, planes, stride, downsample, self.groups, 157 | self.base_width, previous_dilation, norm_layer)) 158 | self.inplanes = planes * block.expansion 159 | for _ in range(1, blocks): 160 | layers.append(block(self.args, self.inplanes, planes, groups=self.groups, 161 | base_width=self.base_width, dilation=self.dilation, 162 | norm_layer=norm_layer)) 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, x, progress=None, threshold=None): 167 | if progress: 168 | DynamicMultiHeadConv.global_progress = progress 169 | x = self.conv1(x) 170 | x = self.bn1(x) 171 | x = self.relu(x) 172 | x = self.maxpool(x) 173 | 174 | x = self.layer1([x,[]]) 175 | x = self.layer2(x) 176 | x = self.layer3(x) 177 | x = self.layer4(x) 178 | _lasso_loss = x[1] 179 | 180 | x = self.avgpool(x[0]) 181 | x = torch.flatten(x, 1) 182 | x = self.fc(x) 183 | 184 | return x, _lasso_loss 185 | 186 | def _resnet(args, arch, block, layers, pretrained, progress, **kwargs): 187 | model = ResNet(args, block, layers, **kwargs) 188 | return model 189 | 190 | 191 | def dyresnet18(args): 192 | r"""ResNet-18 model from 193 | `"Deep Residual Learning for Image Recognition" `_ 194 | 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | progress (bool): If True, displays a progress bar of the download to stderr 198 | """ 199 | return _resnet(args, 'resnet18', BasicBlock, [2, 2, 2, 2], pretrained=False, progress=True) 200 | 201 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy from official Pytorch repository 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | #__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | # 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | # 'wide_resnet50_2', 'wide_resnet101_2'] 11 | 12 | __all__ = ['ResNet', 'resnet18'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 25 | } 26 | 27 | 28 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 29 | """3x3 convolution with padding""" 30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 31 | padding=dilation, groups=groups, bias=False, dilation=dilation) 32 | 33 | 34 | def conv1x1(in_planes, out_planes, stride=1): 35 | """1x1 convolution""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 37 | 38 | 39 | class BasicBlock(nn.Module): 40 | expansion = 1 41 | 42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 43 | base_width=64, dilation=1, norm_layer=None): 44 | super(BasicBlock, self).__init__() 45 | if norm_layer is None: 46 | norm_layer = nn.BatchNorm2d 47 | if groups != 1 or base_width != 64: 48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 49 | if dilation > 1: 50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 52 | self.conv1 = conv3x3(inplanes, planes, stride) 53 | self.bn1 = norm_layer(planes) 54 | self.relu = nn.ReLU(inplace=True) 55 | self.conv2 = conv3x3(planes, planes) 56 | self.bn2 = norm_layer(planes) 57 | self.downsample = downsample 58 | self.stride = stride 59 | 60 | def forward(self, x): 61 | identity = x 62 | 63 | out = self.conv1(x) 64 | out = self.bn1(out) 65 | out = self.relu(out) 66 | 67 | out = self.conv2(out) 68 | out = self.bn2(out) 69 | 70 | if self.downsample is not None: 71 | identity = self.downsample(x) 72 | 73 | out += identity 74 | out = self.relu(out) 75 | 76 | return out 77 | 78 | 79 | class Bottleneck(nn.Module): 80 | expansion = 4 81 | 82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 83 | base_width=64, dilation=1, norm_layer=None): 84 | super(Bottleneck, self).__init__() 85 | if norm_layer is None: 86 | norm_layer = nn.BatchNorm2d 87 | width = int(planes * (base_width / 64.)) * groups 88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 89 | self.conv1 = conv1x1(inplanes, width) 90 | self.bn1 = norm_layer(width) 91 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 92 | self.bn2 = norm_layer(width) 93 | self.conv3 = conv1x1(width, planes * self.expansion) 94 | self.bn3 = norm_layer(planes * self.expansion) 95 | self.relu = nn.ReLU(inplace=True) 96 | self.downsample = downsample 97 | self.stride = stride 98 | 99 | def forward(self, x): 100 | identity = x 101 | 102 | out = self.conv1(x) 103 | out = self.bn1(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv2(out) 107 | out = self.bn2(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv3(out) 111 | out = self.bn3(out) 112 | 113 | if self.downsample is not None: 114 | identity = self.downsample(x) 115 | 116 | out += identity 117 | out = self.relu(out) 118 | 119 | return out 120 | 121 | 122 | class ResNet(nn.Module): 123 | 124 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 125 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 126 | norm_layer=None): 127 | super(ResNet, self).__init__() 128 | if norm_layer is None: 129 | norm_layer = nn.BatchNorm2d 130 | self._norm_layer = norm_layer 131 | 132 | self.inplanes = 64 133 | self.dilation = 1 134 | if replace_stride_with_dilation is None: 135 | # each element in the tuple indicates if we should replace 136 | # the 2x2 stride with a dilated convolution instead 137 | replace_stride_with_dilation = [False, False, False] 138 | if len(replace_stride_with_dilation) != 3: 139 | raise ValueError("replace_stride_with_dilation should be None " 140 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 141 | self.groups = groups 142 | self.base_width = width_per_group 143 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = norm_layer(self.inplanes) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 150 | dilate=replace_stride_with_dilation[0]) 151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 152 | dilate=replace_stride_with_dilation[1]) 153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 154 | dilate=replace_stride_with_dilation[2]) 155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 156 | self.fc = nn.Linear(512 * block.expansion, num_classes) 157 | 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 161 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 162 | nn.init.constant_(m.weight, 1) 163 | nn.init.constant_(m.bias, 0) 164 | 165 | # Zero-initialize the last BN in each residual branch, 166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 168 | if zero_init_residual: 169 | for m in self.modules(): 170 | if isinstance(m, Bottleneck): 171 | nn.init.constant_(m.bn3.weight, 0) 172 | elif isinstance(m, BasicBlock): 173 | nn.init.constant_(m.bn2.weight, 0) 174 | 175 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 176 | norm_layer = self._norm_layer 177 | downsample = None 178 | previous_dilation = self.dilation 179 | if dilate: 180 | self.dilation *= stride 181 | stride = 1 182 | if stride != 1 or self.inplanes != planes * block.expansion: 183 | downsample = nn.Sequential( 184 | conv1x1(self.inplanes, planes * block.expansion, stride), 185 | norm_layer(planes * block.expansion), 186 | ) 187 | 188 | layers = [] 189 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 190 | self.base_width, previous_dilation, norm_layer)) 191 | self.inplanes = planes * block.expansion 192 | for _ in range(1, blocks): 193 | layers.append(block(self.inplanes, planes, groups=self.groups, 194 | base_width=self.base_width, dilation=self.dilation, 195 | norm_layer=norm_layer)) 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def forward(self, x): 200 | x = self.conv1(x) 201 | x = self.bn1(x) 202 | x = self.relu(x) 203 | x = self.maxpool(x) 204 | 205 | x = self.layer1(x) 206 | x = self.layer2(x) 207 | x = self.layer3(x) 208 | x = self.layer4(x) 209 | 210 | x = self.avgpool(x) 211 | x = torch.flatten(x, 1) 212 | x = self.fc(x) 213 | 214 | return x 215 | 216 | 217 | def _resnet(arch, block, layers, **kwargs): 218 | model = ResNet(block, layers, **kwargs) 219 | return model 220 | 221 | 222 | def resnet18(args): 223 | r"""ResNet-18 model from 224 | `"Deep Residual Learning for Image Recognition" `_ 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | progress (bool): If True, displays a progress bar of the download to stderr 229 | """ 230 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs) 231 | 232 | 233 | #def resnet34(pretrained=False, progress=True, **kwargs): 234 | # r"""ResNet-34 model from 235 | # `"Deep Residual Learning for Image Recognition" `_ 236 | # 237 | # Args: 238 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 239 | # progress (bool): If True, displays a progress bar of the download to stderr 240 | # """ 241 | # return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 242 | # **kwargs) 243 | # 244 | # 245 | #def resnet50(pretrained=False, progress=True, **kwargs): 246 | # r"""ResNet-50 model from 247 | # `"Deep Residual Learning for Image Recognition" `_ 248 | # 249 | # Args: 250 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | # progress (bool): If True, displays a progress bar of the download to stderr 252 | # """ 253 | # return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 254 | # **kwargs) 255 | # 256 | # 257 | #def resnet101(pretrained=False, progress=True, **kwargs): 258 | # r"""ResNet-101 model from 259 | # `"Deep Residual Learning for Image Recognition" `_ 260 | # 261 | # Args: 262 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 263 | # progress (bool): If True, displays a progress bar of the download to stderr 264 | # """ 265 | # return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 266 | # **kwargs) 267 | # 268 | # 269 | #def resnet152(pretrained=False, progress=True, **kwargs): 270 | # r"""ResNet-152 model from 271 | # `"Deep Residual Learning for Image Recognition" `_ 272 | # 273 | # Args: 274 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | # progress (bool): If True, displays a progress bar of the download to stderr 276 | # """ 277 | # return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 278 | # **kwargs) 279 | # 280 | # 281 | #def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 282 | # r"""ResNeXt-50 32x4d model from 283 | # `"Aggregated Residual Transformation for Deep Neural Networks" `_ 284 | # 285 | # Args: 286 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | # progress (bool): If True, displays a progress bar of the download to stderr 288 | # """ 289 | # kwargs['groups'] = 32 290 | # kwargs['width_per_group'] = 4 291 | # return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 292 | # pretrained, progress, **kwargs) 293 | # 294 | # 295 | #def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 296 | # r"""ResNeXt-101 32x8d model from 297 | # `"Aggregated Residual Transformation for Deep Neural Networks" `_ 298 | # 299 | # Args: 300 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 301 | # progress (bool): If True, displays a progress bar of the download to stderr 302 | # """ 303 | # kwargs['groups'] = 32 304 | # kwargs['width_per_group'] = 8 305 | # return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 306 | # pretrained, progress, **kwargs) 307 | # 308 | # 309 | #def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 310 | # r"""Wide ResNet-50-2 model from 311 | # `"Wide Residual Networks" `_ 312 | # 313 | # The model is the same as ResNet except for the bottleneck number of channels 314 | # which is twice larger in every block. The number of channels in outer 1x1 315 | # convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 316 | # channels, and in Wide ResNet-50-2 has 2048-1024-2048. 317 | # 318 | # Args: 319 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 320 | # progress (bool): If True, displays a progress bar of the download to stderr 321 | # """ 322 | # kwargs['width_per_group'] = 64 * 2 323 | # return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 324 | # pretrained, progress, **kwargs) 325 | # 326 | # 327 | #def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 328 | # r"""Wide ResNet-101-2 model from 329 | # `"Wide Residual Networks" `_ 330 | # 331 | # The model is the same as ResNet except for the bottleneck number of channels 332 | # which is twice larger in every block. The number of channels in outer 1x1 333 | # convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 334 | # channels, and in Wide ResNet-50-2 has 2048-1024-2048. 335 | # 336 | # Args: 337 | # pretrained (bool): If True, returns a model pre-trained on ImageNet 338 | # progress (bool): If True, displays a progress bar of the download to stderr 339 | # """ 340 | # kwargs['width_per_group'] = 64 * 2 341 | # return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 342 | # pretrained, progress, **kwargs) 343 | 344 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import os 7 | import shutil 8 | import math 9 | import time 10 | 11 | from functools import reduce 12 | import operator 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from torch.autograd import Variable 18 | 19 | 20 | ###################################### 21 | # measurement functions # 22 | ###################################### 23 | 24 | count_ops = 0 25 | count_params = 0 26 | 27 | def get_num_gen(gen): 28 | return sum(1 for x in gen) 29 | 30 | def is_pruned(layer): 31 | if hasattr(layer, 'mask'): 32 | return True 33 | elif hasattr(layer, 'is_pruned'): 34 | return True 35 | else: 36 | return False 37 | 38 | def is_leaf(model): 39 | return get_num_gen(model.children()) == 0 40 | 41 | 42 | def get_layer_info(layer): 43 | layer_str = str(layer) 44 | type_name = layer_str[:layer_str.find('(')].strip() 45 | return type_name 46 | 47 | 48 | def get_layer_param(model): 49 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()]) 50 | 51 | 52 | ### The input batch size should be 1 to call this function 53 | def measure_layer(layer, x): 54 | global count_ops, count_params 55 | delta_ops = 0 56 | delta_params = 0 57 | multi_add = 1 58 | type_name = get_layer_info(layer) 59 | 60 | ### ops_conv 61 | if type_name in ['Conv2d', 'Conv2d_lasso']: 62 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / 63 | layer.stride[0] + 1) 64 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / 65 | layer.stride[1] + 1) 66 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 67 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 68 | delta_params = get_layer_param(layer) 69 | 70 | ### ops_head_conv 71 | elif type_name in ['HeadConv']: 72 | x_ori = x 73 | x = F.adaptive_avg_pool2d(x, 1) 74 | b, c, _, _ = x.size() 75 | x = x.view(b, c) 76 | measure_layer(layer.fc1, x) 77 | x = layer.fc1(x) 78 | measure_layer(layer.relu_fc1, x) 79 | x = layer.relu_fc1(x) 80 | measure_layer(layer.fc2, x) 81 | x = layer.fc2(x) 82 | measure_layer(layer.relu_fc2, x) 83 | delta_ops = reduce(operator.mul, x.size(), 1) 84 | delta_params = 0 85 | 86 | x = x_ori 87 | conv = layer.conv 88 | out_h = int((x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) / 89 | conv.stride[0] + 1) 90 | out_w = int((x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) / 91 | conv.stride[1] + 1) 92 | delta_ops += conv.in_channels * conv.out_channels * conv.kernel_size[0] * \ 93 | conv.kernel_size[1] * out_h * out_w * layer.target_pruning_rate * multi_add 94 | delta_params += get_layer_param(conv) 95 | 96 | ### ops_dynamic_conv 97 | elif type_name in ['DynamicMultiHeadConv']: 98 | measure_layer(layer.relu, x) 99 | measure_layer(layer.norm, x) 100 | measure_layer(layer.avg_pool, x) 101 | for i in range(layer.heads): 102 | measure_layer(layer.__getattr__('headconv_%1d' % i), x) 103 | delta_ops = 0 104 | delta_params = 0 105 | 106 | ### ops_nonlinearity 107 | elif type_name in ['ReLU', 'ReLU6', 'Sigmoid']: 108 | delta_ops = x.numel() 109 | delta_params = get_layer_param(layer) 110 | 111 | ### ops_pooling 112 | elif type_name in ['AvgPool2d', 'MaxPool2d']: 113 | in_w = x.size()[2] 114 | kernel_ops = layer.kernel_size * layer.kernel_size 115 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 116 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 117 | delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops 118 | delta_params = get_layer_param(layer) 119 | 120 | elif type_name in ['AdaptiveAvgPool2d']: 121 | in_w = x.size()[2] 122 | kernel_size = in_w 123 | padding = 0 124 | kernel_ops = kernel_size * kernel_size 125 | out_w = int((in_w + 2 * padding - kernel_size) / 1 + 1) 126 | out_h = int((in_w + 2 * padding - kernel_size) / 1 + 1) 127 | delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops 128 | delta_params = get_layer_param(layer) 129 | 130 | elif type_name in ['AdaptiveAvgPool2d']: 131 | delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] 132 | delta_params = get_layer_param(layer) 133 | 134 | ### ops_linear 135 | elif type_name in ['Linear']: 136 | weight_ops = layer.weight.numel() * multi_add 137 | try: 138 | bias_ops = layer.bias.numel() 139 | except AttributeError: 140 | bias_ops = 0 141 | delta_ops = x.size()[0] * (weight_ops + bias_ops) 142 | delta_params = get_layer_param(layer) 143 | 144 | ### ops_nothing 145 | elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']: 146 | delta_params = get_layer_param(layer) 147 | 148 | ### unknown layer type 149 | else: 150 | raise TypeError('unknown layer type: %s' % type_name) 151 | 152 | count_ops += delta_ops 153 | count_params += delta_params 154 | return 155 | 156 | 157 | def measure_model(model, H, W): 158 | global count_ops, count_params 159 | count_ops = 0 160 | count_params = 0 161 | data = Variable(torch.zeros(1, 3, H, W)) 162 | 163 | def should_measure(x): 164 | return is_leaf(x) or is_pruned(x) 165 | 166 | def modify_forward(model): 167 | for child in model.children(): 168 | if should_measure(child): 169 | def new_forward(m): 170 | def lambda_forward(x): 171 | measure_layer(m, x) 172 | return m.old_forward(x) 173 | return lambda_forward 174 | child.old_forward = child.forward 175 | child.forward = new_forward(child) 176 | else: 177 | modify_forward(child) 178 | 179 | def restore_forward(model): 180 | for child in model.children(): 181 | # leaf node 182 | if is_leaf(child) and hasattr(child, 'old_forward'): 183 | child.forward = child.old_forward 184 | child.old_forward = None 185 | else: 186 | restore_forward(child) 187 | 188 | modify_forward(model) 189 | model.forward(data) 190 | restore_forward(model) 191 | 192 | return count_ops, count_params 193 | 194 | 195 | 196 | ###################################### 197 | # basic functions # 198 | ###################################### 199 | 200 | 201 | def load_checkpoint(args): 202 | 203 | model_dir = os.path.join(args.savedir, 'save_models') 204 | latest_filename = os.path.join(model_dir, 'latest.txt') 205 | model_filename = '' 206 | 207 | if args.evaluate is not None: 208 | model_filename = args.evaluate 209 | else: 210 | if os.path.exists(latest_filename): 211 | with open(latest_filename, 'r') as fin: 212 | model_filename = fin.readlines()[0].strip() 213 | loadinfo = "=> loading checkpoint from '{}'".format(model_filename) 214 | print(loadinfo) 215 | 216 | state = None 217 | if os.path.exists(model_filename): 218 | state = torch.load(model_filename, map_location='cpu') 219 | loadinfo2 = "=> loaded checkpoint '{}' successfully".format(model_filename) 220 | else: 221 | loadinfo2 = "no checkpoint loaded" 222 | print(loadinfo2) 223 | 224 | return state 225 | 226 | 227 | def save_checkpoint(state, epoch, root, is_best, saveID, keep_freq=10): 228 | 229 | filename = 'checkpoint_%03d.pth.tar' % epoch 230 | model_dir = os.path.join(root, 'save_models') 231 | model_filename = os.path.join(model_dir, filename) 232 | latest_filename = os.path.join(model_dir, 'latest.txt') 233 | 234 | if not os.path.exists(model_dir): 235 | os.makedirs(model_dir) 236 | 237 | # write new checkpoint 238 | torch.save(state, model_filename) 239 | with open(latest_filename, 'w') as fout: 240 | fout.write(model_filename) 241 | print("=> saved checkpoint '{}'".format(model_filename)) 242 | 243 | # update best model 244 | if is_best: 245 | best_filename = os.path.join(model_dir, 'model_best.pth.tar') 246 | shutil.copyfile(model_filename, best_filename) 247 | 248 | # remove old model 249 | if saveID is not None and saveID % keep_freq != 0: 250 | filename = 'checkpoint_%03d.pth.tar' % saveID 251 | model_filename = os.path.join(model_dir, filename) 252 | if os.path.exists(model_filename): 253 | os.remove(model_filename) 254 | print('=> removed checkpoint %s' % model_filename) 255 | 256 | print('##########Time##########', time.strftime('%Y-%m-%d %H:%M:%S')) 257 | return epoch 258 | 259 | 260 | class AverageMeter(object): 261 | """Computes and stores the average and current value""" 262 | def __init__(self): 263 | self.reset() 264 | 265 | def reset(self): 266 | self.val = 0 267 | self.avg = 0 268 | self.sum = 0 269 | self.count = 0 270 | 271 | def update(self, val, n=1): 272 | self.val = val 273 | self.sum += val * n 274 | self.count += n 275 | self.avg = self.sum / self.count 276 | 277 | 278 | def adjust_learning_rate(optimizer, epoch, args, batch=None, 279 | nBatch=None, method='cosine'): 280 | if method == 'cosine': 281 | T_total = args.epochs * nBatch 282 | T_cur = (epoch % args.epochs) * nBatch + batch 283 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * T_cur / T_total)) 284 | elif method == 'multistep': 285 | if args.data in ['cifar10', 'cifar100']: 286 | lr, decay_rate = args.lr, 0.1 287 | if epoch >= args.epochs * 0.75: 288 | lr *= decay_rate**2 289 | elif epoch >= args.epochs * 0.5: 290 | lr *= decay_rate 291 | else: 292 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 293 | lr = args.lr * (0.1 ** (epoch // 30)) 294 | for param_group in optimizer.param_groups: 295 | param_group['lr'] = lr 296 | return lr 297 | 298 | 299 | def accuracy(output, target, topk=(1,)): 300 | """Computes the precision@k for the specified values of k""" 301 | maxk = max(topk) 302 | batch_size = target.size(0) 303 | 304 | _, pred = output.topk(maxk, 1, True, True) 305 | pred = pred.t() 306 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 307 | 308 | res = [] 309 | for k in topk: 310 | correct_k = correct[:k].view(-1).float().sum(0) 311 | res.append(correct_k.mul_(100.0 / batch_size)) 312 | return res 313 | --------------------------------------------------------------------------------