├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── layers.py
├── main.py
├── models
├── __init__.py
├── densenet.py
├── dynamic_densenet.py
├── dynamic_resnet.py
└── resnet.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # result files
104 | results/*
105 | *run*.sh
106 |
107 | .DS_Store
108 | *.swp
109 | results*
110 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Zhuo Su
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dynamic Group Convolution
2 |
3 | This repository contains the PyTorch implementation for
4 | "Dynamic Group Convolution for Accelerating Convolutional Neural Networks"
5 | by
6 | [Zhuo Su](https://zhuogege1943.com/homepage/)\*,
7 | [Linpu Fang](https://dblp.org/pers/hd/f/Fang:Linpu)\*,
8 | [Wenxiong Kang](http://www.scholat.com/auwxkang.en),
9 | [Dewen Hu](https://dblp.org/pers/h/Hu:Dewen.html),
10 | [Matti Pietikäinen](https://en.wikipedia.org/wiki/Matti_Pietik%C3%A4inen_(academic)) and
11 | [Li Liu](http://www.ee.oulu.fi/~lili/LiLiuHomepage.html)
12 | (\* Authors have equal contributions). \[[arXiv](https://arxiv.org/abs/2007.04242)\]
13 |
14 | The code is based on [CondenseNet](https://github.com/ShichenLiu/CondenseNet).
15 |
16 |
17 | ### Citation
18 |
19 | If you find our project useful in your research, please consider citing:
20 |
21 | ```
22 | @inproceedings{su2020dgc,
23 | title={Dynamic Group Convolution for Accelerating Convolutional Neural Networks},
24 | author={Su, Zhuo and Fang, Linpu and Kang, Wenxiong and Hu, Dewen and Pietik{\"a}inen, Matti and Liu, Li},
25 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
26 | year={2020}
27 | }
28 | ```
29 |
30 | ## Introduction
31 |
32 |
33 | Dynamic Group Convolution (DGC) can adaptively select which part
34 | of input channels to be connected within each group for individual
35 | samples on the fly. Specifically, we equip each group with a small feature
36 | selector to automatically select the most important input channels
37 | conditioned on the input images. Multiple groups can adptively capture
38 | abundant and complementary visual/semantic features for each input
39 | image. The DGC preserves the original network structure and has
40 | similar computational efficiency as the conventional group convolutions
41 | simultaneously. Extensive experiments on multiple image classification
42 | benchmarks including CIFAR-10, CIFAR-100 and ImageNet demonstrate its
43 | superiority over the exiting group convolution techniques and dynamic execution methods.
44 |
45 |
46 |

47 | Figure 1: Overview of a DGC layer.
48 |
49 |
50 | The DGC network can be trained from scratch by an
51 | end-to-end manner, without the need of model pre-training. During backward
52 | propagation in a DGC layer, gradients are calculated
53 | only for weights connected to selected channels during the forward pass, and
54 | safely set as 0 for others thanks to the unbiased gating strategy (refer to the paper).
55 | To avoid abrupt changes in training loss while pruning,
56 | we gradually deactivate input channels along the training process
57 | with a cosine shape learning rate.
58 |
59 |
60 |

61 | Figure 2: Training pipeline.
62 |
63 |
64 |
65 | ## Training and Evaluation
66 |
67 | At the moment, we are sorry to tell that the training process on ImageNet is somewhat slow and memory consuming because this is still a coarse implementation. For the first bash script of condensenet on ImageNet, the model was trained on two v100 GPUs with 32G gpu memory each.
68 |
69 | Remove `--evaluate xxx.tar` to Train, otherwise to Evaluate (the trained models can be downloaded through the following links or [baidunetdisk](https://pan.baidu.com/s/17BqJ4slwwNxRydj9RBT8yQ) (code: 9dtn))
70 |
71 | (condensenet with dgc on ImageNet, pruning rate=0.75, heads=4, ***top1=25.4, top5=7.8***)
72 |
73 | Links for `imagenet_dydensenet_h4.tar` (92.3M):
74 | [google drive](https://drive.google.com/file/d/1gKrugAFGLea7kjTa_nmhwVAsinoxze8T/view?usp=sharing),
75 | [onedirve](https://unioulu-my.sharepoint.com/:u:/g/personal/zsu18_univ_yo_oulu_fi/EeU7Lpe2AUBPsONNZYBVv5kBNAy0sdOlj94iuqCdRRkneQ?e=NaZpQF)
76 | ```bash
77 | python main.py --model dydensenet -b 256 -j 4 --data imagenet --datadir /path/to/imagenet \
78 | --epochs 120 --lr-type cosine --stages 4-6-8-10-8 --growth 8-16-32-64-128 --bottleneck 4 \
79 | --heads 4 --group-3x3 4 --gate-factor 0.25 --squeeze-rate 16 --resume --gpu 0 --savedir results/exp \
80 | --evaluate /path/to/imagenet_dydensenet_h4.tar
81 | ```
82 |
83 |
84 | (resnet18 with dgc on ImageNet, pruning rate=0.55, heads=4, ***top1=31.22, top5=11.38***)
85 |
86 | Links for `imagenet_dyresnet18_h4.tar` (47.2M):
87 | [google drive](https://drive.google.com/file/d/1rtSU3iUKlA0NhgnUJz-QksW5aL2Lt2Cg/view?usp=sharing),
88 | [onedirve](https://unioulu-my.sharepoint.com/:u:/g/personal/zsu18_univ_yo_oulu_fi/EaiXCgT7H7NBmBWObq1lOukBUYaQb5J6DOcD3RHFA4PLLQ?e=myQHRN)
89 | ```bash
90 | python main.py --model dyresnet18 -b 256 -j 4 --data imagenet --datadir /path/to/imagenet \
91 | --epochs 120 --lr-type cosine --heads 4 --gate-factor 0.45 --squeeze-rate 16 --resume \
92 | --gpu 0 --savedir results/exp --evaluate /path/to/imagenet_dyresnet18_h4.tar
93 | ```
94 |
95 | (densenet86 with dgc on Cifar10, pruning rate=0.75, heads=4, ***top1=4.77***)
96 |
97 | Links for `cifar10_dydensenet86_h4.tar` (16.7M):
98 | [google drive](https://drive.google.com/file/d/1o1cVxqa7juDgNRK53dKpfTKEbfMhPSdG/view?usp=sharing),
99 | [onedirve](https://unioulu-my.sharepoint.com/:u:/g/personal/zsu18_univ_yo_oulu_fi/EZ6cmeLZGHdLtIJeFiM-FzYBVPDoaj70wZ1r4yT8X48Ivw?e=YocnXs)
100 | ```bash
101 | python main.py --model dydensenet -b 64 -j 4 --data cifar10 --datadir ../data --epochs 300 \
102 | --lr-type cosine --stages 14-14-14 --growth 8-16-32 --bottleneck 4 --heads 4 --group-3x3 4 \
103 | --gate-factor 0.25 --squeeze-rate 16 --resume --gpu 0 --savedir results/exp \
104 | --evaluate /path/to/cifar10_dydensenet86_h4.tar
105 | ```
106 |
107 |
108 | (densenet86 with dgc on Cifar100, pruning rate=0.75, heads=4, ***top1=23.41***)
109 |
110 | Links for `cifar100_dydensenet86_h4.tar` (17.0M):
111 | [google drive](https://drive.google.com/file/d/1Wne46Znto-uivTV-Evc5RHywUEe7Emyn/view?usp=sharing),
112 | [onedirve](https://unioulu-my.sharepoint.com/:u:/g/personal/zsu18_univ_yo_oulu_fi/EXci72YYC_1CiA7GwOybIw0BJK9rUg48ZXaapPQvHq0Viw?e=ZKVXk9)
113 | ```bash
114 | python main.py --model dydensenet -b 64 -j 4 --data cifar100 --datadir ../data --epochs 300 \
115 | --lr-type cosine --stages 14-14-14 --growth 8-16-32 --bottleneck 4 --heads 4 --group-3x3 4 \
116 | --gate-factor 0.25 --squeeze-rate 16 --resume --gpu 0 --savedir results/exp \
117 | --evaluate /path/to/cifar100_dydensenet86_h4.tar
118 | ```
119 |
120 | ## Other notes
121 |
122 | Any discussions or concerns are welcomed in the [Issues](https://github.com/zhuogege1943/dgc/issues)!
123 |
124 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hellozhuo/dgc/86befbd7f7b685ab3bbfafcd027ca3551dda48e9/__init__.py
--------------------------------------------------------------------------------
/layers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | class DynamicMultiHeadConv(nn.Module):
11 | global_progress = 0.0
12 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
13 | padding=0, dilation=1, heads=4, squeeze_rate=16, gate_factor=0.25):
14 | super(DynamicMultiHeadConv, self).__init__()
15 | self.norm = nn.BatchNorm2d(in_channels)
16 | self.relu = nn.ReLU(inplace=True)
17 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
18 | self.in_channels = in_channels
19 | self.out_channels = out_channels
20 | self.heads = heads
21 | self.squeeze_rate = squeeze_rate
22 | self.gate_factor = gate_factor
23 | self.stride = stride
24 | self.padding = padding
25 | self.dilation = dilation
26 | self.is_pruned = True
27 | self.register_buffer('_inactive_channels', torch.zeros(1))
28 |
29 | ### Check if arguments are valid
30 | assert self.in_channels % self.heads == 0, \
31 | "head number can not be divided by input channels"
32 | assert self.out_channels % self.heads == 0, \
33 | "head number can not be divided by output channels"
34 | assert self.gate_factor <= 1.0, "gate factor is greater than 1"
35 |
36 | for i in range(self.heads):
37 | self.__setattr__('headconv_%1d' % i,
38 | HeadConv(in_channels, out_channels // self.heads, squeeze_rate,
39 | kernel_size, stride, padding, dilation, 1, gate_factor))
40 |
41 | def forward(self, x):
42 | """
43 | The code here is just a coarse implementation.
44 | The forward process can be quite slow and memory consuming, need to be optimized.
45 | """
46 | if self.training:
47 | progress = DynamicMultiHeadConv.global_progress
48 | # gradually deactivate input channels
49 | if progress < 3.0 / 4 and progress > 1.0 / 12:
50 | self.inactive_channels = round(self.in_channels * (1 - self.gate_factor) * 3.0 / 2 * (progress - 1.0 / 12))
51 | elif progress >= 3.0 / 4:
52 | self.inactive_channels = round(self.in_channels * (1 - self.gate_factor))
53 |
54 | _lasso_loss = 0.0
55 |
56 | x = self.norm(x)
57 | x = self.relu(x)
58 |
59 | x_averaged = self.avg_pool(x)
60 | x_mask = []
61 | weight = []
62 | for i in range(self.heads):
63 | i_x, i_lasso_loss= self.__getattr__('headconv_%1d' % i)(x, x_averaged, self.inactive_channels)
64 | x_mask.append(i_x)
65 | weight.append(self.__getattr__('headconv_%1d' % i).conv.weight)
66 | _lasso_loss = _lasso_loss + i_lasso_loss
67 |
68 | x_mask = torch.cat(x_mask, dim=1) # batch_size, 4 x C_in, H, W
69 | weight = torch.cat(weight, dim=0) # 4 x C_out, C_in, k, k
70 |
71 | out = F.conv2d(x_mask, weight, None, self.stride,
72 | self.padding, self.dilation, self.heads)
73 | b, c, h, w = out.size()
74 | out = out.view(b, self.heads, c // self.heads, h, w)
75 | out = out.transpose(1, 2).contiguous().view(b, c, h, w)
76 | return [out, _lasso_loss]
77 |
78 | @property
79 | def inactive_channels(self):
80 | return int(self._inactive_channels[0])
81 |
82 | @inactive_channels.setter
83 | def inactive_channels(self, val):
84 | self._inactive_channels.fill_(val)
85 |
86 | class HeadConv(nn.Module):
87 | def __init__(self, in_channels, out_channels, squeeze_rate, kernel_size, stride=1,
88 | padding=0, dilation=1, groups=1, gate_factor=0.25):
89 | super(HeadConv, self).__init__()
90 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
91 | padding, dilation, groups=1, bias=False)
92 | self.target_pruning_rate = gate_factor
93 | if in_channels < 80:
94 | squeeze_rate = squeeze_rate // 2
95 | self.fc1 = nn.Linear(in_channels, in_channels // squeeze_rate, bias=False)
96 | self.relu_fc1 = nn.ReLU(inplace=True)
97 | self.fc2 = nn.Linear(in_channels // squeeze_rate, in_channels, bias=True)
98 | self.relu_fc2 = nn.ReLU(inplace=True)
99 |
100 | nn.init.kaiming_normal_(self.fc1.weight)
101 | nn.init.kaiming_normal_(self.fc2.weight)
102 | nn.init.constant_(self.fc2.bias, 1.0)
103 |
104 | def forward(self, x, x_averaged, inactive_channels):
105 | b, c, _, _ = x.size()
106 | x_averaged = x_averaged.view(b, c)
107 | y = self.fc1(x_averaged)
108 | y = self.relu_fc1(y)
109 | y = self.fc2(y)
110 |
111 |
112 | mask = self.relu_fc2(y) # b, c
113 | _lasso_loss = mask.mean()
114 |
115 | mask_d = mask.detach()
116 | mask_c = mask
117 |
118 | if inactive_channels > 0:
119 | mask_c = mask.clone()
120 | topk_maxmum, _ = mask_d.topk(inactive_channels, dim=1, largest=False, sorted=False)
121 | clamp_max, _ = topk_maxmum.max(dim=1, keepdim=True)
122 | mask_index = mask_d.le(clamp_max)
123 | mask_c[mask_index] = 0
124 |
125 | mask_c = mask_c.view(b, c, 1, 1)
126 | x = x * mask_c.expand_as(x)
127 | return x, _lasso_loss
128 |
129 |
130 | class Conv(nn.Sequential):
131 | def __init__(self, in_channels, out_channels, kernel_size,
132 | stride=1, padding=0, groups=1):
133 | super(Conv, self).__init__()
134 | self.add_module('norm', nn.BatchNorm2d(in_channels))
135 | self.add_module('relu', nn.ReLU(inplace=True))
136 | self.add_module('conv', nn.Conv2d(in_channels, out_channels,
137 | kernel_size=kernel_size,
138 | stride=stride,
139 | padding=padding, bias=False,
140 | groups=groups))
141 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Dynamic Group Convolution
3 | date: July 5th, 2020
4 | authors: Zhuo Su, Linpu Fang
5 | paper: Dynamic Group Convolution for Accelerating Convolutional Neural Networks, ECCV 2020.
6 |
7 | Code forked from "https://github.com/ShichenLiu/CondenseNet"
8 | """
9 |
10 | from __future__ import absolute_import
11 | from __future__ import unicode_literals
12 | from __future__ import print_function
13 | from __future__ import division
14 |
15 | import argparse
16 | import os
17 | import time
18 | import models
19 | from utils import *
20 |
21 | import torch
22 | import torch.nn as nn
23 | import torch.backends.cudnn as cudnn
24 | import torchvision.transforms as transforms
25 | import torchvision.datasets as datasets
26 |
27 | parser = argparse.ArgumentParser(description='PyTorch main code for Dynamic Group Convolution')
28 | parser.add_argument('--data', type=str, default='imagenet',
29 | help='name of dataset',
30 | choices=['cifar10', 'cifar100', 'imagenet'])
31 | parser.add_argument('--datadir', type=str, default='../data',
32 | help='dir to the dataset')
33 | parser.add_argument('--savedir', type=str, default='results/exp',
34 | help='path to save result and checkpoint')
35 |
36 | parser.add_argument('--model', type=str, default='dydensenet',
37 | help='model to train the dataset')
38 | parser.add_argument('-j', '--workers', type=int, default=8,
39 | help='number of data loading workers')
40 | parser.add_argument('--epochs', type=int, default=120,
41 | help='number of total epochs to run')
42 | parser.add_argument('-b', '--batch-size', type=int, default=256,
43 | help='mini-batch size')
44 | parser.add_argument('--lr', '--learning-rate', type=float, default=0.1,
45 | help='initial learning rate')
46 | parser.add_argument('--lr-type', type=str, default='cosine',
47 | help='learning rate strategy',
48 | choices=['cosine', 'multistep'])
49 | parser.add_argument('--group-lasso-lambda', type=float, default=1e-5,
50 | help='group lasso loss weight')
51 | parser.add_argument('--momentum', type=float, default=0.9,
52 | help='momentum for sgd')
53 | parser.add_argument('--weight-decay', '--wd', type=float, default=1e-4,
54 | help='weight decay')
55 | parser.add_argument('--seed', type=int, default=None,
56 | help='manual seed')
57 | parser.add_argument('--gpu', type=str, default='',
58 | help='gpu available')
59 |
60 | parser.add_argument('--stages', type=str,
61 | help='per layer depth')
62 | parser.add_argument('--squeeze-rate', type=int, default=16,
63 | help='squeeze rate in SE head')
64 | parser.add_argument('--heads', type=int, default=4,
65 | help='number of heads for 1x1 convolution')
66 | parser.add_argument('--group-3x3', type=int, default=4,
67 | help='3x3 group convolution')
68 | parser.add_argument('--gate-factor', type=float, default=0.25,
69 | help='gate factor')
70 | parser.add_argument('--growth', type=str,
71 | help='per layer growth')
72 | parser.add_argument('--bottleneck', type=int, default=4,
73 | help='bottleneck in densenet')
74 |
75 | parser.add_argument('--print-freq', type=int, default=10,
76 | help='print frequency')
77 | parser.add_argument('--save-freq', type=int, default=10,
78 | help='save frequency')
79 | parser.add_argument('--resume', action='store_true',
80 | help='use latest checkpoint if have any')
81 | parser.add_argument('--evaluate', type=str, default=None,
82 | help="full path to checkpoint to be evaluated")
83 |
84 | args = parser.parse_args()
85 |
86 |
87 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
88 |
89 | best_prec1 = 0
90 |
91 | def main():
92 | global args, best_prec1
93 |
94 | if args.seed is None:
95 | args.seed = int(time.time())
96 | torch.manual_seed(args.seed)
97 | torch.cuda.manual_seed_all(args.seed)
98 |
99 | R = 32
100 | if args.data == 'cifar10':
101 | args.num_classes = 10
102 | elif args.data == 'cifar100':
103 | args.num_classes = 100
104 | else:
105 | args.num_classes = 1000
106 | R = 224
107 |
108 | if 'densenet' in args.model:
109 | args.stages = list(map(int, args.stages.split('-')))
110 | args.growth = list(map(int, args.growth.split('-')))
111 |
112 |
113 | ### Calculate FLOPs & Param
114 | model = getattr(models, args.model)(args)
115 | n_flops, n_params = measure_model(model, R, R)
116 | print('FLOPs: %.2fM, Params: %.2fM' % (n_flops / 1e6, n_params / 1e6))
117 |
118 | os.makedirs(args.savedir, exist_ok=True)
119 | log_file = os.path.join(args.savedir, "%s_%d_%d.txt" % \
120 | (args.model, int(n_params), int(n_flops)))
121 | del(model)
122 |
123 | ### Create model
124 | model = getattr(models, args.model)(args)
125 | model = torch.nn.DataParallel(model).cuda()
126 |
127 | ### Define loss function (criterion) and optimizer
128 | criterion = nn.CrossEntropyLoss().cuda()
129 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
130 | momentum=args.momentum,
131 | weight_decay=args.weight_decay,
132 | nesterov=True)
133 |
134 | cudnn.benchmark = True
135 |
136 | ### Data loading
137 | if args.data == "cifar10":
138 | normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467],
139 | std=[0.2471, 0.2435, 0.2616])
140 | train_set = datasets.CIFAR10(args.datadir, train=True, download=True,
141 | transform=transforms.Compose([
142 | transforms.RandomCrop(32, padding=4),
143 | transforms.RandomHorizontalFlip(),
144 | transforms.ToTensor(),
145 | normalize,
146 | ]))
147 | val_set = datasets.CIFAR10(args.datadir, train=False,
148 | transform=transforms.Compose([
149 | transforms.ToTensor(),
150 | normalize,
151 | ]))
152 | elif args.data == "cifar100":
153 | normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
154 | std=[0.2675, 0.2565, 0.2761])
155 | train_set = datasets.CIFAR100(args.datadir, train=True, download=True,
156 | transform=transforms.Compose([
157 | transforms.RandomCrop(32, padding=4),
158 | transforms.RandomHorizontalFlip(),
159 | transforms.ToTensor(),
160 | normalize,
161 | ]))
162 | val_set = datasets.CIFAR100(args.datadir, train=False,
163 | transform=transforms.Compose([
164 | transforms.ToTensor(),
165 | normalize,
166 | ]))
167 | else: #imagenet
168 | traindir = os.path.join(args.datadir, 'train')
169 | valdir = os.path.join(args.datadir, 'val')
170 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
171 | std=[0.229, 0.224, 0.225])
172 | train_set = datasets.ImageFolder(traindir, transforms.Compose([
173 | transforms.RandomResizedCrop(224),
174 | transforms.RandomHorizontalFlip(),
175 | transforms.ToTensor(),
176 | normalize,
177 | ]))
178 |
179 | val_set = datasets.ImageFolder(valdir, transforms.Compose([
180 | transforms.Resize(256),
181 | transforms.CenterCrop(224),
182 | transforms.ToTensor(),
183 | normalize,
184 | ]))
185 |
186 | train_loader = torch.utils.data.DataLoader(
187 | train_set,
188 | batch_size=args.batch_size, shuffle=True,
189 | num_workers=args.workers, pin_memory=True)
190 |
191 | val_loader = torch.utils.data.DataLoader(
192 | val_set,
193 | batch_size=args.batch_size, shuffle=False,
194 | num_workers=args.workers, pin_memory=True)
195 |
196 | ### Optionally resume from a checkpoint
197 | args.start_epoch = 0
198 | if args.resume or (args.evaluate is not None):
199 | checkpoint = load_checkpoint(args)
200 | if checkpoint is not None:
201 | model.load_state_dict(checkpoint['state_dict'])
202 | try:
203 | args.start_epoch = checkpoint['epoch'] + 1
204 | best_prec1 = checkpoint['best_prec1']
205 | optimizer.load_state_dict(checkpoint['optimizer'])
206 | except KeyError:
207 | pass
208 |
209 | ### Evaluate directly if required
210 | print(args)
211 | if args.evaluate is not None:
212 | validate(val_loader, model, criterion, args)
213 | return
214 |
215 | saveID = None
216 | for epoch in range(args.start_epoch, args.epochs):
217 | ### Train for one epoch
218 | tr_prec1, tr_prec5, loss, lr = \
219 | train(train_loader, model, criterion, optimizer, epoch, args)
220 |
221 | ### Evaluate on validation set
222 | val_prec1, val_prec5 = validate(val_loader, model, criterion, args)
223 |
224 | ### Remember best prec@1 and save checkpoint
225 | is_best = val_prec1 >= best_prec1
226 | best_prec1 = max(val_prec1, best_prec1)
227 |
228 | log = ("Epoch %03d/%03d: top1 %.4f | top5 %.4f" + \
229 | " | train-top1 %.4f | train-top5 %.4f | loss %.4f | lr %.5f | Time %s\n") \
230 | % (epoch, args.epochs, val_prec1, val_prec5, tr_prec1, \
231 | tr_prec5, loss, lr, time.strftime('%Y-%m-%d %H:%M:%S'))
232 | with open(log_file, 'a') as f:
233 | f.write(log)
234 |
235 | saveID = save_checkpoint({
236 | 'epoch': epoch,
237 | 'state_dict': model.state_dict(),
238 | 'best_prec1': best_prec1,
239 | 'optimizer': optimizer.state_dict(),
240 | }, epoch, args.savedir, is_best,
241 | saveID, keep_freq=args.save_freq)
242 |
243 | return
244 |
245 |
246 | def train(train_loader, model, criterion, optimizer, epoch, args):
247 | batch_time = AverageMeter()
248 | data_time = AverageMeter()
249 | losses = AverageMeter()
250 | lasso_losses = AverageMeter()
251 | top1 = AverageMeter()
252 | top5 = AverageMeter()
253 |
254 | ### Switch to train mode
255 | model.train()
256 | wD = len(str(len(train_loader)))
257 | wE = len(str(args.epochs))
258 |
259 | end = time.time()
260 | for i, (input, target) in enumerate(train_loader):
261 |
262 | progress = float(epoch * len(train_loader) + i) / \
263 | (args.epochs * len(train_loader))
264 | ## Adjust learning rate
265 | lr = adjust_learning_rate(optimizer, epoch, args, batch=i,
266 | nBatch=len(train_loader), method=args.lr_type)
267 |
268 | ## Measure data loading time
269 | data_time.update(time.time() - end)
270 |
271 | input = input.cuda(non_blocking=True)
272 | target = target.cuda(non_blocking=True)
273 |
274 | ## Compute output
275 | output, _lasso_list = model(input, progress)
276 | loss = criterion(output, target)
277 |
278 | ## Add group lasso loss
279 | lasso_loss = 0
280 | if args.group_lasso_lambda > 0:
281 | for lasso_m in _lasso_list:
282 | lasso_loss = lasso_loss + lasso_m.mean()
283 | loss = loss + args.group_lasso_lambda * lasso_loss
284 |
285 | ## Measure accuracy and record loss
286 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
287 | losses.update(loss.item(), input.size(0))
288 | lasso_losses.update(lasso_loss.item())
289 | top1.update(prec1.item(), input.size(0))
290 | top5.update(prec5.item(), input.size(0))
291 |
292 | ## Compute gradient and do SGD step
293 | optimizer.zero_grad()
294 | loss.backward()
295 | optimizer.step()
296 |
297 | ## Measure elapsed time
298 | batch_time.update(time.time() - end)
299 | end = time.time()
300 |
301 | ## Record
302 | if i % args.print_freq == 0:
303 | print(('Epoch: [{0}/{1}][{2}/{3}]\t' + \
304 | 'Time {batch_time.val:.3f}\t' + \
305 | 'Data {data_time.val:.3f}\t' + \
306 | 'Loss (lasso_loss) {loss.val:.4f} ({lasso_loss.val:.4f})\t' + \
307 | 'Prec@1 {top1.val:.3f}\t' + \
308 | 'Prec@5 {top5.val:.3f}\t' + \
309 | 'lr {lr: .5f}\t').format(
310 | epoch, args.epochs, i, len(train_loader), batch_time=batch_time,
311 | data_time=data_time, loss=losses, lasso_loss=lasso_losses,
312 | top1=top1, top5=top5, lr=lr))
313 |
314 | return top1.avg, top5.avg, losses.avg, lr
315 |
316 |
317 | def validate(val_loader, model, criterion, args):
318 | batch_time = AverageMeter()
319 | losses = AverageMeter()
320 | top1 = AverageMeter()
321 | top5 = AverageMeter()
322 |
323 | ## Switch to evaluate mode
324 | model.eval()
325 |
326 | end = time.time()
327 | for i, (input, target) in enumerate(val_loader):
328 | ## Compute output
329 | with torch.no_grad():
330 | target = target.cuda(non_blocking=True)
331 | input = input.cuda(non_blocking=True)
332 |
333 | output, _ = model(input)
334 | loss = criterion(output, target)
335 |
336 | ## Measure accuracy and record loss
337 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
338 | losses.update(loss.data.item(), input.size(0))
339 | top1.update(prec1.item(), input.size(0))
340 | top5.update(prec5.item(), input.size(0))
341 |
342 | ## Measure elapsed time
343 | batch_time.update(time.time() - end)
344 | end = time.time()
345 |
346 | if i % args.print_freq == 0:
347 | print('Test: [{0}/{1}]\t'
348 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
349 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
350 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
351 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
352 | i, len(val_loader), batch_time=batch_time, loss=losses,
353 | top1=top1, top5=top5))
354 |
355 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
356 | .format(top1=top1, top5=top5))
357 |
358 | return top1.avg, top5.avg
359 |
360 |
361 |
362 | if __name__ == '__main__':
363 | main()
364 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .densenet import DenseNet as densenet
2 | from .resnet import resnet18
3 |
4 |
5 | from .dynamic_densenet import DydenseNet as dydensenet
6 | from .dynamic_resnet import dyresnet18
7 |
--------------------------------------------------------------------------------
/models/densenet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.autograd import Variable
10 | import math
11 | from layers import Conv
12 |
13 | __all__ = ['DenseNet']
14 |
15 |
16 | def make_divisible(x, y):
17 | return int((x // y + 1) * y) if x % y else int(x)
18 |
19 |
20 | class _DenseLayer(nn.Module):
21 | def __init__(self, in_channels, growth_rate, args):
22 | super(_DenseLayer, self).__init__()
23 | self.group_1x1 = args.group_1x1
24 | self.group_3x3 = args.group_3x3
25 | ### 1x1 conv i --> b*k
26 | self.conv_1 = Conv(in_channels, args.bottleneck * growth_rate,
27 | kernel_size=1, groups=self.group_1x1)
28 | ### 3x3 conv b*k --> k
29 | self.conv_2 = Conv(args.bottleneck * growth_rate, growth_rate,
30 | kernel_size=3, padding=1, groups=self.group_3x3)
31 |
32 | def forward(self, x):
33 | x_ = x
34 | x = self.conv_1(x)
35 | x = self.conv_2(x)
36 | return torch.cat([x_, x], 1)
37 |
38 |
39 | class _DenseBlock(nn.Sequential):
40 | def __init__(self, num_layers, in_channels, growth_rate, args):
41 | super(_DenseBlock, self).__init__()
42 | for i in range(num_layers):
43 | layer = _DenseLayer(in_channels + i * growth_rate, growth_rate, args)
44 | self.add_module('denselayer_%d' % (i + 1), layer)
45 |
46 |
47 | class _Transition(nn.Module):
48 | def __init__(self, in_channels, out_channels, args):
49 | super(_Transition, self).__init__()
50 | #self.conv = Conv(in_channels, out_channels,
51 | # kernel_size=1, groups=args.group_1x1)
52 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
53 |
54 | def forward(self, x):
55 | #x = self.conv(x)
56 | x = self.pool(x)
57 | return x
58 |
59 |
60 | class DenseNet(nn.Module):
61 | def __init__(self, args):
62 |
63 | super(DenseNet, self).__init__()
64 |
65 | self.stages = args.stages
66 | self.growth = args.growth
67 | self.reduction = args.reduction
68 | assert len(self.stages) == len(self.growth)
69 | self.args = args
70 | self.progress = 0.0
71 | if args.data in ['cifar10', 'cifar100']:
72 | self.init_stride = 1
73 | self.pool_size = 8
74 | else:
75 | self.init_stride = 2
76 | self.pool_size = 7
77 |
78 | self.features = nn.Sequential()
79 | ### Set initial width to 2 x growth_rate[0]
80 | self.num_features = 2 * self.growth[0]
81 | ### Dense-block 1 (224x224)
82 | self.features.add_module('init_conv', nn.Conv2d(3, self.num_features,
83 | kernel_size=3,
84 | stride=self.init_stride,
85 | padding=1,
86 | bias=False))
87 | for i in range(len(self.stages)):
88 | ### Dense-block i
89 | self.add_block(i)
90 | ### Linear layer
91 | self.classifier = nn.Linear(self.num_features, args.num_classes)
92 |
93 | ### initialize
94 | for m in self.modules():
95 | if isinstance(m, nn.Conv2d):
96 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
97 | m.weight.data.normal_(0, math.sqrt(2. / n))
98 | elif isinstance(m, nn.BatchNorm2d):
99 | m.weight.data.fill_(1)
100 | m.bias.data.zero_()
101 | elif isinstance(m, nn.Linear):
102 | m.bias.data.zero_()
103 |
104 | def add_block(self, i):
105 | ### Check if ith is the last one
106 | last = (i == len(self.stages) - 1)
107 | block = _DenseBlock(
108 | num_layers=self.stages[i],
109 | in_channels=self.num_features,
110 | growth_rate=self.growth[i],
111 | args=self.args
112 | )
113 | self.features.add_module('denseblock_%d' % (i + 1), block)
114 | self.num_features += self.stages[i] * self.growth[i]
115 | if not last:
116 | out_features = make_divisible(math.ceil(self.num_features * self.reduction),
117 | self.args.group_1x1)
118 | trans = _Transition(in_channels=self.num_features,
119 | out_channels=out_features,
120 | args=self.args)
121 | self.features.add_module('transition_%d' % (i + 1), trans)
122 | #self.num_features = out_features
123 | else:
124 | self.features.add_module('norm_last',
125 | nn.BatchNorm2d(self.num_features))
126 | self.features.add_module('relu_last',
127 | nn.ReLU(inplace=True))
128 | ### Use adaptive ave pool as global pool
129 | self.features.add_module('pool_last',
130 | nn.AvgPool2d(self.pool_size))
131 |
132 | def forward(self, x, progress=None):
133 | features = self.features(x)
134 | out = features.view(features.size(0), -1)
135 | out = self.classifier(out)
136 | return out
137 |
--------------------------------------------------------------------------------
/models/dynamic_densenet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | import math
10 | from layers import Conv, DynamicMultiHeadConv
11 |
12 | __all__ = ['Dydensenet']
13 |
14 |
15 | class _DenseLayer(nn.Module):
16 | def __init__(self, in_channels, growth_rate, args):
17 | super(_DenseLayer, self).__init__()
18 | ### 1x1 conv: i --> bottleneck * k
19 | self.conv_1 = DynamicMultiHeadConv(
20 | in_channels, args.bottleneck * growth_rate,
21 | kernel_size=1, heads=args.heads, squeeze_rate=args.squeeze_rate,
22 | gate_factor=args.gate_factor)
23 |
24 | ### 3x3 conv: bottleneck * k --> k
25 | self.conv_2 = Conv(args.bottleneck * growth_rate, growth_rate,
26 | kernel_size=3, padding=1, groups=args.group_3x3)
27 |
28 | def forward(self, x):
29 | _lasso_loss = x[1]
30 | x_ = x[0]
31 | x, lasso_loss = self.conv_1(x[0])
32 | x = self.conv_2(x)
33 | x = torch.cat([x_, x], 1)
34 | _lasso_loss.append(lasso_loss)
35 | return [x, _lasso_loss]
36 |
37 |
38 | class _DenseBlock(nn.Sequential):
39 | def __init__(self, num_layers, in_channels, growth_rate, args):
40 | super(_DenseBlock, self).__init__()
41 | for i in range(num_layers):
42 | layer = _DenseLayer(in_channels + i * growth_rate, growth_rate, args)
43 | self.add_module('denselayer_%d' % (i + 1), layer)
44 |
45 |
46 | class _Transition(nn.Module):
47 | def __init__(self, in_channels, args):
48 | super(_Transition, self).__init__()
49 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
50 |
51 | def forward(self, x):
52 | _lasso_loss = x[1]
53 | x_ = x[0]
54 | x = self.pool(x_)
55 | return [x, _lasso_loss]
56 |
57 | class Conv2d_lasso(nn.Conv2d):
58 | def forward(self, x):
59 | x = super(Conv2d_lasso, self).forward(x)
60 | return [x, []]
61 |
62 | class DydenseNet(nn.Module):
63 | def __init__(self, args):
64 |
65 | super(DydenseNet, self).__init__()
66 |
67 | self.stages = args.stages
68 | self.growth = args.growth
69 | assert len(self.stages) == len(self.growth)
70 | self.args = args
71 | self.progress = 0.0
72 | if args.data in ['cifar10', 'cifar100']:
73 | self.init_stride = 1
74 | self.pool_size = 8
75 | else:
76 | self.init_stride = 2
77 | self.pool_size = 7
78 |
79 | self.features = nn.Sequential()
80 | ### Initial nChannels should be 3
81 | self.num_features = 2 * self.growth[0]
82 | ### Dense-block 1 (224x224)
83 | self.features.add_module('init_conv', Conv2d_lasso(3, self.num_features,
84 | kernel_size=3,
85 | stride=self.init_stride,
86 | padding=1,
87 | bias=False))
88 | for i in range(len(self.stages)):
89 | ### Dense-block i
90 | self.add_block(i)
91 |
92 | ### Linear layer
93 | self.bn_last = nn.BatchNorm2d(self.num_features)
94 | self.relu_last = nn.ReLU(inplace=True)
95 | self.pool_last = nn.AvgPool2d(self.pool_size)
96 | self.classifier = nn.Linear(self.num_features, args.num_classes)
97 | self.classifier.bias.data.zero_()
98 |
99 | ### initialize
100 | for m in self.modules():
101 | if isinstance(m, nn.Conv2d):
102 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
103 | m.weight.data.normal_(0, math.sqrt(2. / n))
104 | elif isinstance(m, nn.BatchNorm2d):
105 | m.weight.data.fill_(1)
106 | m.bias.data.zero_()
107 | return
108 |
109 | def add_block(self, i):
110 | ### Check if ith is the last one
111 | last = (i == len(self.stages) - 1)
112 | block = _DenseBlock(
113 | num_layers=self.stages[i],
114 | in_channels=self.num_features,
115 | growth_rate=self.growth[i],
116 | args=self.args,
117 | )
118 | self.features.add_module('denseblock_%d' % (i + 1), block)
119 | self.num_features += self.stages[i] * self.growth[i]
120 | if not last:
121 | trans = _Transition(in_channels=self.num_features,
122 | args=self.args)
123 | self.features.add_module('transition_%d' % (i + 1), trans)
124 |
125 | def forward(self, x, progress=None, threshold=None):
126 | if progress:
127 | DynamicMultiHeadConv.global_progress = progress
128 | features, _lasso_loss = self.features(x)
129 | features = self.bn_last(features)
130 | features = self.relu_last(features)
131 | features = self.pool_last(features)
132 | out = features.view(features.size(0), -1)
133 | out = self.classifier(out)
134 | return out, _lasso_loss
135 |
--------------------------------------------------------------------------------
/models/dynamic_resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified based on Official Pytorch repository
3 | """
4 |
5 |
6 | import torch
7 | import torch.nn as nn
8 | from layers import DynamicMultiHeadConv
9 |
10 | __all__ = ['dyresnet18']
11 |
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=dilation, groups=groups, bias=False, dilation=dilation)
18 |
19 |
20 | def conv1x1(in_planes, out_planes, stride=1):
21 | """1x1 convolution"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, args, inplanes, planes, stride=1, downsample=None, groups=1,
29 | base_width=64, dilation=1, norm_layer=None):
30 | super(BasicBlock, self).__init__()
31 | if norm_layer is None:
32 | norm_layer = nn.BatchNorm2d
33 | if groups != 1 or base_width != 64:
34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
35 | if dilation > 1:
36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
37 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
38 | self.conv1 = DynamicMultiHeadConv(inplanes, planes, kernel_size=3, stride=stride,
39 | padding=1, heads=args.heads, squeeze_rate=args.squeeze_rate,
40 | gate_factor=args.gate_factor)
41 | self.bn1 = norm_layer(planes)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv2 = DynamicMultiHeadConv(planes, planes, kernel_size=3, stride=1,
44 | padding=1, heads=args.heads, squeeze_rate=args.squeeze_rate,
45 | gate_factor=args.gate_factor)
46 | self.bn2 = norm_layer(planes)
47 | self.downsample = downsample
48 | self.stride = stride
49 |
50 | def forward(self, x):
51 | _lasso_loss = x[1]
52 | identity = x[0]
53 |
54 | out = self.conv1(x[0])
55 | _lasso_loss.append(out[1])
56 | out = self.bn1(out[0])
57 | out = self.relu(out)
58 |
59 | out = self.conv2(out)
60 | _lasso_loss.append(out[1])
61 | out = self.bn2(out[0])
62 |
63 | if self.downsample is not None:
64 | x_down = self.downsample(x[0])
65 | identity = x_down[0]
66 | _lasso_loss.append(x_down[1])
67 |
68 | out += identity
69 | out = self.relu(out)
70 |
71 | return [out, _lasso_loss]
72 |
73 | class Norm_after_downsample(nn.Module):
74 |
75 | def __init__(self, norm_layer, planes):
76 | super(Norm_after_downsample, self).__init__()
77 | self.norm = norm_layer(planes)
78 |
79 | def forward(self, x):
80 | _lasso_loss = x[1]
81 | out = self.norm(x[0])
82 | return [out, _lasso_loss]
83 |
84 |
85 | class ResNet(nn.Module):
86 |
87 | def __init__(self, args, block, layers, num_classes=1000, zero_init_residual=False,
88 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
89 | norm_layer=None):
90 | super(ResNet, self).__init__()
91 | if norm_layer is None:
92 | norm_layer = nn.BatchNorm2d
93 | self._norm_layer = norm_layer
94 | self.args = args
95 |
96 | self.inplanes = 64
97 | self.dilation = 1
98 | if replace_stride_with_dilation is None:
99 | # each element in the tuple indicates if we should replace
100 | # the 2x2 stride with a dilated convolution instead
101 | replace_stride_with_dilation = [False, False, False]
102 | if len(replace_stride_with_dilation) != 3:
103 | raise ValueError("replace_stride_with_dilation should be None "
104 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
105 | self.groups = groups
106 | self.base_width = width_per_group
107 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
108 | bias=False)
109 | self.bn1 = norm_layer(self.inplanes)
110 | self.relu = nn.ReLU(inplace=True)
111 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
112 | self.layer1 = self._make_layer(block, 64, layers[0])
113 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
114 | dilate=replace_stride_with_dilation[0])
115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
116 | dilate=replace_stride_with_dilation[1])
117 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
118 | dilate=replace_stride_with_dilation[2])
119 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
120 | self.fc = nn.Linear(512 * block.expansion, num_classes)
121 |
122 | for m in self.modules():
123 | if isinstance(m, nn.Conv2d):
124 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
125 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
126 | nn.init.constant_(m.weight, 1)
127 | nn.init.constant_(m.bias, 0)
128 |
129 | # Zero-initialize the last BN in each residual branch,
130 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
131 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
132 | if zero_init_residual:
133 | for m in self.modules():
134 | if isinstance(m, Bottleneck):
135 | nn.init.constant_(m.bn3.weight, 0)
136 | elif isinstance(m, BasicBlock):
137 | nn.init.constant_(m.bn2.weight, 0)
138 |
139 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
140 | norm_layer = self._norm_layer
141 | downsample = None
142 | previous_dilation = self.dilation
143 | if dilate:
144 | self.dilation *= stride
145 | stride = 1
146 | if stride != 1 or self.inplanes != planes * block.expansion:
147 | downsample = nn.Sequential(
148 | DynamicMultiHeadConv(self.inplanes, planes * block.expansion,
149 | kernel_size=1, stride=stride, padding=0, heads=self.args.heads,
150 | squeeze_rate=self.args.squeeze_rate, gate_factor=self.args.gate_factor
151 | ),
152 | Norm_after_downsample(norm_layer, planes * block.expansion),
153 | )
154 |
155 | layers = []
156 | layers.append(block(self.args, self.inplanes, planes, stride, downsample, self.groups,
157 | self.base_width, previous_dilation, norm_layer))
158 | self.inplanes = planes * block.expansion
159 | for _ in range(1, blocks):
160 | layers.append(block(self.args, self.inplanes, planes, groups=self.groups,
161 | base_width=self.base_width, dilation=self.dilation,
162 | norm_layer=norm_layer))
163 |
164 | return nn.Sequential(*layers)
165 |
166 | def forward(self, x, progress=None, threshold=None):
167 | if progress:
168 | DynamicMultiHeadConv.global_progress = progress
169 | x = self.conv1(x)
170 | x = self.bn1(x)
171 | x = self.relu(x)
172 | x = self.maxpool(x)
173 |
174 | x = self.layer1([x,[]])
175 | x = self.layer2(x)
176 | x = self.layer3(x)
177 | x = self.layer4(x)
178 | _lasso_loss = x[1]
179 |
180 | x = self.avgpool(x[0])
181 | x = torch.flatten(x, 1)
182 | x = self.fc(x)
183 |
184 | return x, _lasso_loss
185 |
186 | def _resnet(args, arch, block, layers, pretrained, progress, **kwargs):
187 | model = ResNet(args, block, layers, **kwargs)
188 | return model
189 |
190 |
191 | def dyresnet18(args):
192 | r"""ResNet-18 model from
193 | `"Deep Residual Learning for Image Recognition" `_
194 |
195 | Args:
196 | pretrained (bool): If True, returns a model pre-trained on ImageNet
197 | progress (bool): If True, displays a progress bar of the download to stderr
198 | """
199 | return _resnet(args, 'resnet18', BasicBlock, [2, 2, 2, 2], pretrained=False, progress=True)
200 |
201 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | Copy from official Pytorch repository
3 | """
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | #__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9 | # 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
10 | # 'wide_resnet50_2', 'wide_resnet101_2']
11 |
12 | __all__ = ['ResNet', 'resnet18']
13 |
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
25 | }
26 |
27 |
28 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
29 | """3x3 convolution with padding"""
30 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
31 | padding=dilation, groups=groups, bias=False, dilation=dilation)
32 |
33 |
34 | def conv1x1(in_planes, out_planes, stride=1):
35 | """1x1 convolution"""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
37 |
38 |
39 | class BasicBlock(nn.Module):
40 | expansion = 1
41 |
42 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
43 | base_width=64, dilation=1, norm_layer=None):
44 | super(BasicBlock, self).__init__()
45 | if norm_layer is None:
46 | norm_layer = nn.BatchNorm2d
47 | if groups != 1 or base_width != 64:
48 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
49 | if dilation > 1:
50 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
51 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
52 | self.conv1 = conv3x3(inplanes, planes, stride)
53 | self.bn1 = norm_layer(planes)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.conv2 = conv3x3(planes, planes)
56 | self.bn2 = norm_layer(planes)
57 | self.downsample = downsample
58 | self.stride = stride
59 |
60 | def forward(self, x):
61 | identity = x
62 |
63 | out = self.conv1(x)
64 | out = self.bn1(out)
65 | out = self.relu(out)
66 |
67 | out = self.conv2(out)
68 | out = self.bn2(out)
69 |
70 | if self.downsample is not None:
71 | identity = self.downsample(x)
72 |
73 | out += identity
74 | out = self.relu(out)
75 |
76 | return out
77 |
78 |
79 | class Bottleneck(nn.Module):
80 | expansion = 4
81 |
82 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
83 | base_width=64, dilation=1, norm_layer=None):
84 | super(Bottleneck, self).__init__()
85 | if norm_layer is None:
86 | norm_layer = nn.BatchNorm2d
87 | width = int(planes * (base_width / 64.)) * groups
88 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
89 | self.conv1 = conv1x1(inplanes, width)
90 | self.bn1 = norm_layer(width)
91 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
92 | self.bn2 = norm_layer(width)
93 | self.conv3 = conv1x1(width, planes * self.expansion)
94 | self.bn3 = norm_layer(planes * self.expansion)
95 | self.relu = nn.ReLU(inplace=True)
96 | self.downsample = downsample
97 | self.stride = stride
98 |
99 | def forward(self, x):
100 | identity = x
101 |
102 | out = self.conv1(x)
103 | out = self.bn1(out)
104 | out = self.relu(out)
105 |
106 | out = self.conv2(out)
107 | out = self.bn2(out)
108 | out = self.relu(out)
109 |
110 | out = self.conv3(out)
111 | out = self.bn3(out)
112 |
113 | if self.downsample is not None:
114 | identity = self.downsample(x)
115 |
116 | out += identity
117 | out = self.relu(out)
118 |
119 | return out
120 |
121 |
122 | class ResNet(nn.Module):
123 |
124 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
125 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
126 | norm_layer=None):
127 | super(ResNet, self).__init__()
128 | if norm_layer is None:
129 | norm_layer = nn.BatchNorm2d
130 | self._norm_layer = norm_layer
131 |
132 | self.inplanes = 64
133 | self.dilation = 1
134 | if replace_stride_with_dilation is None:
135 | # each element in the tuple indicates if we should replace
136 | # the 2x2 stride with a dilated convolution instead
137 | replace_stride_with_dilation = [False, False, False]
138 | if len(replace_stride_with_dilation) != 3:
139 | raise ValueError("replace_stride_with_dilation should be None "
140 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
141 | self.groups = groups
142 | self.base_width = width_per_group
143 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
144 | bias=False)
145 | self.bn1 = norm_layer(self.inplanes)
146 | self.relu = nn.ReLU(inplace=True)
147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
148 | self.layer1 = self._make_layer(block, 64, layers[0])
149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
150 | dilate=replace_stride_with_dilation[0])
151 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
152 | dilate=replace_stride_with_dilation[1])
153 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
154 | dilate=replace_stride_with_dilation[2])
155 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
156 | self.fc = nn.Linear(512 * block.expansion, num_classes)
157 |
158 | for m in self.modules():
159 | if isinstance(m, nn.Conv2d):
160 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
161 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
162 | nn.init.constant_(m.weight, 1)
163 | nn.init.constant_(m.bias, 0)
164 |
165 | # Zero-initialize the last BN in each residual branch,
166 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
167 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
168 | if zero_init_residual:
169 | for m in self.modules():
170 | if isinstance(m, Bottleneck):
171 | nn.init.constant_(m.bn3.weight, 0)
172 | elif isinstance(m, BasicBlock):
173 | nn.init.constant_(m.bn2.weight, 0)
174 |
175 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
176 | norm_layer = self._norm_layer
177 | downsample = None
178 | previous_dilation = self.dilation
179 | if dilate:
180 | self.dilation *= stride
181 | stride = 1
182 | if stride != 1 or self.inplanes != planes * block.expansion:
183 | downsample = nn.Sequential(
184 | conv1x1(self.inplanes, planes * block.expansion, stride),
185 | norm_layer(planes * block.expansion),
186 | )
187 |
188 | layers = []
189 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
190 | self.base_width, previous_dilation, norm_layer))
191 | self.inplanes = planes * block.expansion
192 | for _ in range(1, blocks):
193 | layers.append(block(self.inplanes, planes, groups=self.groups,
194 | base_width=self.base_width, dilation=self.dilation,
195 | norm_layer=norm_layer))
196 |
197 | return nn.Sequential(*layers)
198 |
199 | def forward(self, x):
200 | x = self.conv1(x)
201 | x = self.bn1(x)
202 | x = self.relu(x)
203 | x = self.maxpool(x)
204 |
205 | x = self.layer1(x)
206 | x = self.layer2(x)
207 | x = self.layer3(x)
208 | x = self.layer4(x)
209 |
210 | x = self.avgpool(x)
211 | x = torch.flatten(x, 1)
212 | x = self.fc(x)
213 |
214 | return x
215 |
216 |
217 | def _resnet(arch, block, layers, **kwargs):
218 | model = ResNet(block, layers, **kwargs)
219 | return model
220 |
221 |
222 | def resnet18(args):
223 | r"""ResNet-18 model from
224 | `"Deep Residual Learning for Image Recognition" `_
225 |
226 | Args:
227 | pretrained (bool): If True, returns a model pre-trained on ImageNet
228 | progress (bool): If True, displays a progress bar of the download to stderr
229 | """
230 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs)
231 |
232 |
233 | #def resnet34(pretrained=False, progress=True, **kwargs):
234 | # r"""ResNet-34 model from
235 | # `"Deep Residual Learning for Image Recognition" `_
236 | #
237 | # Args:
238 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
239 | # progress (bool): If True, displays a progress bar of the download to stderr
240 | # """
241 | # return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
242 | # **kwargs)
243 | #
244 | #
245 | #def resnet50(pretrained=False, progress=True, **kwargs):
246 | # r"""ResNet-50 model from
247 | # `"Deep Residual Learning for Image Recognition" `_
248 | #
249 | # Args:
250 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
251 | # progress (bool): If True, displays a progress bar of the download to stderr
252 | # """
253 | # return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
254 | # **kwargs)
255 | #
256 | #
257 | #def resnet101(pretrained=False, progress=True, **kwargs):
258 | # r"""ResNet-101 model from
259 | # `"Deep Residual Learning for Image Recognition" `_
260 | #
261 | # Args:
262 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
263 | # progress (bool): If True, displays a progress bar of the download to stderr
264 | # """
265 | # return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
266 | # **kwargs)
267 | #
268 | #
269 | #def resnet152(pretrained=False, progress=True, **kwargs):
270 | # r"""ResNet-152 model from
271 | # `"Deep Residual Learning for Image Recognition" `_
272 | #
273 | # Args:
274 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
275 | # progress (bool): If True, displays a progress bar of the download to stderr
276 | # """
277 | # return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
278 | # **kwargs)
279 | #
280 | #
281 | #def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
282 | # r"""ResNeXt-50 32x4d model from
283 | # `"Aggregated Residual Transformation for Deep Neural Networks" `_
284 | #
285 | # Args:
286 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
287 | # progress (bool): If True, displays a progress bar of the download to stderr
288 | # """
289 | # kwargs['groups'] = 32
290 | # kwargs['width_per_group'] = 4
291 | # return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
292 | # pretrained, progress, **kwargs)
293 | #
294 | #
295 | #def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
296 | # r"""ResNeXt-101 32x8d model from
297 | # `"Aggregated Residual Transformation for Deep Neural Networks" `_
298 | #
299 | # Args:
300 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
301 | # progress (bool): If True, displays a progress bar of the download to stderr
302 | # """
303 | # kwargs['groups'] = 32
304 | # kwargs['width_per_group'] = 8
305 | # return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
306 | # pretrained, progress, **kwargs)
307 | #
308 | #
309 | #def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
310 | # r"""Wide ResNet-50-2 model from
311 | # `"Wide Residual Networks" `_
312 | #
313 | # The model is the same as ResNet except for the bottleneck number of channels
314 | # which is twice larger in every block. The number of channels in outer 1x1
315 | # convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
316 | # channels, and in Wide ResNet-50-2 has 2048-1024-2048.
317 | #
318 | # Args:
319 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
320 | # progress (bool): If True, displays a progress bar of the download to stderr
321 | # """
322 | # kwargs['width_per_group'] = 64 * 2
323 | # return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
324 | # pretrained, progress, **kwargs)
325 | #
326 | #
327 | #def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
328 | # r"""Wide ResNet-101-2 model from
329 | # `"Wide Residual Networks" `_
330 | #
331 | # The model is the same as ResNet except for the bottleneck number of channels
332 | # which is twice larger in every block. The number of channels in outer 1x1
333 | # convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
334 | # channels, and in Wide ResNet-50-2 has 2048-1024-2048.
335 | #
336 | # Args:
337 | # pretrained (bool): If True, returns a model pre-trained on ImageNet
338 | # progress (bool): If True, displays a progress bar of the download to stderr
339 | # """
340 | # kwargs['width_per_group'] = 64 * 2
341 | # return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
342 | # pretrained, progress, **kwargs)
343 |
344 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import os
7 | import shutil
8 | import math
9 | import time
10 |
11 | from functools import reduce
12 | import operator
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from torch.autograd import Variable
18 |
19 |
20 | ######################################
21 | # measurement functions #
22 | ######################################
23 |
24 | count_ops = 0
25 | count_params = 0
26 |
27 | def get_num_gen(gen):
28 | return sum(1 for x in gen)
29 |
30 | def is_pruned(layer):
31 | if hasattr(layer, 'mask'):
32 | return True
33 | elif hasattr(layer, 'is_pruned'):
34 | return True
35 | else:
36 | return False
37 |
38 | def is_leaf(model):
39 | return get_num_gen(model.children()) == 0
40 |
41 |
42 | def get_layer_info(layer):
43 | layer_str = str(layer)
44 | type_name = layer_str[:layer_str.find('(')].strip()
45 | return type_name
46 |
47 |
48 | def get_layer_param(model):
49 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])
50 |
51 |
52 | ### The input batch size should be 1 to call this function
53 | def measure_layer(layer, x):
54 | global count_ops, count_params
55 | delta_ops = 0
56 | delta_params = 0
57 | multi_add = 1
58 | type_name = get_layer_info(layer)
59 |
60 | ### ops_conv
61 | if type_name in ['Conv2d', 'Conv2d_lasso']:
62 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
63 | layer.stride[0] + 1)
64 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
65 | layer.stride[1] + 1)
66 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
67 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
68 | delta_params = get_layer_param(layer)
69 |
70 | ### ops_head_conv
71 | elif type_name in ['HeadConv']:
72 | x_ori = x
73 | x = F.adaptive_avg_pool2d(x, 1)
74 | b, c, _, _ = x.size()
75 | x = x.view(b, c)
76 | measure_layer(layer.fc1, x)
77 | x = layer.fc1(x)
78 | measure_layer(layer.relu_fc1, x)
79 | x = layer.relu_fc1(x)
80 | measure_layer(layer.fc2, x)
81 | x = layer.fc2(x)
82 | measure_layer(layer.relu_fc2, x)
83 | delta_ops = reduce(operator.mul, x.size(), 1)
84 | delta_params = 0
85 |
86 | x = x_ori
87 | conv = layer.conv
88 | out_h = int((x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) /
89 | conv.stride[0] + 1)
90 | out_w = int((x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) /
91 | conv.stride[1] + 1)
92 | delta_ops += conv.in_channels * conv.out_channels * conv.kernel_size[0] * \
93 | conv.kernel_size[1] * out_h * out_w * layer.target_pruning_rate * multi_add
94 | delta_params += get_layer_param(conv)
95 |
96 | ### ops_dynamic_conv
97 | elif type_name in ['DynamicMultiHeadConv']:
98 | measure_layer(layer.relu, x)
99 | measure_layer(layer.norm, x)
100 | measure_layer(layer.avg_pool, x)
101 | for i in range(layer.heads):
102 | measure_layer(layer.__getattr__('headconv_%1d' % i), x)
103 | delta_ops = 0
104 | delta_params = 0
105 |
106 | ### ops_nonlinearity
107 | elif type_name in ['ReLU', 'ReLU6', 'Sigmoid']:
108 | delta_ops = x.numel()
109 | delta_params = get_layer_param(layer)
110 |
111 | ### ops_pooling
112 | elif type_name in ['AvgPool2d', 'MaxPool2d']:
113 | in_w = x.size()[2]
114 | kernel_ops = layer.kernel_size * layer.kernel_size
115 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
116 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
117 | delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops
118 | delta_params = get_layer_param(layer)
119 |
120 | elif type_name in ['AdaptiveAvgPool2d']:
121 | in_w = x.size()[2]
122 | kernel_size = in_w
123 | padding = 0
124 | kernel_ops = kernel_size * kernel_size
125 | out_w = int((in_w + 2 * padding - kernel_size) / 1 + 1)
126 | out_h = int((in_w + 2 * padding - kernel_size) / 1 + 1)
127 | delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops
128 | delta_params = get_layer_param(layer)
129 |
130 | elif type_name in ['AdaptiveAvgPool2d']:
131 | delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3]
132 | delta_params = get_layer_param(layer)
133 |
134 | ### ops_linear
135 | elif type_name in ['Linear']:
136 | weight_ops = layer.weight.numel() * multi_add
137 | try:
138 | bias_ops = layer.bias.numel()
139 | except AttributeError:
140 | bias_ops = 0
141 | delta_ops = x.size()[0] * (weight_ops + bias_ops)
142 | delta_params = get_layer_param(layer)
143 |
144 | ### ops_nothing
145 | elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']:
146 | delta_params = get_layer_param(layer)
147 |
148 | ### unknown layer type
149 | else:
150 | raise TypeError('unknown layer type: %s' % type_name)
151 |
152 | count_ops += delta_ops
153 | count_params += delta_params
154 | return
155 |
156 |
157 | def measure_model(model, H, W):
158 | global count_ops, count_params
159 | count_ops = 0
160 | count_params = 0
161 | data = Variable(torch.zeros(1, 3, H, W))
162 |
163 | def should_measure(x):
164 | return is_leaf(x) or is_pruned(x)
165 |
166 | def modify_forward(model):
167 | for child in model.children():
168 | if should_measure(child):
169 | def new_forward(m):
170 | def lambda_forward(x):
171 | measure_layer(m, x)
172 | return m.old_forward(x)
173 | return lambda_forward
174 | child.old_forward = child.forward
175 | child.forward = new_forward(child)
176 | else:
177 | modify_forward(child)
178 |
179 | def restore_forward(model):
180 | for child in model.children():
181 | # leaf node
182 | if is_leaf(child) and hasattr(child, 'old_forward'):
183 | child.forward = child.old_forward
184 | child.old_forward = None
185 | else:
186 | restore_forward(child)
187 |
188 | modify_forward(model)
189 | model.forward(data)
190 | restore_forward(model)
191 |
192 | return count_ops, count_params
193 |
194 |
195 |
196 | ######################################
197 | # basic functions #
198 | ######################################
199 |
200 |
201 | def load_checkpoint(args):
202 |
203 | model_dir = os.path.join(args.savedir, 'save_models')
204 | latest_filename = os.path.join(model_dir, 'latest.txt')
205 | model_filename = ''
206 |
207 | if args.evaluate is not None:
208 | model_filename = args.evaluate
209 | else:
210 | if os.path.exists(latest_filename):
211 | with open(latest_filename, 'r') as fin:
212 | model_filename = fin.readlines()[0].strip()
213 | loadinfo = "=> loading checkpoint from '{}'".format(model_filename)
214 | print(loadinfo)
215 |
216 | state = None
217 | if os.path.exists(model_filename):
218 | state = torch.load(model_filename, map_location='cpu')
219 | loadinfo2 = "=> loaded checkpoint '{}' successfully".format(model_filename)
220 | else:
221 | loadinfo2 = "no checkpoint loaded"
222 | print(loadinfo2)
223 |
224 | return state
225 |
226 |
227 | def save_checkpoint(state, epoch, root, is_best, saveID, keep_freq=10):
228 |
229 | filename = 'checkpoint_%03d.pth.tar' % epoch
230 | model_dir = os.path.join(root, 'save_models')
231 | model_filename = os.path.join(model_dir, filename)
232 | latest_filename = os.path.join(model_dir, 'latest.txt')
233 |
234 | if not os.path.exists(model_dir):
235 | os.makedirs(model_dir)
236 |
237 | # write new checkpoint
238 | torch.save(state, model_filename)
239 | with open(latest_filename, 'w') as fout:
240 | fout.write(model_filename)
241 | print("=> saved checkpoint '{}'".format(model_filename))
242 |
243 | # update best model
244 | if is_best:
245 | best_filename = os.path.join(model_dir, 'model_best.pth.tar')
246 | shutil.copyfile(model_filename, best_filename)
247 |
248 | # remove old model
249 | if saveID is not None and saveID % keep_freq != 0:
250 | filename = 'checkpoint_%03d.pth.tar' % saveID
251 | model_filename = os.path.join(model_dir, filename)
252 | if os.path.exists(model_filename):
253 | os.remove(model_filename)
254 | print('=> removed checkpoint %s' % model_filename)
255 |
256 | print('##########Time##########', time.strftime('%Y-%m-%d %H:%M:%S'))
257 | return epoch
258 |
259 |
260 | class AverageMeter(object):
261 | """Computes and stores the average and current value"""
262 | def __init__(self):
263 | self.reset()
264 |
265 | def reset(self):
266 | self.val = 0
267 | self.avg = 0
268 | self.sum = 0
269 | self.count = 0
270 |
271 | def update(self, val, n=1):
272 | self.val = val
273 | self.sum += val * n
274 | self.count += n
275 | self.avg = self.sum / self.count
276 |
277 |
278 | def adjust_learning_rate(optimizer, epoch, args, batch=None,
279 | nBatch=None, method='cosine'):
280 | if method == 'cosine':
281 | T_total = args.epochs * nBatch
282 | T_cur = (epoch % args.epochs) * nBatch + batch
283 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * T_cur / T_total))
284 | elif method == 'multistep':
285 | if args.data in ['cifar10', 'cifar100']:
286 | lr, decay_rate = args.lr, 0.1
287 | if epoch >= args.epochs * 0.75:
288 | lr *= decay_rate**2
289 | elif epoch >= args.epochs * 0.5:
290 | lr *= decay_rate
291 | else:
292 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
293 | lr = args.lr * (0.1 ** (epoch // 30))
294 | for param_group in optimizer.param_groups:
295 | param_group['lr'] = lr
296 | return lr
297 |
298 |
299 | def accuracy(output, target, topk=(1,)):
300 | """Computes the precision@k for the specified values of k"""
301 | maxk = max(topk)
302 | batch_size = target.size(0)
303 |
304 | _, pred = output.topk(maxk, 1, True, True)
305 | pred = pred.t()
306 | correct = pred.eq(target.view(1, -1).expand_as(pred))
307 |
308 | res = []
309 | for k in topk:
310 | correct_k = correct[:k].view(-1).float().sum(0)
311 | res.append(correct_k.mul_(100.0 / batch_size))
312 | return res
313 |
--------------------------------------------------------------------------------