├── .gitignore
├── LICENSE
├── README.md
├── dataloader.py
├── img
└── DGNet.png
├── main.py
├── models
├── __init__.py
├── cifar
│ ├── __init__.py
│ └── resdg.py
├── imagenet
│ ├── __init__.py
│ └── resdg.py
├── mask.py
└── mobilenet_v2
│ ├── __init__.py
│ ├── mobilenet_v2_dg.py
│ └── mobilenet_v2_dg_util.py
├── options.py
├── regularization.py
├── requirements.txt
├── scripts
├── cifar_e.sh
├── cifar_t.sh
├── imagenet_e.sh
├── imagenet_t.sh
├── mobilenet_v2_e.sh
└── mobilenet_v2_t.sh
└── utils
├── __init__.py
├── logger.py
├── misc.py
└── progress
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.rst
├── demo.gif
├── progress
├── __init__.py
├── bar.py
├── counter.py
├── helpers.py
└── spinner.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # dataset
7 | data/
8 |
9 | # log
10 | logs/
11 |
12 | # jupyter notebook
13 | *.ipynb
14 |
15 | # IDE
16 | .vscode
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 anonymous-9800
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dynamic Dual Gating Neural Networks
2 |
3 | This repository contains the PyTorch implementation for
4 |
5 | > **Dynamic Dual Gating Neural Networks**
6 | > Fanrong Li, Gang Li, Xiangyu He, Jian Cheng
7 | > ICCV 2021 Oral
8 |
9 | 
10 |
11 | ## Getting Started
12 |
13 | ### Requirements
14 |
15 | The main requirements of this work are:
16 |
17 | - Python 3.7
18 | - PyTorch == 1.5.0
19 | - Torchvision == 0.6.0
20 | - CUDA 10.2
21 |
22 | We recommand using conda env to setup the experimental environments.
23 |
24 |
25 | ```shell script
26 | # Create environment
27 | conda create -n DGNet python=3.7
28 | conda activate DGNet
29 |
30 | # Install PyTorch & Torchvision
31 | pip install torch==1.5.0 torchvision==0.6.0
32 |
33 | # Clone repo
34 | git clone https://github.com/anonymous-9800/DGNet.git ./DGNet
35 | cd ./DGNet
36 |
37 | # Install other requirements
38 | pip install -r requirements.txt
39 | ```
40 |
41 | ### Trained models
42 | Our trained models can be found here: [Google Drive](https://drive.google.com/file/d/1_-G5eHm3PUrrorjzp8w17W7ogZZoTElk/view?usp=sharing). And the pretrained cifar10 models can be found here: [Google Drive](https://drive.google.com/file/d/15sM2W2ADqtq5Gr8RTdaFalPK7qIw0VXF/view?usp=sharing). Unzip and place them into the DGNet folder.
43 |
44 | ### Evaluate a trained DGNet
45 |
46 | ```shell script
47 | # CIFAR-10
48 | sh ./scripts/cifar_e.sh [ARCH] [PATH-TO-DATASET] [GPU-IDs] [PATH-TO-SAVE] [PATH-TO-TRAINED-MODEL]
49 |
50 | # ResNet on ImageNet
51 | sh ./scripts/imagenet_e.sh [ARCH] [PATH-TO-DATASET] [GPU-IDs] [PATH-TO-SAVE] [PATH-TO-TRAINED-MODEL]
52 |
53 | # Example
54 | sh ./scripts/imagenet_e.sh resdg34 [PATH-TO-DATASET] 0 imagenet/resdg34-04-e ./trained_models_cls/imagenet_results/resdg34/sparse06/resdg34_04.pth.tar
55 | ```
56 |
57 | ### Train a DGNet
58 | ```shell script
59 | # CIFAR-10
60 | sh ./scripts/cifar_t.sh [ARCH] [PATH-TO-DATASET] [TARGET-DENSITY] [GPU-IDs] [PATH-TO-SAVE] [PATH-TO-PRETRAINED-MODEL]
61 |
62 | # ResNet on ImageNet
63 | sh ./scripts/imagenet_t.sh [ARCH] [PATH-TO-DATASET] [TARGET-DENSITY] [GPU-IDs] [PATH-TO-SAVE]
64 |
65 | # Example
66 | sh ./scripts/imagenet_t.sh resdg34 [PATH-TO-DATASET] 0.4 0,1 imagent/resdg34-04
67 | ```
68 |
69 | ## Main results
70 |
71 |
72 |
73 | Model |
74 | Method |
75 | Top-1 (%) |
76 | Top-5 (%) |
77 | FLOPs |
78 | Google Drive |
79 |
80 |
81 | ResNet-18 |
82 | DGNet (50%) |
83 | 70.12 |
84 | 89.22 |
85 | 9.54E8 |
86 | Link |
87 |
88 |
89 | DGNet (60%) |
90 | 69.38 |
91 | 88.94 |
92 | 7.88E8 |
93 | Link |
94 |
95 |
96 | ResNet-34 |
97 | DGNet (60%) |
98 | 73.01 |
99 | 90.99 |
100 | 1.50E9 |
101 | Link |
102 |
103 |
104 | DGNet (70%) |
105 | 71.95 |
106 | 90.46 |
107 | 1.21E9 |
108 | Link |
109 |
110 |
111 | ResNet-50 |
112 | DGNet (60%) |
113 | 76.41 |
114 | 93.05 |
115 | 1.65E9 |
116 | Link |
117 |
118 |
119 | DGNet (70%) |
120 | 75.12 |
121 | 92.34 |
122 | 1.31E9 |
123 | Link |
124 |
125 |
126 | MobileNet-V2 |
127 | DGNet (50%) |
128 | 71.62 |
129 | 90.05 |
130 | 1.60E8 |
131 | Link |
132 |
133 |
134 |
135 | ## Citation
136 |
137 | If you find this project useful for your research, please use the following BibTeX entry.
138 |
139 | @inproceedings{dgnet,
140 | title={Dynamic Dual Gating Neural Networks},
141 | author={Li, Fanrong and Li, Gang and He, Xiangyu and Cheng, Jian},
142 | booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
143 | year={2021}
144 | }
145 |
146 | ## Contact
147 | For any questions, feel free to contact:
148 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import torch
4 | import torchvision.datasets as datasets
5 | import torchvision.transforms as transforms
6 |
7 |
8 | def _getCifarLoader(data, dataset, batch_size, workers):
9 | traindir = os.path.join(data)
10 | valdir = os.path.join(data)
11 | if dataset == 'cifar10':
12 | normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
13 | std=(0.2023, 0.1994, 0.2010))
14 | else:
15 | normalize = transforms.Normalize(mean=(0.507, 0.487, 0.441),
16 | std=(0.267, 0.256, 0.276))
17 |
18 | logging.info('=> Preparing dataset %s' % dataset)
19 | transform_train = transforms.Compose([
20 | transforms.RandomCrop(32, padding=4),
21 | transforms.RandomHorizontalFlip(),
22 | transforms.ToTensor(),
23 | normalize,
24 | ])
25 |
26 | transform_test = transforms.Compose([
27 | transforms.ToTensor(),
28 | normalize,
29 | ])
30 |
31 | if dataset == 'cifar10':
32 | dataloader = datasets.CIFAR10
33 | num_classes = 10
34 | else:
35 | dataloader = datasets.CIFAR100
36 | num_classes = 100
37 |
38 | trainset = dataloader(root=traindir,
39 | train=True,
40 | download=False,
41 | transform=transform_train)
42 | trainloader = torch.utils.data.DataLoader(trainset,
43 | batch_size=batch_size,
44 | shuffle=True,
45 | num_workers=workers)
46 |
47 | testset = dataloader(root=valdir,
48 | train=False,
49 | download=False,
50 | transform=transform_test)
51 | testloader = torch.utils.data.DataLoader(testset,
52 | batch_size=batch_size,
53 | shuffle=False,
54 | num_workers=workers)
55 | return trainloader, testloader
56 |
57 |
58 | def _getImageNetLoader(data, batch_size, workers):
59 | traindir = os.path.join(data, 'train')
60 | valdir = os.path.join(data, 'val')
61 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
62 | std=[0.229, 0.224, 0.225])
63 |
64 | train_loader = torch.utils.data.DataLoader(
65 | datasets.ImageFolder(traindir, transforms.Compose([
66 | transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
67 | transforms.RandomHorizontalFlip(),
68 | transforms.ToTensor(),
69 | normalize,
70 | ])),
71 | batch_size=batch_size, shuffle=True,
72 | num_workers=workers, pin_memory=True,
73 | drop_last=True)
74 |
75 | val_loader = torch.utils.data.DataLoader(
76 | datasets.ImageFolder(valdir, transforms.Compose([
77 | transforms.Resize(256),
78 | transforms.CenterCrop(224),
79 | transforms.ToTensor(),
80 | normalize,
81 | ])),
82 | batch_size=batch_size, shuffle=False,
83 | num_workers=workers, pin_memory=True,
84 | drop_last=True)
85 | return train_loader, val_loader
86 |
87 |
88 | def getDataLoader(data, dataset, batch_size, workers):
89 | if dataset == 'imagenet':
90 | return _getImageNetLoader(data, batch_size, workers)
91 | else:
92 | return _getCifarLoader(data, dataset, batch_size, workers)
93 |
--------------------------------------------------------------------------------
/img/DGNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CAS-CLab/DGNet/6b709a388c463d7468fbad953ad0112bc3abe66d/img/DGNet.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import math
4 | import random
5 | import shutil
6 | import logging
7 | import torch
8 | import torch.nn as nn
9 | import models
10 | import numpy as np
11 | from options import parser
12 | from collections import OrderedDict
13 | from dataloader import getDataLoader
14 | from utils import *
15 | from regularization import *
16 |
17 | args = parser.parse_args()
18 | state = {k: v for k, v in args._get_kwargs()}
19 | print('Parameters:')
20 | for key, value in state.items():
21 | print(' {key} : {value}'.format(key=key, value=value))
22 |
23 | # Use CUDA
24 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
25 | use_cuda = torch.cuda.is_available()
26 |
27 | # Random seed
28 | if args.manualSeed is None:
29 | args.manualSeed = random.randint(1, 10000)
30 | random.seed(args.manualSeed)
31 | torch.manual_seed(args.manualSeed)
32 | np.random.seed(args.manualSeed)
33 | if use_cuda:
34 | torch.cuda.manual_seed_all(args.manualSeed)
35 | torch.backends.cudnn.deterministic = True
36 |
37 | best_acc = 0 # best test accuracy
38 |
39 | # Get loggers and save the config information
40 | train_log, test_log, checkpoint_dir, log_dir = get_loggers(args)
41 |
42 | def main():
43 | global best_acc, train_log, test_log, checkpoint_dir, log_dir
44 | # create model
45 | logging.info("=" * 89)
46 | logging.info("=> creating model '{}'".format(args.arch))
47 | model = models.get_model(pretrained=args.pretrained, dataset = args.dataset,
48 | arch = args.arch, bias=args.bias)
49 | # define loss function (criterion) and optimizer
50 | criterion = Loss()
51 | model.set_criterion(criterion)
52 | # Data loader
53 | trainloader, testloader = getDataLoader(args.data, args.dataset, args.batch_size,
54 | args.workers)
55 | # to cuda
56 | if torch.cuda.is_available() and args.gpu_id != -1:
57 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
58 | model = torch.nn.DataParallel(model).cuda()
59 | logging.info('=> running the model on gpu{}.'.format(args.gpu_id))
60 | else:
61 | logging.info('=> running the model on cpu.')
62 | # define optimizer
63 | param_dict = dict(model.named_parameters())
64 | params = []
65 | BN_name_pool = []
66 | for m_name, m in model.named_modules():
67 | if isinstance(m, nn.BatchNorm2d):
68 | BN_name_pool.append(m_name + '.weight')
69 | BN_name_pool.append(m_name + '.bias')
70 | for key, value in param_dict.items():
71 | if (key in BN_name_pool and 'mobilenet' in args.arch) or 'mask' in key:
72 | params += [{'params': [value], 'lr': args.learning_rate, 'weight_decay': 0.}]
73 | else:
74 | params += [{'params':[value]}]
75 | optimizer = torch.optim.SGD(params, lr=args.learning_rate,weight_decay=args.weight_decay,
76 | momentum=args.momentum, nesterov=True)
77 | p_anneal = ExpAnnealing(0, 1, 0, alpha=args.alpha)
78 | # ready
79 | logging.info("=" * 89)
80 | # Evaluate
81 | if args.evaluate:
82 | logging.info('Evaluate model')
83 | top1, top5 = validate(testloader, model, criterion, 0, use_cuda,
84 | (args.lbda, 0), args.den_target)
85 | logging.info('Test Acc (Top-1): %.2f, Test Acc (Top-5): %.2f' % (top1, top5))
86 | return
87 | # training
88 | logging.info('\n Train for {} epochs'.format(args.epochs))
89 | train_process(model, args.epochs, testloader, trainloader, criterion, optimizer,
90 | use_cuda, args.lbda, args.gamma, p_anneal, checkpoint_dir, args.den_target)
91 | train_log.close()
92 | test_log.close()
93 | logging.info('Best acc: {}'.format(best_acc))
94 | return
95 |
96 |
97 | def train_process(model, total_epochs, testloader, trainloader, criterion, optimizer,
98 | use_cuda, lbda, gamma, p_anneal, checkpoint_dir, den_target):
99 | global best_acc
100 | for epoch in range(total_epochs):
101 | p = p_anneal.get_lr(epoch)
102 | # get target density
103 | state['den_target'] = den_target
104 | # update lr
105 | adjust_learning_rate(optimizer, epoch=epoch)
106 | # Training
107 | train(trainloader, model, criterion, optimizer, epoch, use_cuda, (lbda, gamma),
108 | den_target, p)
109 | test_acc, _ = validate(testloader, model, criterion, epoch, use_cuda,
110 | (lbda, gamma), den_target, p=p)
111 | # save checkpoint
112 | if checkpoint_dir is not None:
113 | is_best = test_acc > best_acc
114 | best_acc = max(test_acc, best_acc)
115 | model_dict = model.module.state_dict() if use_cuda else model.state_dict()
116 | save_checkpoint(
117 | {
118 | 'epoch': epoch + 1,
119 | 'state_dict': model_dict,
120 | 'acc': test_acc,
121 | 'best_acc': best_acc,
122 | 'optimizer': optimizer.state_dict()
123 | },
124 | is_best=is_best,
125 | checkpoint_dir=checkpoint_dir)
126 | return
127 |
128 |
129 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda, param,
130 | den_target, p):
131 | lbda, gamma = param
132 | # switch to train mode
133 | model.train()
134 | logging.info("=" * 89)
135 |
136 | batch_time, data_time, closses, rlosses, blosses, losses, top1, top5 = getAvgMeter(8)
137 |
138 | end = time.time()
139 | bar = Bar('Processing', max=len(train_loader))
140 | for batch_idx, (x, targets) in enumerate(train_loader):
141 | # measure data loading time
142 | data_time.update(time.time() - end)
143 | # get inputs
144 | if use_cuda:
145 | x, targets = x.cuda(), targets.cuda()
146 | x, targets = torch.autograd.Variable(x), torch.autograd.Variable(targets)
147 | batch_size = x.size(0)
148 | # inference
149 | inputs = {"x": x, "label": targets, "den_target": den_target, "lbda": lbda,
150 | "gamma": gamma, "p": p}
151 | outputs= model(**inputs)
152 | loss = outputs["closs"].mean() + outputs["rloss"].mean() + outputs["bloss"].mean()
153 | # measure accuracy and record loss
154 | prec1, prec5 = accuracy(outputs["out"].data, targets.data, topk=(1, 5))
155 | closses.update(outputs["closs"].mean().item(), batch_size)
156 | rlosses.update(outputs["rloss"].mean().item(), batch_size)
157 | blosses.update(outputs["bloss"].mean().item(), batch_size)
158 | losses.update(loss.item(), batch_size)
159 | top1.update(prec1.item(), batch_size)
160 | top5.update(prec5.item(), batch_size)
161 | # compute gradient and do SGD step
162 | optimizer.zero_grad()
163 | loss.backward()
164 | optimizer.step()
165 | # measure elapsed time
166 | batch_time.update(time.time() - end)
167 | end = time.time()
168 | # plot progress
169 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | '.format(
170 | batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val,
171 | )+'Total: {total:} | (C,R,B)Loss: {closs:.2f}, {rloss:.2f}, {bloss:.2f}'.format(
172 | total=bar.elapsed_td, closs=closses.avg, rloss=rlosses.avg, bloss=blosses.avg,
173 | )+' | Loss: {loss:.2f} | top1: {top1:.2f} | top5: {top5:.2f}'.format(top1=top1.avg,
174 | top5=top5.avg, loss=losses.avg)
175 | bar.next()
176 | bar.finish()
177 | train_log.write(content="{epoch}\t{top1.avg:.4f}\t{top5.avg:.4f}\t{loss.avg:.4f}\t"
178 | "{closs.avg:.4f}\t{rloss.avg:.4f}\t{bloss.avg:.4f}".format(
179 | epoch=epoch, top1=top1, top5=top5,loss=losses, closs=closses,
180 | rloss=rlosses, bloss=blosses),
181 | wrap=True, flush=True)
182 | return
183 |
184 |
185 | def validate(val_loader, model, criterion, epoch, use_cuda, param, den_target, p=0):
186 | global log_dir
187 | lbda, gamma = param
188 | # switch to evaluate mode
189 | model.eval()
190 | logging.info("=" * 89)
191 |
192 | (batch_time, data_time, closses, rlosses, blosses, losses,
193 | top1, top5, block_flops)= getAvgMeter(9)
194 |
195 | with torch.no_grad():
196 | end = time.time()
197 | bar = Bar('Processing', max=len(val_loader))
198 | for batch_idx, (x, targets) in enumerate(val_loader):
199 | # measure data loading time
200 | data_time.update(time.time() - end)
201 | # get inputs
202 | if use_cuda:
203 | x, targets = x.cuda(), targets.cuda(non_blocking=True)
204 | x, targets = torch.autograd.Variable(x), torch.autograd.Variable(targets)
205 | batch_size = x.size(0)
206 | # inference
207 | inputs = {"x": x, "label": targets, "den_target": den_target, "lbda": lbda,
208 | "gamma": gamma, "p": p}
209 | outputs= model(**inputs)
210 | loss = outputs["closs"].mean() + outputs["rloss"].mean() + outputs["bloss"].mean()
211 | # measure accuracy and record loss
212 | prec1, prec5 = accuracy(outputs["out"].data, targets.data, topk=(1, 5))
213 | closses.update(outputs["closs"].mean().item(), batch_size)
214 | rlosses.update(outputs["rloss"].mean().item(), batch_size)
215 | blosses.update(outputs["bloss"].mean().item(), batch_size)
216 | losses.update(loss.item(), batch_size)
217 | top1.update(prec1.item(), batch_size)
218 | top5.update(prec5.item(), batch_size)
219 | # measure elapsed time
220 | batch_time.update(time.time() - end)
221 | end = time.time()
222 | # get flops
223 | flops_real = outputs["flops_real"]
224 | flops_mask = outputs["flops_mask"]
225 | flops_ori = outputs["flops_ori"]
226 | flops_conv, flops_mask, flops_ori, flops_conv1, flops_fc = analyse_flops(
227 | flops_real, flops_mask, flops_ori, batch_size)
228 | block_flops.update(flops_conv, batch_size)
229 | # plot progress
230 | bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:}'.format(
231 | batch=batch_idx+1, size=len(val_loader), bt=batch_time.avg, total=bar.elapsed_td
232 | )+' | (C,R,B)Loss: {closs:.2f}, {rloss:.2f}, {bloss:.2f}'.format(
233 | closs=closses.avg, rloss=rlosses.avg, bloss=blosses.avg,
234 | )+' | Loss: {loss:.2f} | top1: {top1:.2f} | top5: {top5:.2f}'.format(
235 | top1=top1.avg, top5=top5.avg, loss=losses.avg)
236 | bar.next()
237 | bar.finish()
238 | # log
239 | if use_cuda:
240 | model.module.record_flops(block_flops.avg, flops_mask, flops_ori, flops_conv1, flops_fc)
241 | else:
242 | model.record_flops(block_flops.avg, flops_mask, flops_ori, flops_conv1, flops_fc)
243 | flops = (block_flops.avg[-1]+flops_mask[-1]+flops_conv1.mean()+flops_fc.mean())/1024
244 | flops_per = (block_flops.avg[-1]+flops_mask[-1]+flops_conv1.mean()+flops_fc.mean())/(
245 | flops_ori[-1]+flops_conv1.mean()+flops_fc.mean())*100
246 | test_log.write(content="{epoch}\t{top1.avg:.4f}\t{top5.avg:.4f}\t{loss.avg:.4f}\t"
247 | "{closs.avg:.4f}\t{rloss.avg:.4f}\t{bloss.avg:.4f}\t"
248 | "{flops_per:.2f}%\t{flops:.2f}K\t".format(epoch=epoch, top1=top1,
249 | top5=top5, loss=losses, closs=closses, rloss=rlosses,
250 | bloss=blosses, flops_per=flops_per, flops=flops),
251 | wrap=True, flush=True)
252 | return (top1.avg, top5.avg)
253 |
254 |
255 | def getAvgMeter(num):
256 | return [AverageMeter() for _ in range(num)]
257 |
258 |
259 | def adjust_learning_rate(optimizer, epoch):
260 | global state
261 | if args.lr_mode == 'cosine':
262 | lr = 0.5*args.learning_rate*(1+math.cos(math.pi*float(epoch)/float(args.epochs)))
263 | state['learning_rate'] = lr
264 | for param_group in optimizer.param_groups:
265 | param_group['lr'] = lr
266 | elif args.lr_mode == 'step':
267 | if epoch in args.schedule:
268 | state['learning_rate'] *= args.lr_decay
269 | for param_group in optimizer.param_groups:
270 | param_group['lr'] = state['learning_rate']
271 | else:
272 | raise NotImplementedError('can not support lr mode {}'.format(args.lr_mode))
273 | logging.info("\nEpoch: {epoch:3d} | learning rate = {lr:.6f}".format(
274 | epoch=epoch, lr=state['learning_rate']))
275 |
276 |
277 | def save_checkpoint(state,
278 | is_best,
279 | filename='checkpoint.pth.tar',
280 | checkpoint_dir='.'):
281 | filename = os.path.join(checkpoint_dir, filename)
282 | torch.save(state, filename, pickle_protocol=4)
283 | if is_best:
284 | shutil.copyfile(filename, os.path.join(checkpoint_dir, 'model_best.pth.tar'))
285 |
286 | if __name__ == "__main__":
287 | main()
288 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 | import logging
5 | import pretrainedmodels
6 | import torchvision.models as torch_models
7 | import torch.backends.cudnn as cudnn
8 | from torchvision.models.utils import load_state_dict_from_url
9 | from . import cifar as cifar_models
10 | from . import imagenet as imagenet_extra_models
11 | from . import mobilenet_v2 as mobilenet_models
12 |
13 |
14 | SUPPORTED_DATASETS = ('imagenet', 'cifar10')
15 |
16 | TORCHVISION_MODEL_NAMES = sorted(name for name in torch_models.__dict__
17 | if not name.startswith("__")
18 | and callable(torch_models.__dict__[name]))
19 |
20 | IMAGENET_MODEL_NAMES = copy.deepcopy(TORCHVISION_MODEL_NAMES)
21 | IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__
22 | if name.islower() and not name.startswith("__")
23 | and callable(imagenet_extra_models.__dict__[name])))
24 |
25 | CIFAR_MODEL_NAMES = sorted(name for name in cifar_models.__dict__
26 | if name.islower() and not name.startswith("__")
27 | and callable(cifar_models.__dict__[name]))
28 |
29 | MOBILENET_MODEL_NAMES = sorted(name for name in mobilenet_models.__dict__
30 | if name.islower() and not name.startswith("__")
31 | and callable(mobilenet_models.__dict__[name]))
32 |
33 | ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR_MODEL_NAMES
34 | + MOBILENET_MODEL_NAMES)))
35 |
36 | model_urls = {
37 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
38 | }
39 |
40 |
41 | def get_model(pretrained, dataset, arch, **kwargs):
42 | """Create a pytorch model based on the model architecture and dataset
43 |
44 | Args:
45 | pretrained [boolean]: True is you wish to load a pretrained model.
46 | Some models do not have a pretrained version.
47 | dataset: dataset name ('imagenet', 'cifar100', and 'cifar10' are supported)
48 | arch: architecture name
49 | """
50 | dataset = dataset.lower()
51 | if dataset not in SUPPORTED_DATASETS:
52 | raise ValueError('Dataset {} is not supported'.format(dataset))
53 |
54 | model = None
55 | cadene = False
56 | try:
57 | if dataset == 'imagenet':
58 | if 'mobilenet' in arch:
59 | model = _create_mobilenet_model(arch, pretrained, **kwargs)
60 | else:
61 | kwargs['num_classes'] = 1000
62 | model = _create_imagenet_model(arch, pretrained, **kwargs)
63 | elif dataset == 'cifar10':
64 | kwargs['num_classes'] = 10
65 | model = _create_cifar10_model(arch, pretrained, **kwargs)
66 | except ValueError:
67 | raise ValueError('Could not recognize dataset {} and model {} pair'.format(dataset, arch))
68 |
69 | logging.info("=> created a %s%s model with the %s dataset" % ('pretrained ' if pretrained else '',
70 | arch, dataset))
71 | logging.info(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0))
72 | return model
73 |
74 |
75 | def _create_imagenet_model(arch, pretrained, **kwargs):
76 | dataset = "imagenet"
77 | model = None
78 | pretrained_pytorch = pretrained == 'pytorch'
79 | pretrained_checkpoint = os.path.isfile(pretrained)
80 | if arch in TORCHVISION_MODEL_NAMES:
81 | try:
82 | model = getattr(torch_models, arch)(pretrained=pretrained_pytorch)
83 | except NotImplementedError:
84 | # In torchvision 0.3, trying to download a model that has no
85 | # pretrained image available will raise NotImplementedError
86 | if not pretrained_pytorch:
87 | raise
88 | if model is None and (arch in imagenet_extra_models.__dict__):
89 | model = imagenet_extra_models.__dict__[arch](**kwargs)
90 | if pretrained_pytorch:
91 | model_dict = model.state_dict()
92 | # get pretrained model
93 | if arch.startswith('resdg'):
94 | arch_pretrained = 'resnet' + arch.lstrip('resdg')
95 | else:
96 | raise ValueError("There is no pretrained model for {} in pytorch".format(arch))
97 | logging.info("=> use a pretrained %s model to initialize." % (arch_pretrained))
98 | pretrained_model = getattr(torch_models, arch_pretrained)(pretrained=pretrained)
99 | pretrained_dict = pretrained_model.state_dict()
100 | # filter out unnecessary keys
101 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
102 | # overwrite entries in the existing state dict
103 | model_dict.update(pretrained_dict)
104 | # load the new state dict
105 | model.load_state_dict(model_dict)
106 | elif pretrained_checkpoint:
107 | checkpoint = torch.load(pretrained, map_location=lambda storage, loc: storage)
108 | logging.info("=> loaded checkpoint (prec {:.2f})".format(checkpoint['best_acc']))
109 | model_dict = model.state_dict()
110 | pretrained_dict = checkpoint['state_dict']
111 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
112 | model_dict.update(pretrained_dict)
113 | model.load_state_dict(model_dict)
114 |
115 | if model is None and (arch in pretrainedmodels.model_names):
116 | model = pretrainedmodels.__dict__[arch](
117 | num_classes=1000,
118 | pretrained=(dataset if pretrained else None))
119 |
120 | if model is None:
121 | error_message = ''
122 | if arch not in IMAGENET_MODEL_NAMES:
123 | error_message = "Model {} is not supported for dataset ImageNet".format(arch)
124 | elif pretrained:
125 | error_message = "Model {} (ImageNet) does not have a pretrained model".format(arch)
126 | raise ValueError(error_message or 'Failed to find model {}'.format(arch))
127 | return model
128 |
129 |
130 | def _create_cifar10_model(arch, pretrained, **kwargs):
131 | try:
132 | model = cifar_models.__dict__[arch](**kwargs)
133 | except KeyError:
134 | raise ValueError("Model {} is not supported for dataset CIFAR10".format(arch))
135 | pretrained_path = pretrained
136 | pretrained = os.path.isfile(pretrained_path)
137 | # load pretrained model
138 | if pretrained:
139 | checkpoint = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
140 | logging.info("=> loaded checkpoint (prec {:.2f})".format(checkpoint['best_acc']))
141 | model_dict = model.state_dict()
142 | pretrained_dict = checkpoint['state_dict']
143 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
144 | model_dict.update(pretrained_dict)
145 | model.load_state_dict(model_dict)
146 | return model
147 |
148 | def _create_mobilenet_model(arch, pretrained, **kwargs):
149 | model = mobilenet_models.__dict__[arch](**kwargs)
150 | pretrained_pytorch = pretrained == 'pytorch'
151 | pretrained_checkpoint = os.path.isfile(pretrained)
152 | if pretrained_pytorch:
153 | model_dict = model.state_dict()
154 | arch_pretrained = arch[:-3]
155 | logging.info("=> use a pretrained %s model to initialize." % (arch_pretrained))
156 | if 'mobilenet_v2' in arch:
157 | pretrained_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], progress=True)
158 | else:
159 | pretrained_model = mobilenet_models.__dict__[arch_pretrained](pretrained=True, **kwargs)
160 | pretrained_dict = pretrained_model.state_dict()
161 | # filter out unnecessary keys
162 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
163 | # overwrite entries in the existing state dict
164 | model_dict.update(pretrained_dict)
165 | # load the new state dict
166 | model.load_state_dict(model_dict)
167 | elif pretrained_checkpoint:
168 | checkpoint = torch.load(pretrained, map_location=lambda storage, loc: storage)
169 | logging.info("=> loaded checkpoint (prec {:.2f})".format(checkpoint['best_acc']))
170 | model_dict = model.state_dict()
171 | pretrained_dict = checkpoint['state_dict']
172 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
173 | model_dict.update(pretrained_dict)
174 | model.load_state_dict(model_dict)
175 | else:
176 | logging.info("=> no checkpoint found at '{}'".format(pretrained))
177 | return model
178 |
--------------------------------------------------------------------------------
/models/cifar/__init__.py:
--------------------------------------------------------------------------------
1 | from .resdg import *
2 |
--------------------------------------------------------------------------------
/models/cifar/resdg.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 | import torch
4 | import torch.nn as nn
5 | from prettytable import PrettyTable
6 | from ..mask import Mask_s, Mask_c
7 |
8 |
9 | __all__ = ['resdg20_cifar10', 'resdg32_cifar10', 'resdg56_cifar10',
10 | 'resdg110_cifar10']
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
14 | """3x3 convolution with padding"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=dilation, groups=groups, bias=False, dilation=dilation)
17 |
18 |
19 | def conv1x1(in_planes, out_planes, stride=1):
20 | """1x1 convolution"""
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
22 |
23 |
24 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False):
25 | if ceil_mode:
26 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
27 | else:
28 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | expansion = 1
33 | def __init__(self, inplanes, planes, h, w, eta=4,
34 | stride=1, downsample=None, **kwargs):
35 | super(BasicBlock, self).__init__()
36 | # gating modules
37 | self.height = conv2d_out_dim(h, kernel_size=3, stride=stride, padding=1)
38 | self.width = conv2d_out_dim(w, kernel_size=3, stride=stride, padding=1)
39 | self.mask_s = Mask_s(self.height, self.width, inplanes, eta, eta, **kwargs)
40 | self.mask_c = Mask_c(inplanes, planes, **kwargs)
41 | self.upsample = nn.Upsample(size=(self.height, self.width), mode='nearest')
42 | # conv 1
43 | self.conv1 = conv3x3(inplanes, planes, stride)
44 | self.bn1 = nn.BatchNorm2d(planes)
45 | self.relu = nn.ReLU(inplace=True)
46 | # conv 2
47 | self.conv2 = conv3x3(planes, planes)
48 | self.bn2 = nn.BatchNorm2d(planes)
49 | # misc
50 | self.downsample = downsample
51 | self.inplanes, self.planes = inplanes, planes
52 | self.b = eta * eta
53 | self.b_reduce = (eta-1) * (eta-1)
54 | flops_conv1_full = torch.Tensor([9 * self.height * self.width * planes * inplanes])
55 | flops_conv2_full = torch.Tensor([9 * self.height * self.width * planes * planes])
56 | # downsample flops
57 | self.flops_downsample = torch.Tensor([self.height*self.width*planes*inplanes]
58 | )if downsample is not None else torch.Tensor([0])
59 | # full flops
60 | self.flops_full = flops_conv1_full + flops_conv2_full + self.flops_downsample
61 | # mask flops
62 | flops_mks = self.mask_s.get_flops()
63 | flops_mkc = self.mask_c.get_flops()
64 | self.flops_mask = torch.Tensor([flops_mks + flops_mkc])
65 |
66 | def forward(self, input):
67 | x, norm_1, norm_2, flops = input
68 | residual = x
69 | # spatial mask
70 | mask_s_m, norm_s, norm_s_t = self.mask_s(x) # [N, 1, h, w]
71 | mask_s = self.upsample(mask_s_m) # [N, 1, H, W]
72 | # conv 1
73 | mask_c, norm_c, norm_c_t = self.mask_c(x) # [N, C_out, 1, 1]
74 | out = self.conv1(x)
75 | out = self.bn1(out)
76 | out = self.relu(out)
77 | if not self.training:
78 | out = out * mask_c * mask_s
79 | else:
80 | out = out * mask_c
81 | # conv 2
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 | out = out * mask_s
85 | # identity
86 | if self.downsample is not None:
87 | residual = self.downsample(x)
88 | out += residual
89 | out = self.relu(out)
90 | # flops
91 | flops_blk = self.get_flops(mask_s_m, mask_s, mask_c)
92 | flops = torch.cat((flops, flops_blk.unsqueeze(0)))
93 | # norm
94 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0)))
95 | norm_2 = torch.cat((norm_2, torch.cat((norm_c, norm_c_t)).unsqueeze(0)))
96 | return (out, norm_1, norm_2, flops)
97 |
98 | def get_flops(self, mask_s, mask_s_up, mask_c):
99 | s_sum = mask_s.sum((1,2,3))
100 | c_sum = mask_c.sum((1,2,3))
101 | # conv1
102 | flops_conv1 = 9 * self.b * s_sum * c_sum * self.inplanes
103 | # conv2
104 | flops_conv2 = 9 * self.b * s_sum * self.planes * c_sum
105 | # total
106 | flops = flops_conv1 + flops_conv2 + self.flops_downsample.to(flops_conv1.device)
107 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device)))
108 |
109 |
110 | class ResNetCifar10(nn.Module):
111 | def __init__(self, depth, num_classes=10, h=32, w=32, **kwargs):
112 | super(ResNetCifar10, self).__init__()
113 | self.height, self.width = h, w
114 | # Model type specifies number of layers for CIFAR-10 model
115 | n = (depth - 2) // 6
116 | block = BasicBlock
117 | # norm
118 | self._norm_layer = nn.BatchNorm2d
119 | # conv1
120 | self.inplanes = 16
121 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,bias=False)
122 | self.bn1 = nn.BatchNorm2d(16)
123 | self.relu = nn.ReLU(inplace=True)
124 | # residual blocks
125 | self.layer1, h, w = self._make_layer(block, 16, n, h, w, 4, **kwargs)
126 | self.layer2, h, w = self._make_layer(block, 32, n, h, w, 2, stride=2, **kwargs)
127 | self.layer3, h, w = self._make_layer(block, 64, n, h, w, 2, stride=2, **kwargs)
128 | self.avgpool = nn.AvgPool2d(8)
129 | self.fc = nn.Linear(64 * block.expansion, num_classes)
130 | # flops
131 | self.flops_conv1 = torch.Tensor([9 * self.height * self.width * 16 * 3])
132 | self.flops_fc = torch.Tensor([64 * block.expansion * num_classes])
133 | # criterion
134 | self.criterion = None
135 |
136 | for m in self.modules():
137 | if isinstance(m, nn.Conv2d):
138 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
139 | m.weight.data.normal_(0, math.sqrt(2. / n))
140 | elif isinstance(m, nn.BatchNorm2d):
141 | if m.weight is not None and m.bias is not None:
142 | m.weight.data.fill_(1)
143 | m.bias.data.zero_()
144 |
145 | def _make_layer(self, block, planes, blocks, h, w, tile, stride=1, **kwargs):
146 | norm_layer = self._norm_layer
147 | downsample = None
148 | if stride != 1 or self.inplanes != planes * block.expansion:
149 | downsample = nn.Sequential(
150 | nn.Conv2d(self.inplanes, planes * block.expansion,
151 | kernel_size=1, stride=stride, bias=False),
152 | nn.BatchNorm2d(planes * block.expansion),
153 | )
154 | layers = []
155 | layers.append(block(self.inplanes, planes, h, w, tile,
156 | stride, downsample, **kwargs))
157 | h = conv2d_out_dim(h, kernel_size=1, stride=stride, padding=0)
158 | w = conv2d_out_dim(w, kernel_size=1, stride=stride, padding=0)
159 | self.inplanes = planes * block.expansion
160 | for i in range(1, blocks):
161 | layers.append(block(self.inplanes, planes, h, w, tile, **kwargs))
162 | return nn.Sequential(*layers), h, w
163 |
164 | def forward(self, x, label, den_target, lbda, gamma, p):
165 | batch_num, _, _, _ = x.shape
166 | # conv1
167 | x = self.conv1(x)
168 | x = self.bn1(x)
169 | x = self.relu(x) # 32x32
170 | # residual blocks
171 | norm1 = torch.zeros(1, batch_num+1).to(x.device)
172 | norm2 = torch.zeros(1, batch_num+1).to(x.device)
173 | flops = torch.zeros(1, batch_num+2).to(x.device)
174 | x = self.layer1((x, norm1, norm2, flops)) # 32x32
175 | x = self.layer2(x) # 16x16
176 | x, norm1, norm2, flops = self.layer3(x) # 8x8
177 | # fc layer
178 | x = self.avgpool(x)
179 | x = x.view(x.size(0), -1)
180 | x = self.fc(x)
181 | # flops
182 | flops_real = [flops[1:, 0:batch_num].permute(1, 0).contiguous(),
183 | self.flops_conv1.to(x.device), self.flops_fc.to(x.device)]
184 | flops_mask, flops_ori = flops[1:, -2].unsqueeze(0), flops[1:, -1].unsqueeze(0)
185 | # norm
186 | norm_s = norm1[1:, 0:batch_num].permute(1, 0).contiguous()
187 | norm_c = norm2[1:, 0:batch_num].permute(1, 0).contiguous()
188 | norm_s_t, norm_c_t = norm1[1:, -1].unsqueeze(0), norm2[1:, -1].unsqueeze(0)
189 | # get outputs
190 | outputs = {}
191 | outputs["closs"], outputs["rloss"], outputs["bloss"] = self.get_loss(
192 | x, label, batch_num, den_target, lbda, gamma, p,
193 | norm_s, norm_c, norm_s_t, norm_c_t,
194 | flops_real, flops_mask, flops_ori)
195 | outputs["out"] = x
196 | outputs["flops_real"] = flops_real
197 | outputs["flops_mask"] = flops_mask
198 | outputs["flops_ori"] = flops_ori
199 | return outputs
200 |
201 | def set_criterion(self, criterion):
202 | self.criterion = criterion
203 | return
204 |
205 | def get_loss(self, output, label, batch_size, den_target, lbda, gamma, p,
206 | mask_norm_s, mask_norm_c, norm_s_t, norm_c_t,
207 | flops_real, flops_mask, flops_ori):
208 | closs, rloss, bloss = self.criterion(output, label, flops_real, flops_mask,
209 | flops_ori, batch_size, den_target, lbda, mask_norm_s, mask_norm_c,
210 | norm_s_t, norm_c_t, gamma, p)
211 | return closs, rloss, bloss
212 |
213 | def record_flops(self, flops_conv, flops_mask, flops_ori, flops_conv1, flops_fc):
214 | i = 0
215 | table = PrettyTable(['Layer', 'Conv FLOPs', 'Conv %', 'Mask FLOPs', 'Total FLOPs', 'Total %', 'Original FLOPs'])
216 | table.add_row(['layer0'] + ['{flops:.2f}K'.format(flops=flops_conv1/1024)] + [' ' for _ in range(5)])
217 | for name, m in self.named_modules():
218 | if isinstance(m, BasicBlock):
219 | table.add_row([name] + ['{flops:.2f}K'.format(flops=flops_conv[i]/1024)] + ['{per_f:.2f}%'.format(
220 | per_f=flops_conv[i]/flops_ori[i]*100)] + ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] +
221 | ['{total:.2f}K'.format(total=(flops_conv[i]+flops_mask[i])/1024)] + ['{per_t:.2f}%'.format(
222 | per_t=(flops_conv[i]+flops_mask[i])/flops_ori[i]*100)] +
223 | ['{ori:.2f}K'.format(ori=flops_ori[i]/1024)])
224 | i+=1
225 | table.add_row(['fc'] + ['{flops:.2f}K'.format(flops=flops_fc/1024)] + [' ' for _ in range(5)])
226 | table.add_row(['Total'] + ['{flops:.2f}K'.format(flops=(flops_conv[i]+flops_conv1+flops_fc)/1024)] +
227 | ['{per_f:.2f}%'.format(per_f=(flops_conv[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] +
228 | ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] + ['{total:.2f}K'.format(
229 | total=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/1024)] + ['{per_t:.2f}%'.format(
230 | per_t=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] +
231 | ['{ori:.2f}K'.format(ori=(flops_ori[i]+flops_conv1+flops_fc)/1024)])
232 | logging.info('\n{}'.format(table))
233 |
234 |
235 | def resdg20_cifar10(**kwargs):
236 | """
237 | return a ResNet 20 object for cifar-10.
238 | """
239 | return ResNetCifar10(20, **kwargs)
240 |
241 |
242 | def resdg32_cifar10(**kwargs):
243 | """
244 | return a ResNet 32 object for cifar-10.
245 | """
246 | return ResNetCifar10(32, **kwargs)
247 |
248 |
249 | def resdg56_cifar10(**kwargs):
250 | """
251 | return a ResNet 56 object for cifar-10.
252 | """
253 | return ResNetCifar10(56, **kwargs)
254 |
255 |
256 | def resdg110_cifar10(**kwargs):
257 | """
258 | return a ResNet 110 object for cifar-10.
259 | """
260 | return ResNetCifar10(110, **kwargs)
261 |
--------------------------------------------------------------------------------
/models/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | from .resdg import *
2 |
--------------------------------------------------------------------------------
/models/imagenet/resdg.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 | import torch
4 | import torch.nn as nn
5 | from prettytable import PrettyTable
6 | from ..mask import Mask_s, Mask_c
7 |
8 | __all__ = ['resdg18', 'resdg34', 'resdg50']
9 |
10 |
11 | def conv1x1(in_planes, out_planes, stride=1):
12 | """1x1 convolution"""
13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
14 |
15 |
16 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
17 | """3x3 convolution with padding"""
18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=dilation, groups=groups, bias=False, dilation=dilation)
20 |
21 |
22 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False):
23 | if ceil_mode:
24 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
25 | else:
26 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
27 |
28 |
29 | class BasicBlock(nn.Module):
30 | expansion = 1
31 | def __init__(self, inplanes, planes, h, w, eta=8, stride=1,
32 | downsample=None, groups=1, base_width=64, dilation=1,
33 | norm_layer=None, **kwargs):
34 | super(BasicBlock, self).__init__()
35 | if norm_layer is None:
36 | norm_layer = nn.BatchNorm2d
37 | if groups != 1 or base_width != 64:
38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
39 | if dilation > 1:
40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
41 | # gating modules
42 | self.height = conv2d_out_dim(h, kernel_size=3, stride=stride, padding=1)
43 | self.width = conv2d_out_dim(w, kernel_size=3, stride=stride, padding=1)
44 | self.mask_s = Mask_s(self.height, self.width, inplanes, eta, eta, **kwargs)
45 | self.mask_c = Mask_c(inplanes, planes, **kwargs)
46 | self.upsample = nn.Upsample(size=(self.height, self.width), mode='nearest')
47 | # conv 1
48 | self.conv1 = conv3x3(inplanes, planes, stride)
49 | self.bn1 = norm_layer(planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | # conv 2
52 | self.conv2 = conv3x3(planes, planes)
53 | self.bn2 = norm_layer(planes)
54 | # misc
55 | self.downsample = downsample
56 | self.inplanes, self.planes = inplanes, planes
57 | # flops
58 | flops_conv1_full = torch.Tensor([9 * self.height * self.width * planes * inplanes])
59 | flops_conv2_full = torch.Tensor([9 * self.height * self.width * planes * planes])
60 | self.flops_downsample = torch.Tensor([self.height*self.width*planes*inplanes]
61 | )if downsample is not None else torch.Tensor([0])
62 | self.flops_full = flops_conv1_full + flops_conv2_full + self.flops_downsample
63 | # mask flops
64 | flops_mks = self.mask_s.get_flops()
65 | flops_mkc = self.mask_c.get_flops()
66 | self.flops_mask = torch.Tensor([flops_mks + flops_mkc])
67 |
68 | def forward(self, input):
69 | x, norm_1, norm_2, flops = input
70 | residual = x
71 | mask_s_m, norm_s, norm_s_t = self.mask_s(x) # [N, 1, h, w]
72 | mask_c, norm_c, norm_c_t = self.mask_c(x) # [N, C_out, 1, 1]
73 | mask_s = self.upsample(mask_s_m) # [N, 1, H, W]
74 | out = self.conv1(x)
75 | out = self.bn1(out)
76 | out = self.relu(out)
77 | out = out * mask_c * mask_s if not self.training else out * mask_c
78 | # conv 2
79 | out = self.conv2(out)
80 | out = self.bn2(out)
81 | out = out * mask_s
82 | # identity
83 | if self.downsample is not None:
84 | residual = self.downsample(x)
85 | out += residual
86 | out = self.relu(out)
87 | # norm
88 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0)))
89 | norm_2 = torch.cat((norm_2, torch.cat((norm_c, norm_c_t)).unsqueeze(0)))
90 | # flops
91 | flops_blk = self.get_flops(mask_s, mask_c)
92 | flops = torch.cat((flops, flops_blk.unsqueeze(0)))
93 | return (out, norm_1, norm_2, flops)
94 |
95 | def get_flops(self, mask_s_up, mask_c):
96 | s_sum = mask_s_up.sum((1,2,3))
97 | c_sum = mask_c.sum((1,2,3))
98 | # conv1
99 | flops_conv1 = 9 * s_sum * c_sum * self.inplanes
100 | # conv2
101 | flops_conv2 = 9 * s_sum * c_sum * self.planes
102 | # total
103 | flops = flops_conv1 + flops_conv2 + self.flops_downsample.to(flops_conv1.device)
104 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device)))
105 |
106 |
107 | class Bottleneck(nn.Module):
108 | expansion = 4
109 | __constants__ = ['downsample']
110 | def __init__(self, inplanes, planes, h, w, eta=8, stride=1,
111 | downsample=None, groups=1, base_width=64, dilation=1,
112 | norm_layer=None, **kwargs):
113 | super(Bottleneck, self).__init__()
114 | if norm_layer is None:
115 | norm_layer = nn.BatchNorm2d
116 | width = int(planes * (base_width / 64.)) * groups
117 | # spatial gating module
118 | self.height_1, self.width_1 = h, w
119 | self.height_2 = conv2d_out_dim(h, 3, dilation, stride, dilation)
120 | self.width_2 = conv2d_out_dim(w, 3, dilation, stride, dilation)
121 | self.mask_s = Mask_s(self.height_2, self.width_2, inplanes, eta, eta, **kwargs)
122 | self.upsample_1 = nn.Upsample(size=(self.height_1, self.width_1), mode='nearest')
123 | self.upsample_2 = nn.Upsample(size=(self.height_2, self.width_2), mode='nearest')
124 | # conv 1
125 | self.conv1 = conv1x1(inplanes, width)
126 | self.bn1 = norm_layer(width)
127 | self.mask_c1 = Mask_c(inplanes, width, **kwargs)
128 | # conv 2
129 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
130 | self.bn2 = norm_layer(width)
131 | self.mask_c2 = Mask_c(width, width, **kwargs)
132 | # conv 3
133 | self.conv3 = conv1x1(width, planes * self.expansion)
134 | self.bn3 = norm_layer(planes * self.expansion)
135 | # misc
136 | self.relu = nn.ReLU(inplace=True)
137 | self.downsample = downsample
138 | self.inplanes, self.width, self.planes = inplanes, width, planes * self.expansion
139 | # flops
140 | flops_conv1_full = torch.Tensor([self.height_1 * self.width_1 * width * inplanes])
141 | flops_conv2_full = torch.Tensor([9 * self.height_2 * self.width_2 * width * width])
142 | flops_conv3_full = torch.Tensor([self.height_2 * self.width_2 * width * planes*self.expansion])
143 | self.flops_downsample = torch.Tensor([self.height_2*self.width_2*planes*self.expansion*inplanes]
144 | ) if self.downsample is not None else torch.Tensor([0])
145 | self.flops_full = flops_conv1_full+flops_conv2_full+flops_conv3_full+self.flops_downsample
146 | # mask flops
147 | flops_mask_s = self.mask_s.get_flops()
148 | flops_mask_c1 = self.mask_c1.get_flops()
149 | flops_mask_c2 = self.mask_c2.get_flops()
150 | self.flops_mask = torch.Tensor([flops_mask_s + flops_mask_c1 + flops_mask_c2])
151 |
152 | def forward(self, input):
153 | x, norm_1, norm_2, flops = input
154 | identity = x
155 | # spatial mask
156 | mask_s_m, norm_s, norm_s_t = self.mask_s(x) # [N, 1, h, w]
157 | mask_c1, norm_c1, norm_c1_t = self.mask_c1(x)
158 | mask_s1 = self.upsample_1(mask_s_m) # [N, 1, H1, W1]
159 | mask_s = self.upsample_2(mask_s_m) # [N, 1, H2, W2]
160 | # conv 1
161 | out = self.conv1(x)
162 | out = self.bn1(out)
163 | out = self.relu(out)
164 | out = out * mask_c1 * mask_s1 if not self.training else out * mask_c1
165 | # conv 2
166 | mask_c2, norm_c2, norm_c2_t = self.mask_c2(out)
167 | out = self.conv2(out)
168 | out = self.bn2(out)
169 | out = self.relu(out)
170 | out = out * mask_c2 * mask_s if not self.training else out * mask_c2
171 | # conv 3
172 | out = self.conv3(out)
173 | out = self.bn3(out)
174 | out = out * mask_s
175 | # identity
176 | if self.downsample is not None:
177 | identity = self.downsample(x)
178 | out += identity
179 | out = self.relu(out)
180 | # norm
181 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0)))
182 | norm_2 = torch.cat((norm_2, torch.cat((norm_c1, norm_c1_t)).unsqueeze(0)))
183 | norm_2 = torch.cat((norm_2, torch.cat((norm_c2, norm_c2_t)).unsqueeze(0)))
184 | # flops
185 | flops_blk = self.get_flops(mask_s, mask_s1, mask_c1, mask_c2)
186 | flops = torch.cat((flops, flops_blk.unsqueeze(0)))
187 | return (out, norm_1, norm_2, flops)
188 |
189 | def get_flops(self, mask_s, mask_s1, mask_c1, mask_c2):
190 | s_sum = mask_s.sum((1,2,3))
191 | c1_sum, c2_sum = mask_c1.sum((1,2,3)), mask_c2.sum((1,2,3))
192 | # conv
193 | s_sum_1 = mask_s1.sum((1,2,3))
194 | flops_conv1 = s_sum_1 * c1_sum * self.inplanes
195 | flops_conv2 = 9 * s_sum * c2_sum * c1_sum
196 | flops_conv3 = s_sum * self.planes * c2_sum
197 | # total
198 | flops = flops_conv1+flops_conv2+flops_conv3+self.flops_downsample.to(flops_conv1.device)
199 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device)))
200 |
201 |
202 | class ResDG(nn.Module):
203 |
204 | def __init__(self, block, layers, h=224, w=224, num_classes=1000,
205 | zero_init_residual=False, groups=1, width_per_group=64,
206 | replace_stride_with_dilation=None, norm_layer=None, **kwargs):
207 | super(ResDG, self).__init__()
208 | # block
209 | self.height, self.width = h, w
210 | # norm layer
211 | if norm_layer is None:
212 | norm_layer = nn.BatchNorm2d
213 | self._norm_layer = norm_layer
214 |
215 | self.inplanes = 64
216 | self.dilation = 1
217 | if replace_stride_with_dilation is None:
218 | # each element in the tuple indicates if we should replace
219 | # the 2x2 stride with a dilated convolution instead
220 | replace_stride_with_dilation = [False, False, False]
221 | if len(replace_stride_with_dilation) != 3:
222 | raise ValueError("replace_stride_with_dilation should be None "
223 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
224 | self.groups = groups
225 | self.base_width = width_per_group
226 | # conv1
227 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
228 | self.bn1 = norm_layer(self.inplanes)
229 | self.relu = nn.ReLU(inplace=True)
230 | h = conv2d_out_dim(h, kernel_size=7, stride=2, padding=3)
231 | w = conv2d_out_dim(w, kernel_size=7, stride=2, padding=3)
232 | self.flops_conv1 = torch.Tensor([49 * h * w * self.inplanes * 3])
233 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
234 | h = conv2d_out_dim(h, kernel_size=3, stride=2, padding=1)
235 | w = conv2d_out_dim(w, kernel_size=3, stride=2, padding=1)
236 | # residual blocks
237 | self.layer1, h, w = self._make_layer(block, 64, layers[0], h, w, 8, **kwargs)
238 | self.layer2, h, w = self._make_layer(block, 128, layers[1], h, w, 4, stride=2,
239 | dilate=replace_stride_with_dilation[0], **kwargs)
240 | self.layer3, h, w = self._make_layer(block, 256, layers[2], h, w, 2, stride=2,
241 | dilate=replace_stride_with_dilation[1], **kwargs)
242 | self.layer4, h, w = self._make_layer(block, 512, layers[3], h, w, 1, stride=2,
243 | dilate=replace_stride_with_dilation[2], **kwargs)
244 | # fc layer
245 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
246 | self.fc = nn.Linear(512 * block.expansion, num_classes)
247 | self.flops_fc = torch.Tensor([512 * block.expansion * num_classes])
248 | # criterion
249 | self.criterion = None
250 |
251 | for m in self.modules():
252 | if isinstance(m, nn.Conv2d):
253 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
254 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
255 | nn.init.constant_(m.weight, 1)
256 | nn.init.constant_(m.bias, 0)
257 |
258 | # Zero-initialize the last BN in each residual branch,
259 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
260 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
261 | if zero_init_residual:
262 | for m in self.modules():
263 | if isinstance(m, Bottleneck):
264 | nn.init.constant_(m.bn3.weight, 0)
265 | elif isinstance(m, BasicBlock):
266 | nn.init.constant_(m.bn2.weight, 0)
267 |
268 | def _make_layer(self, block, planes, blocks, h, w, tile, stride=1, dilate=False, **kwargs):
269 | norm_layer, downsample, previous_dilation = self._norm_layer, None, self.dilation
270 | mask_s = torch.ones(blocks)
271 | if dilate:
272 | self.dilation *= stride
273 | stride = 1
274 | if stride != 1 or self.inplanes != planes * block.expansion:
275 | downsample = nn.Sequential(
276 | conv1x1(self.inplanes, planes * block.expansion, stride),
277 | norm_layer(planes * block.expansion),
278 | )
279 | layers = []
280 | layers.append(block(self.inplanes, planes, h, w, tile, stride, downsample,
281 | self.groups, self.base_width, previous_dilation, norm_layer, **kwargs))
282 | h = conv2d_out_dim(h, kernel_size=1, stride=stride, padding=0)
283 | w = conv2d_out_dim(w, kernel_size=1, stride=stride, padding=0)
284 | self.inplanes = planes * block.expansion
285 | for i in range(1, blocks):
286 | layers.append(block(self.inplanes, planes, h, w, tile, groups=self.groups,
287 | base_width=self.base_width, dilation=self.dilation,
288 | norm_layer=norm_layer,**kwargs))
289 | return nn.Sequential(*layers), h, w
290 |
291 | def forward(self, x, label, den_target, lbda, gamma, p):
292 | # See note [TorchScript super()]
293 | batch_num, _, _, _ = x.shape
294 | # conv1
295 | x = self.conv1(x)
296 | x = self.bn1(x)
297 | x = self.relu(x)
298 | x = self.maxpool(x)
299 | # residual modules
300 | norm1 = torch.zeros(1, batch_num+1).to(x.device)
301 | norm2 = torch.zeros(1, batch_num+1).to(x.device)
302 | flops = torch.zeros(1, batch_num+2).to(x.device)
303 | x = self.layer1((x, norm1, norm2, flops))
304 | x = self.layer2(x)
305 | x = self.layer3(x)
306 | x, norm1, norm2, flops = self.layer4(x)
307 | # fc layer
308 | x = self.avgpool(x)
309 | x = torch.flatten(x, 1)
310 | x = self.fc(x)
311 | # norm and flops
312 | norm_s = norm1[1:, 0:batch_num].permute(1, 0).contiguous()
313 | norm_c = norm2[1:, 0:batch_num].permute(1, 0).contiguous()
314 | norm_s_t = norm1[1:, -1].unsqueeze(0)
315 | norm_c_t = norm2[1:, -1].unsqueeze(0)
316 | flops_real = [flops[1:, 0:batch_num].permute(1, 0).contiguous(),
317 | self.flops_conv1.to(x.device), self.flops_fc.to(x.device)]
318 | flops_mask = flops[1:, -2].unsqueeze(0)
319 | flops_ori = flops[1:, -1].unsqueeze(0)
320 | # get outputs
321 | outputs = {}
322 | outputs["closs"], outputs["rloss"], outputs["bloss"] = self.get_loss(
323 | x, label, batch_num, den_target, lbda, gamma, p,
324 | norm_s, norm_c, norm_s_t, norm_c_t,
325 | flops_real, flops_mask, flops_ori)
326 | outputs["out"] = x
327 | outputs["flops_real"] = flops_real
328 | outputs["flops_mask"] = flops_mask
329 | outputs["flops_ori"] = flops_ori
330 | return outputs
331 |
332 | def set_criterion(self, criterion):
333 | self.criterion = criterion
334 | return
335 |
336 | def get_loss(self, output, label, batch_size, den_target, lbda, gamma, p,
337 | mask_norm_s, mask_norm_c, norm_s_t, norm_c_t,
338 | flops_real, flops_mask, flops_ori):
339 | closs, rloss, bloss = self.criterion(output, label, flops_real, flops_mask,
340 | flops_ori, batch_size, den_target, lbda, mask_norm_s, mask_norm_c,
341 | norm_s_t, norm_c_t, gamma, p)
342 | return closs, rloss, bloss
343 |
344 | def record_flops(self, flops_conv, flops_mask, flops_ori, flops_conv1, flops_fc):
345 | i = 0
346 | table = PrettyTable(['Layer', 'Conv FLOPs', 'Conv %', 'Mask FLOPs', 'Total FLOPs', 'Total %', 'Original FLOPs'])
347 | table.add_row(['layer0'] + ['{flops:.2f}K'.format(flops=flops_conv1/1024)] + [' ' for _ in range(5)])
348 | for name, m in self.named_modules():
349 | if isinstance(m, (BasicBlock, Bottleneck)):
350 | table.add_row([name] + ['{flops:.2f}K'.format(flops=flops_conv[i]/1024)] + ['{per_f:.2f}%'.format(
351 | per_f=flops_conv[i]/flops_ori[i]*100)] + ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] +
352 | ['{total:.2f}K'.format(total=(flops_conv[i]+flops_mask[i])/1024)] + ['{per_t:.2f}%'.format(
353 | per_t=(flops_conv[i]+flops_mask[i])/flops_ori[i]*100)] +
354 | ['{ori:.2f}K'.format(ori=flops_ori[i]/1024)])
355 | i+=1
356 | table.add_row(['fc'] + ['{flops:.2f}K'.format(flops=flops_fc/1024)] + [' ' for _ in range(5)])
357 | table.add_row(['Total'] + ['{flops:.2f}K'.format(flops=(flops_conv[i]+flops_conv1+flops_fc)/1024)] +
358 | ['{per_f:.2f}%'.format(per_f=(flops_conv[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] +
359 | ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] + ['{total:.2f}K'.format(
360 | total=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/1024)] + ['{per_t:.2f}%'.format(
361 | per_t=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] +
362 | ['{ori:.2f}K'.format(ori=(flops_ori[i]+flops_conv1+flops_fc)/1024)])
363 | logging.info('\n{}'.format(table))
364 |
365 |
366 | def _resdg(arch, block, layers, **kwargs):
367 | model = ResDG(block, layers, **kwargs)
368 | return model
369 |
370 |
371 | def resdg18(**kwargs):
372 | r"""ResNet-18 model from
373 | `"Deep Residual Learning for Image Recognition" `_
374 | """
375 | return _resdg('resdg18', BasicBlock, [2, 2, 2, 2], **kwargs)
376 |
377 |
378 | def resdg34(**kwargs):
379 | r"""ResNet-34 model from
380 | `"Deep Residual Learning for Image Recognition" `_
381 | """
382 | return _resdg('resdg34', BasicBlock, [3, 4, 6, 3], **kwargs)
383 |
384 |
385 | def resdg50(**kwargs):
386 | r"""ResNet-50 model from
387 | `"Deep Residual Learning for Image Recognition" `_
388 | """
389 | return _resdg('resdg50', Bottleneck, [3, 4, 6, 3], **kwargs)
390 |
--------------------------------------------------------------------------------
/models/mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 |
7 |
8 | class GumbelSoftmax(nn.Module):
9 | '''
10 | gumbel softmax gate.
11 | '''
12 | def __init__(self, eps=1):
13 | super(GumbelSoftmax, self).__init__()
14 | self.eps = eps
15 | self.sigmoid = nn.Sigmoid()
16 |
17 | def gumbel_sample(self, template_tensor, eps=1e-8):
18 | uniform_samples_tensor = template_tensor.clone().uniform_()
19 | gumble_samples_tensor = torch.log(uniform_samples_tensor+eps)-torch.log(
20 | 1-uniform_samples_tensor+eps)
21 | return gumble_samples_tensor
22 |
23 | def gumbel_softmax(self, logits):
24 | """ Draw a sample from the Gumbel-Softmax distribution"""
25 | gsamples = self.gumbel_sample(logits.data)
26 | logits = logits + Variable(gsamples)
27 | soft_samples = self.sigmoid(logits / self.eps)
28 | return soft_samples, logits
29 |
30 | def forward(self, logits):
31 | if not self.training:
32 | out_hard = (logits>=0).float()
33 | return out_hard
34 | out_soft, prob_soft = self.gumbel_softmax(logits)
35 | out_hard = ((out_soft >= 0.5).float() - out_soft).detach() + out_soft
36 | return out_hard
37 |
38 |
39 | class Mask_s(nn.Module):
40 | '''
41 | Attention Mask spatial.
42 | '''
43 | def __init__(self, h, w, planes, block_w, block_h, eps=0.66667,
44 | bias=-1, **kwargs):
45 | super(Mask_s, self).__init__()
46 | # Parameter
47 | self.width, self.height, self.channel = w, h, planes
48 | self.mask_h, self.mask_w = int(np.ceil(h / block_h)), int(np.ceil(w / block_w))
49 | self.eleNum_s = torch.Tensor([self.mask_h*self.mask_w])
50 | # spatial attention
51 | self.atten_s = nn.Conv2d(planes, 1, kernel_size=3, stride=1, bias=bias>=0, padding=1)
52 | if bias>=0:
53 | nn.init.constant_(self.atten_s.bias, bias)
54 | # Gate
55 | self.gate_s = GumbelSoftmax(eps=eps)
56 | # Norm
57 | self.norm = lambda x: torch.norm(x, p=1, dim=(1,2,3))
58 |
59 | def forward(self, x):
60 | batch, channel, height, width = x.size()
61 | # Pooling
62 | input_ds = F.adaptive_avg_pool2d(input=x, output_size=(self.mask_h, self.mask_w))
63 | # spatial attention
64 | s_in = self.atten_s(input_ds) # [N, 1, h, w]
65 | # spatial gate
66 | mask_s = self.gate_s(s_in) # [N, 1, h, w]
67 | # norm
68 | norm = self.norm(mask_s)
69 | norm_t = self.eleNum_s.to(x.device)
70 | return mask_s, norm, norm_t
71 |
72 | def get_flops(self):
73 | flops = self.mask_h * self.mask_w * self.channel * 9
74 | return flops
75 |
76 |
77 | class Mask_c(nn.Module):
78 | '''
79 | Attention Mask.
80 | '''
81 | def __init__(self, inplanes, outplanes, fc_reduction=4, eps=0.66667, bias=-1, **kwargs):
82 | super(Mask_c, self).__init__()
83 | # Parameter
84 | self.bottleneck = inplanes // fc_reduction
85 | self.inplanes, self.outplanes = inplanes, outplanes
86 | self.eleNum_c = torch.Tensor([outplanes])
87 | # channel attention
88 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
89 | self.atten_c = nn.Sequential(
90 | nn.Conv2d(inplanes, self.bottleneck, kernel_size=1, stride=1, bias=False),
91 | nn.BatchNorm2d(self.bottleneck),
92 | nn.ReLU(inplace=True),
93 | nn.Conv2d(self.bottleneck, outplanes, kernel_size=1, stride=1, bias=bias>=0),
94 | )
95 | if bias>=0:
96 | nn.init.constant_(self.atten_c[3].bias, bias)
97 | # Gate
98 | self.gate_c = GumbelSoftmax(eps=eps)
99 | # Norm
100 | self.norm = lambda x: torch.norm(x, p=1, dim=(1,2,3))
101 |
102 | def forward(self, x):
103 | batch, channel, _, _ = x.size()
104 | context = self.avg_pool(x) # [N, C, 1, 1]
105 | # transform
106 | c_in = self.atten_c(context) # [N, C_out, 1, 1]
107 | # channel gate
108 | mask_c = self.gate_c(c_in) # [N, C_out, 1, 1]
109 | # norm
110 | norm = self.norm(mask_c)
111 | norm_t = self.eleNum_c.to(x.device)
112 | return mask_c, norm, norm_t
113 |
114 | def get_flops(self):
115 | flops = self.inplanes * self.bottleneck + self.bottleneck * self.outplanes
116 | return flops
117 |
--------------------------------------------------------------------------------
/models/mobilenet_v2/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .mobilenet_v2_dg import *
4 |
--------------------------------------------------------------------------------
/models/mobilenet_v2/mobilenet_v2_dg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import logging
4 | from torch import nn
5 | from prettytable import PrettyTable
6 | from .mobilenet_v2_dg_util import ConvBNReLU_1st, InvertedResidual
7 |
8 |
9 | __all__ = ['mobilenet_v2_dg']
10 |
11 |
12 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False):
13 | if ceil_mode:
14 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
15 | else:
16 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
17 |
18 |
19 | def _make_divisible(v, divisor, min_value=None):
20 | """
21 | This function is taken from the original tf repo.
22 | It ensures that all layers have a channel number that is divisible by 8
23 | It can be seen here:
24 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
25 | :param v:
26 | :param divisor:
27 | :param min_value:
28 | :return:
29 | """
30 | if min_value is None:
31 | min_value = divisor
32 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
33 | # Make sure that round down does not go down by more than 10%.
34 | if new_v < 0.9 * v:
35 | new_v += divisor
36 | return new_v
37 |
38 |
39 | class MobileNetV2(nn.Module):
40 | def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None,
41 | round_nearest=8, in_size=(224, 224), block = InvertedResidual, **kwargs):
42 | """
43 | MobileNet V2 main class
44 |
45 | Args:
46 | num_classes (int): Number of classes
47 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
48 | inverted_residual_setting: Network structure
49 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
50 | Set to 1 to turn off rounding
51 | """
52 | super(MobileNetV2, self).__init__()
53 | input_channel = 32
54 | last_channel = 1280
55 | h, w = in_size
56 |
57 | if inverted_residual_setting is None:
58 | inverted_residual_setting = [
59 | # t, c, n, s, tile
60 | [1, 16, 1, 1, 16],
61 | [6, 24, 2, 2, 8],
62 | [6, 32, 3, 2, 4],
63 | [6, 64, 4, 2, 2],
64 | [6, 96, 3, 1, 2],
65 | [6, 160, 3, 2, 2],
66 | [6, 320, 1, 1, 2],
67 | ]
68 | # only check the first element, assuming user knows t,c,n,s are required
69 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 5:
70 | raise ValueError("inverted_residual_setting should be non-empty "
71 | "or a 5-element list, got {}".format(inverted_residual_setting))
72 | # building first layer
73 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
74 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
75 | features = [ConvBNReLU_1st(3, input_channel, stride=2)]
76 | h = conv2d_out_dim(h, kernel_size=3, stride=2, padding=1)
77 | w = conv2d_out_dim(w, kernel_size=3, stride=2, padding=1)
78 | self.flops_conv1 = torch.Tensor([3 * h * w * 3 * input_channel])
79 | # building inverted residual blocks
80 | for t, c, n, s, tile in inverted_residual_setting:
81 | output_channel = _make_divisible(c * width_mult, round_nearest)
82 | for i in range(n):
83 | stride = s if i == 0 else 1
84 | features.append(block(input_channel, output_channel, stride,
85 | expand_ratio=t, h=h, w=w, eta=tile, **kwargs))
86 | h = conv2d_out_dim(h, kernel_size=3, stride=stride, padding=1)
87 | w = conv2d_out_dim(w, kernel_size=3, stride=stride, padding=1)
88 | input_channel = output_channel
89 | # building last several layers
90 | features.append(ConvBNReLU_1st(input_channel, self.last_channel, kernel_size=1))
91 | self.flops_fc = torch.Tensor([input_channel * self.last_channel *h*w])
92 | # make it nn.Sequential
93 | self.features = nn.Sequential(*features)
94 | # building classifier
95 | self.classifier = nn.Sequential(
96 | nn.Dropout(0.2),
97 | nn.Linear(self.last_channel, num_classes),
98 | )
99 | self.flops_fc = self.flops_fc + torch.Tensor([num_classes * self.last_channel])
100 | # criterion
101 | self.criterion = None
102 |
103 | # weight initialization
104 | for m in self.modules():
105 | if isinstance(m, nn.Conv2d):
106 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
107 | if m.bias is not None:
108 | nn.init.zeros_(m.bias)
109 | elif isinstance(m, nn.BatchNorm2d):
110 | nn.init.ones_(m.weight)
111 | nn.init.zeros_(m.bias)
112 | elif isinstance(m, nn.Linear):
113 | nn.init.normal_(m.weight, 0, 0.01)
114 | nn.init.zeros_(m.bias)
115 |
116 | def forward(self, x, label, den_target, lbda, gamma, p):
117 | batch_num, _, _, _ = x.shape
118 | norm1 = torch.zeros(1, batch_num+1).to(x.device)
119 | norm2 = torch.zeros(1, batch_num+1).to(x.device)
120 | flops = torch.zeros(1, batch_num+2).to(x.device)
121 | x, norm1, norm2, flops = self.features((x, norm1, norm2, flops))
122 | x = x.mean([2, 3])
123 | x = self.classifier(x)
124 | # norm and flops
125 | norm_s = norm1[1:, 0:batch_num].permute(1, 0).contiguous()
126 | norm_c = norm2[1:, 0:batch_num].permute(1, 0).contiguous()
127 | norm_s_t = norm1[1:, -1].unsqueeze(0)
128 | norm_c_t = norm2[1:, -1].unsqueeze(0)
129 | flops_real = [flops[1:, 0:batch_num].permute(1, 0).contiguous(),
130 | self.flops_conv1.to(x.device), self.flops_fc.to(x.device)]
131 | flops_mask = flops[1:, -2].unsqueeze(0)
132 | flops_ori = flops[1:, -1].unsqueeze(0)
133 | # get outputs
134 | outputs = {}
135 | outputs["closs"], outputs["rloss"], outputs["bloss"] = self.get_loss(
136 | x, label, batch_num, den_target, lbda, gamma, p,
137 | norm_s, norm_c, norm_s_t, norm_c_t,
138 | flops_real, flops_mask, flops_ori)
139 | outputs["out"] = x
140 | outputs["flops_real"] = flops_real
141 | outputs["flops_mask"] = flops_mask
142 | outputs["flops_ori"] = flops_ori
143 | return outputs
144 |
145 | def set_criterion(self, criterion):
146 | self.criterion = criterion
147 | return
148 |
149 | def get_loss(self, output, label, batch_size, den_target, lbda, gamma, p,
150 | mask_norm_s, mask_norm_c, norm_s_t, norm_c_t,
151 | flops_real, flops_mask, flops_ori):
152 | closs, rloss, bloss = self.criterion(output, label, flops_real, flops_mask,
153 | flops_ori, batch_size, den_target, lbda, mask_norm_s, mask_norm_c,
154 | norm_s_t, norm_c_t, gamma, p)
155 | return closs, rloss, bloss
156 |
157 | def record_flops(self, flops_conv, flops_mask, flops_ori, flops_conv1, flops_fc):
158 | i = 0
159 | table = PrettyTable(['Layer', 'Conv FLOPs', 'Conv %', 'Mask FLOPs', 'Total FLOPs', 'Total %', 'Original FLOPs'])
160 | table.add_row(['layer0'] + ['{flops:.2f}K'.format(flops=flops_conv1/1024)] + [' ' for _ in range(5)])
161 | for name, m in self.named_modules():
162 | if isinstance(m, InvertedResidual):
163 | table.add_row([name] + ['{flops:.2f}K'.format(flops=flops_conv[i]/1024)] + ['{per_f:.2f}%'.format(
164 | per_f=flops_conv[i]/flops_ori[i]*100)] + ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] +
165 | ['{total:.2f}K'.format(total=(flops_conv[i]+flops_mask[i])/1024)] + ['{per_t:.2f}%'.format(
166 | per_t=(flops_conv[i]+flops_mask[i])/flops_ori[i]*100)] +
167 | ['{ori:.2f}K'.format(ori=flops_ori[i]/1024)])
168 | i+=1
169 | table.add_row(['fc'] + ['{flops:.2f}K'.format(flops=flops_fc/1024)] + [' ' for _ in range(5)])
170 | table.add_row(['Total'] + ['{flops:.2f}K'.format(flops=(flops_conv[i]+flops_conv1+flops_fc)/1024)] +
171 | ['{per_f:.2f}%'.format(per_f=(flops_conv[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] +
172 | ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] + ['{total:.2f}K'.format(
173 | total=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/1024)] + ['{per_t:.2f}%'.format(
174 | per_t=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] +
175 | ['{ori:.2f}K'.format(ori=(flops_ori[i]+flops_conv1+flops_fc)/1024)])
176 | logging.info('\n{}'.format(table))
177 |
178 |
179 | def mobilenet_v2_dg(**kwargs):
180 | return MobileNetV2(block=InvertedResidual, **kwargs)
181 |
--------------------------------------------------------------------------------
/models/mobilenet_v2/mobilenet_v2_dg_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from torch import nn
4 | from ..mask import Mask_s, Mask_c
5 |
6 |
7 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False):
8 | if ceil_mode:
9 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
10 | else:
11 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1))
12 |
13 | class ConvBNReLU(nn.Sequential):
14 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
15 | padding = (kernel_size - 1) // 2
16 | super(ConvBNReLU, self).__init__(
17 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
18 | nn.BatchNorm2d(out_planes),
19 | nn.ReLU6(inplace=True)
20 | )
21 |
22 |
23 | class ConvBNReLU_1st(nn.Sequential):
24 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
25 | padding = (kernel_size - 1) // 2
26 | super(ConvBNReLU_1st, self).__init__(
27 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
28 | nn.BatchNorm2d(out_planes),
29 | nn.ReLU6(inplace=True)
30 | )
31 |
32 | def forward(self, input):
33 | x, norm_1, norm_2, flops = input
34 | x = super(ConvBNReLU_1st, self).forward(x)
35 | return x, norm_1, norm_2, flops
36 |
37 |
38 | class Sequential_DG(nn.Sequential):
39 | def __init__(self, layers):
40 | super(Sequential_DG, self).__init__(*layers)
41 | self._module_num = len(layers)
42 |
43 | def forward(self, input):
44 | x, mask_c, mask_s1, mask_s2 = input
45 | i = 0
46 | for module in self._modules.values():
47 | if self.training:
48 | if i == self._module_num-2:
49 | x = x * mask_c
50 | x = module(x)
51 | else:
52 | if i == 0:
53 | x = module(x) * mask_s1
54 | elif i == self._module_num-2:
55 | x = x * mask_c * mask_s2
56 | x = module(x)
57 | else:
58 | x = module(x)
59 | i += 1
60 | return x
61 |
62 |
63 | class InvertedResidual(nn.Module):
64 | def __init__(self, inp, oup, stride, expand_ratio, h, w, eta, **kwargs):
65 | super(InvertedResidual, self).__init__()
66 | self.stride = stride
67 | assert stride in [1, 2]
68 |
69 | self.height = conv2d_out_dim(h, kernel_size=3, stride=stride, padding=1)
70 | self.width = conv2d_out_dim(w, kernel_size=3, stride=stride, padding=1)
71 | self.spatial = self.height * self.width
72 | self.expand = expand_ratio == 1
73 | hidden_dim = int(round(inp * expand_ratio))
74 | self.use_res_connect = self.stride == 1 and inp == oup
75 |
76 | layers = []
77 | if expand_ratio != 1:
78 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
79 | layers.extend([
80 | # dw
81 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
82 | # pw-linear
83 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
84 | nn.BatchNorm2d(oup),
85 | ])
86 | if self.use_res_connect:
87 | self.conv = Sequential_DG(layers)
88 | # channel mask
89 | self.mask_c = Mask_c(inp, hidden_dim, **kwargs)
90 | flops_mkc = self.mask_c.get_flops()
91 | # spatial mask
92 | self.mask_s = Mask_s(self.height, self.width, inp, eta, eta, **kwargs)
93 | self.upsample = nn.Upsample(size=(h, w), mode='nearest')
94 | flops_mks = self.mask_s.get_flops()
95 | else:
96 | self.conv = nn.Sequential(*layers)
97 | flops_mkc, flops_mks = 0, 0
98 | self.norm_c_t = torch.Tensor([hidden_dim])
99 | self.norm_s_t = torch.Tensor([self.spatial])
100 | # misc
101 | self.inp, self.oup = inp, oup
102 | self.hidden_dim = hidden_dim
103 | # flops
104 | flops_dw_full = torch.Tensor([9 * self.spatial * hidden_dim])
105 | flops_pw_full = torch.Tensor([self.spatial * hidden_dim * oup])
106 | self.flops_full = flops_dw_full + flops_pw_full
107 | if expand_ratio != 1:
108 | self.flops_full = self.flops_full + torch.Tensor([h * w * hidden_dim * inp])
109 | self.upsample1 = nn.Upsample(size=(h, w), mode='nearest')
110 | # mask flops
111 | self.flops_mask = torch.Tensor([flops_mks + flops_mkc])
112 |
113 | def forward(self, input):
114 | if not self.use_res_connect:
115 | x, norm_1, norm_2, flops = input
116 | x = self.conv(x)
117 | norm_s = torch.ones((x.shape[0], self.spatial), device=x.device).sum(1)
118 | norm_c = torch.ones((x.shape[0], self.hidden_dim), device=x.device).sum(1)
119 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, self.norm_s_t.to(x.device))).unsqueeze(0)))
120 | norm_2 = torch.cat((norm_2, torch.cat((norm_c, self.norm_c_t.to(x.device))).unsqueeze(0)))
121 | flops_blk = torch.cat((torch.ones(x.shape[0])*self.flops_full, self.flops_mask, self.flops_full)).to(flops.device)
122 | flops = torch.cat((flops, flops_blk.unsqueeze(0)))
123 | return (x, norm_1, norm_2, flops)
124 | else:
125 | x_in, norm_1, norm_2, flops = input
126 | # channel mask
127 | mask_c, norm_c, norm_c_t = self.mask_c(x_in) # [N, C_out, 1, 1]
128 | # spatial mask
129 | mask_s_m, norm_s, norm_s_t = self.mask_s(x_in) # [N, 1, h, w]
130 | mask_s1 = self.upsample1(mask_s_m) # [N, 1, H1, W1]
131 | mask_s = self.upsample(mask_s_m) # [N, 1, H, W]
132 | x = self.conv((x_in, mask_c, mask_s1, mask_s))
133 | x = x * mask_s
134 | # norm
135 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0)))
136 | norm_2 = torch.cat((norm_2, torch.cat((norm_c, norm_c_t)).unsqueeze(0)))
137 | # flops
138 | flops_blk = self.get_flops(mask_c, mask_s)
139 | flops = torch.cat((flops, flops_blk.unsqueeze(0)))
140 | return (x+x_in, norm_1, norm_2, flops)
141 |
142 | def get_flops(self, mask_c, mask_s_up):
143 | s_sum = mask_s_up.sum((1,2,3))
144 | c_sum = mask_c.sum((1,2,3))
145 | # convdw
146 | flops_dw = 9 * s_sum * c_sum
147 | # convpw
148 | flops_pw = s_sum * c_sum * self.oup
149 | # conv1x1
150 | flops = flops_dw + flops_pw
151 | if not self.expand:
152 | mask_s_1 = self.upsample1(mask_s_up)
153 | flops = flops + mask_s_1.sum((1,2,3)) * c_sum * self.inp
154 | # total
155 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device)))
156 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import models
3 |
4 |
5 | # Parse arguments
6 | parser = argparse.ArgumentParser(description='PyTorch Training')
7 | # Datasets
8 | parser.add_argument('-d', '--data', default='path to dataset', type=str)
9 | parser.add_argument('-dset', '--dataset', default='dataset', type=str)
10 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 4)')
11 | # Architecture
12 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=models.ALL_MODEL_NAMES,
13 | help='model architecture: ' + ' | '.join(models.ALL_MODEL_NAMES) + ' (default: resnet50)')
14 | # Optimization options
15 | parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)')
16 | parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR',
17 | help='initial learning rate (default: 0.001 | for inception recommend 0.0256)')
18 | parser.add_argument('--lr-decay', default=0.1, type=float, metavar='LD',
19 | help='every lr-decay-step epochs learning rate decays by LD (default:0.1 | for inception recommend 0.16)')
20 | parser.add_argument('--lr-mode', default='step', type=str, help='learning rate mode')
21 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum (default: 0.9)')
22 | parser.add_argument('--weight-decay', '-wd', default=1e-4, type=float, metavar='WD', help='weight decay for sgd (default: 1e-4)')
23 | parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], help='Decrease learning rate at these epochs.')
24 | parser.add_argument('--den-target', default=0.5, type=float, help='target density of the mask.')
25 | parser.add_argument('--lbda', default=5, type=float, help='penalty factor of the L2 loss for mask.')
26 | parser.add_argument('--gamma', default=1, type=float, help='penalty factor of the L2 loss for balance gate.')
27 | parser.add_argument('--alpha', default=5e-2, type=float, help='alpha in exp annealing.')
28 | # Training
29 | parser.add_argument('--epochs', default=300, type=int, metavar='EPOCHS', help='number of total iteration to run.')
30 | # Device options
31 | parser.add_argument('--gpu-id', default='-1', type=str, help='id(s) for CUDA_VISIBLE_DEVICES')
32 | # Miscs
33 | parser.add_argument('--manualSeed', type=int, help='manual seed')
34 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set')
35 | parser.add_argument('--pretrained', default='', type=str, metavar='PATH',
36 | help='use pre-trained model: ''pytorch: use pytorch official | path to self-trained model')
37 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
38 | help='path to store the checkpoint and log checkpoint path = ./checkpoints/PATH, log path = ./logs/PATH')
39 | parser.add_argument('--bias', default=2, type=float, help='initial value of the bias in the last fc layer of mask module.')
40 |
--------------------------------------------------------------------------------
/regularization.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class spar_loss(nn.Module):
8 | def __init__(self):
9 | super(spar_loss, self).__init__()
10 |
11 | def forward(self, flops_real, flops_mask, flops_ori, batch_size, den_target, lbda):
12 | # total sparsity
13 | flops_tensor, flops_conv1, flops_fc = flops_real[0], flops_real[1], flops_real[2]
14 | # block flops
15 | flops_conv = flops_tensor[0:batch_size,:].mean(0).sum()
16 | flops_mask = flops_mask.mean(0).sum()
17 | flops_ori = flops_ori.mean(0).sum() + flops_conv1.mean() + flops_fc.mean()
18 | flops_real = flops_conv + flops_mask + flops_conv1.mean() + flops_fc.mean()
19 | # loss
20 | rloss = lbda * (flops_real / flops_ori - den_target)**2
21 | return rloss
22 |
23 |
24 | class blance_loss(nn.Module):
25 | def __init__(self):
26 | super(blance_loss, self).__init__()
27 |
28 | def forward(self, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size,
29 | den_target, gamma, p):
30 | norm_s = mask_norm_s
31 | norm_s_t = norm_s_t.mean(0)
32 | norm_c = mask_norm_c
33 | norm_c_t = norm_c_t.mean(0)
34 | den_s = norm_s[0:batch_size,:].mean(0) / norm_s_t
35 | den_c = norm_c[0:batch_size,:].mean(0) / norm_c_t
36 | den_tar = math.sqrt(den_target)
37 | bloss_s = get_bloss_basic(den_s, den_tar, batch_size, gamma, p)
38 | bloss_c = get_bloss_basic(den_c, den_tar, batch_size, gamma, p)
39 | bloss = bloss_s + bloss_c
40 | return bloss
41 |
42 |
43 | def get_bloss_basic(spar, spar_tar, batch_size, gamma, p):
44 | # bound
45 | bloss_l = (F.relu(p*spar_tar-spar)**2).mean()
46 | bloss_u = (F.relu(spar-1+p-p*spar_tar)**2).mean()
47 | bloss = gamma * (bloss_l + bloss_u)
48 | return bloss
49 |
50 |
51 | class Loss(nn.Module):
52 | def __init__(self):
53 | super(Loss, self).__init__()
54 | self.task_loss = nn.CrossEntropyLoss()
55 | self.spar_loss = spar_loss()
56 | self.balance_loss = blance_loss()
57 |
58 | def forward(self, output, targets, flops_real, flops_mask, flops_ori, batch_size,
59 | den_target, lbda, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t,
60 | gamma, p):
61 | closs = self.task_loss(output, targets)
62 | sloss = self.spar_loss(flops_real, flops_mask, flops_ori, batch_size, den_target, lbda)
63 | bloss = self.balance_loss(mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size,
64 | den_target, gamma, p)
65 | return closs, sloss, bloss
66 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | prettytable
2 | matplotlib
3 | pretrainedmodels
4 | numpy
5 |
--------------------------------------------------------------------------------
/scripts/cifar_e.sh:
--------------------------------------------------------------------------------
1 | echo "Dynamic dual gating model $1 for cifar-10."
2 |
3 | time python main.py -d $2 -dset cifar10 -j 2 -a $1 -b 128 \
4 | --checkpoint $4 --gpu-id $3 --pretrained $5 -e
5 |
--------------------------------------------------------------------------------
/scripts/cifar_t.sh:
--------------------------------------------------------------------------------
1 | echo "Dynamic dual gating model $1 for cifar-10."
2 |
3 | dgnet_cifar10(){
4 | time python main.py -d $2 -dset cifar10 -j 2 -a $1 -b 128 -lr 0.1 \
5 | --weight-decay 5e-4 --schedule 150 225 --checkpoint $5 \
6 | --gpu-id $4 --den-target $3 --alpha 2e-2 --pretrained $6
7 | }
8 |
9 | checkpoint1="$5_varience1"
10 | checkpoint2="$5_varience2"
11 | checkpoint3="$5_varience3"
12 |
13 | dgnet_cifar10 $1 $2 $3 $4 $checkpoint1 $6
14 | dgnet_cifar10 $1 $2 $3 $4 $checkpoint2 $6
15 | dgnet_cifar10 $1 $2 $3 $4 $checkpoint3 $6
16 |
--------------------------------------------------------------------------------
/scripts/imagenet_e.sh:
--------------------------------------------------------------------------------
1 | echo "Dynamic dual gating model $1 for ImageNet."
2 |
3 | time python main.py -d $2 -dset imagenet -a $1 \
4 | --checkpoint $4 --gpu-id $3 --pretrained $5 -e
--------------------------------------------------------------------------------
/scripts/imagenet_t.sh:
--------------------------------------------------------------------------------
1 | echo "Dynamic dual gating model $1 for ImageNet."
2 |
3 | time python main.py -d $2 -dset imagenet -a $1 -lr 0.05 \
4 | --weight-decay 1e-4 --epochs 100 --checkpoint $5 --gpu-id $4 \
5 | --den-target $3 --pretrained pytorch --lr-mode cosine
6 |
--------------------------------------------------------------------------------
/scripts/mobilenet_v2_e.sh:
--------------------------------------------------------------------------------
1 | echo "Dynamic dual gating model $1 for ImageNet."
2 |
3 | time python main.py -d $2 -dset imagenet -a $1 \
4 | --checkpoint $4 --gpu-id $3 --pretrained $5 -e
5 |
--------------------------------------------------------------------------------
/scripts/mobilenet_v2_t.sh:
--------------------------------------------------------------------------------
1 | echo "Dynamic dual gating model $1 for ImageNet."
2 |
3 | time python main.py -d $2 -dset imagenet -a $1 -lr 0.05 \
4 | --weight-decay 4e-5 --epochs 200 --checkpoint $5 \
5 | --gpu-id $4 --den-target $3 --pretrained pytorch \
6 | --lr-mode cosine
7 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import get_loggers
2 | from .misc import AverageMeter, accuracy
3 | from .misc import analyse_flops, ExpAnnealing
4 |
5 | # progress bar
6 | import os, sys
7 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress"))
8 | from progress.bar import Bar as Bar
9 |
10 | __all__ = ['AverageMeter', 'Bar', 'accuracy', 'get_loggers',
11 | 'analyse_flops', 'ExpAnnealing']
12 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import matplotlib.pyplot as plt
3 | import os
4 | import datetime
5 | import torch
6 | import logging
7 | import errno
8 | import numpy as np
9 | from logging import handlers
10 |
11 | __all__ = ['Logger', 'LoggerMonitor', 'savefig', 'get_loggers']
12 |
13 |
14 | def savefig(fname, dpi=None):
15 | dpi = 150 if dpi is None else dpi
16 | plt.savefig(fname, dpi=dpi)
17 |
18 |
19 | def plot_overlap(logger, names=None):
20 | names = logger.names if names is None else names
21 | numbers = logger.numbers
22 | for _, name in enumerate(names):
23 | x = np.arange(len(numbers[name]))
24 | plt.plot(x, np.asarray(numbers[name]))
25 | return [logger.title + '(' + name + ')' for name in names]
26 |
27 |
28 | def get_loggers(args):
29 | """
30 | Generate loggers
31 |
32 | Args:
33 | - args : config information
34 | """
35 | # log file and checkpoint file
36 | checkpoint = args.checkpoint
37 | arch = args.arch
38 | if(checkpoint == ''):
39 | dir_name = arch + '_' + datetime.datetime.now().strftime('%m%d_%H%M')
40 | else:
41 | dir_name = checkpoint
42 | log_dir = os.path.join('logs', dir_name)
43 | checkpoint_dir = log_dir
44 | print('\n--------------------------------------------------------')
45 | if not os.path.isdir(checkpoint_dir):
46 | mkdir_p(log_dir)
47 | mkdir_p(checkpoint_dir)
48 | print("=> make directory '{}'".format(log_dir))
49 | else:
50 | print("=> directory '{}' exists".format(log_dir))
51 |
52 | train_log = Logger(os.path.join(log_dir, 'train.log'))
53 | test_log = Logger(os.path.join(log_dir, 'test.log'))
54 | config_log = Logger(os.path.join(log_dir, 'config.log'))
55 | if not os.path.isdir(os.path.join(log_dir, 'tb')):
56 | os.makedirs(os.path.join(log_dir, 'tb'))
57 |
58 | # msg logger
59 | log_level = logging.INFO
60 | fmt = '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
61 | logging.basicConfig(level=log_level,
62 | filename=os.path.join(log_dir, 'message.log'),
63 | filemode='w',
64 | format=fmt)
65 | console = logging.StreamHandler()
66 | console.setLevel(logging.INFO)
67 | # set a format which is simpler for console use
68 | formatter = logging.Formatter('%(message)s')
69 | # tell the handler to use this format
70 | console.setFormatter(formatter)
71 | # add the handler to the root logger
72 | logging.getLogger('').addHandler(console)
73 |
74 | # Save the config info
75 | for k, v in vars(args).items():
76 | config_log.write(content="{k} : {v}".format(k=k, v=v),
77 | wrap=True,
78 | flush=True)
79 | config_log.close()
80 |
81 | # logger initialization
82 | test_log.write(content="epoch\ttop1\ttop5\tloss\tcloss\trloss\tbloss\tdensity\tflops_per\tflops\t",
83 | wrap=True,
84 | flush=True)
85 | train_log.write(content="epoch\ttop1\ttop5\tloss\tcloss\trloss\tbloss",
86 | wrap=True,
87 | flush=True)
88 | return train_log, test_log, checkpoint_dir, log_dir
89 |
90 |
91 | def has_children(module):
92 | try:
93 | next(module.children())
94 | return True
95 | except StopIteration:
96 | return False
97 |
98 |
99 | class Logger(object):
100 | '''Save training process to log file with simple plot function.'''
101 | def __init__(self, fpath, title=None, resume=False):
102 | self.file = None
103 | self.resume = resume
104 | self.title = '' if title is None else title
105 | if fpath is not None:
106 | if resume:
107 | self.file = open(fpath, 'r')
108 | name = self.file.readline()
109 | self.names = name.rstrip().split('\t')
110 | self.numbers = {}
111 | for _, name in enumerate(self.names):
112 | self.numbers[name] = []
113 |
114 | for numbers in self.file:
115 | numbers = numbers.rstrip().split('\t')
116 | for i in range(0, len(numbers)):
117 | self.numbers[self.names[i]].append(numbers[i])
118 | self.file.close()
119 | self.file = open(fpath, 'a')
120 | else:
121 | self.file = open(fpath, 'w')
122 |
123 | def set_names(self, names):
124 | if self.resume:
125 | pass
126 | # initialize numbers as empty list
127 | self.numbers = {}
128 | self.names = names
129 | for _, name in enumerate(self.names):
130 | self.file.write(name)
131 | self.file.write('\t')
132 | self.numbers[name] = []
133 | self.file.write('\n')
134 | self.file.flush()
135 |
136 | def append(self, numbers):
137 | assert len(self.names) == len(numbers), 'Numbers do not match names'
138 | for index, num in enumerate(numbers):
139 | self.file.write("{0:.6f}".format(num))
140 | self.file.write('\t')
141 | self.numbers[self.names[index]].append(num)
142 | self.file.write('\n')
143 | self.file.flush()
144 |
145 | def plot(self, names=None):
146 | names = self.names if names is None else names
147 | numbers = self.numbers
148 | for _, name in enumerate(names):
149 | x = np.arange(len(numbers[name]))
150 | plt.plot(x, np.asarray(numbers[name]))
151 | plt.legend([self.title + '(' + name + ')' for name in names])
152 | plt.grid(True)
153 |
154 | def close(self):
155 | if self.file is not None:
156 | self.file.close()
157 |
158 | def write(self, content, wrap=True, flush=False, verbose=False):
159 | """
160 | write file and flush buffer to the disk
161 | :param content: str
162 | :param wrap: bool, whether to add '\n' at the end of the content
163 | :param flush: bool, whether to flush buffer to the disk, default=False
164 | :param verbose: bool, whether to print the content, default=False
165 | :return:
166 | void
167 | """
168 | if verbose:
169 | print(content)
170 | if wrap:
171 | content += "\n"
172 | self.file.write(content)
173 | if flush:
174 | self.file.flush()
175 | os.fsync(self.file)
176 |
177 |
178 | class LoggerMonitor(object):
179 | '''Load and visualize multiple logs.'''
180 | def __init__(self, paths):
181 | '''paths is a distionary with {name:filepath} pair'''
182 | self.loggers = []
183 | for title, path in paths.items():
184 | logger = Logger(path, title=title, resume=True)
185 | self.loggers.append(logger)
186 |
187 | def plot(self, names=None):
188 | plt.figure()
189 | plt.subplot(121)
190 | legend_text = []
191 | for logger in self.loggers:
192 | legend_text += plot_overlap(logger, names)
193 | plt.legend(legend_text,
194 | bbox_to_anchor=(1.05, 1),
195 | loc=2,
196 | borderaxespad=0.)
197 | plt.grid(True)
198 |
199 |
200 | def mkdir_p(path):
201 | '''make dir if not exist'''
202 | try:
203 | os.makedirs(path)
204 | except OSError as exc: # Python >2.5
205 | if exc.errno == errno.EEXIST and os.path.isdir(path):
206 | pass
207 | else:
208 | raise
209 |
210 |
211 | def size_to_str(torch_size):
212 | """Convert a pytorch Size object to a string"""
213 | assert isinstance(torch_size, (torch.Size, tuple, list))
214 | return '(' + (', ').join(['%d' % v for v in torch_size]) + ')'
215 |
216 |
217 | def to_np(var):
218 | return var.data.cpu().numpy()
219 |
220 |
221 | def norm_filters(weights, p=1):
222 | """Compute the p-norm of convolution filters.
223 |
224 | Args:
225 | weights - a 4D convolution weights tensor.
226 | Has shape = (#filters, #channels, k_w, k_h)
227 | p - the exponent value in the norm formulation
228 | """
229 | assert weights.dim() == 4
230 | return weights.view(weights.size(0), -1).norm(p=p, dim=1)
231 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 |
4 |
5 | __all__ = ['AverageMeter', 'accuracy', 'analyse_flops', 'ExpAnnealing']
6 |
7 |
8 | def accuracy(output, target, topk=(1,)):
9 | """Computes the precision@k for the specified values of k"""
10 | maxk = max(topk)
11 | batch_size = target.size(0)
12 |
13 | _, pred = output.topk(maxk, 1, True, True)
14 | pred = pred.t()
15 | correct = pred.eq(target.view(1, -1).expand_as(pred))
16 |
17 | res = []
18 | for k in topk:
19 | correct_k = correct[:k].view(-1).float().sum(0)
20 | res.append(correct_k.mul_(100.0 / batch_size))
21 | return res
22 |
23 |
24 | def analyse_flops(flops_real, flops_mask, flops_ori, batch_size):
25 | def add_sum(data):
26 | s = data.sum().unsqueeze(0)
27 | out = torch.cat([data, s])
28 | return out
29 | block_flops, flops_conv1, flops_fc = flops_real[0], flops_real[1], flops_real[2]
30 | flops_mask = flops_mask.mean(0)
31 | # block flops
32 | flops_conv = add_sum(block_flops[0:batch_size,:].mean(0))
33 | flops_mask = add_sum(flops_mask)
34 | flops_ori = add_sum(flops_ori.mean(0))
35 | return flops_conv, flops_mask, flops_ori, flops_conv1.mean(), flops_fc.mean()
36 |
37 |
38 | class AverageMeter(object):
39 | r"""Computes and stores the average and current value
40 | Imported from
41 | https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
42 | """
43 | def __init__(self):
44 | self.reset()
45 |
46 | def reset(self):
47 | self.val = 0
48 | self.avg = 0
49 | self.sum = 0
50 | self.count = 0
51 |
52 | def update(self, val, n=1):
53 | self.val = val
54 | self.sum += val * n
55 | self.count += n
56 | self.avg = self.sum / self.count
57 |
58 |
59 | class ExpAnnealing(object):
60 | r"""
61 | Args:
62 | T_max (int): Maximum number of iterations.
63 | eta_ini (float): Initial density. Default: 1.
64 | eta_min (float): Minimum density. Default: 0.
65 | """
66 |
67 | def __init__(self, T_ini, eta_ini=1, eta_final=0, up=False, alpha=1):
68 | self.T_ini = T_ini
69 | self.eta_final = eta_final
70 | self.eta_ini = eta_ini
71 | self.up = up
72 | self.last_epoch = 0
73 | self.alpha = alpha
74 |
75 | def get_lr(self, epoch):
76 | if epoch < self.T_ini:
77 | return self.eta_ini
78 | elif self.up:
79 | return self.eta_ini + (self.eta_final-self.eta_ini) * (1-
80 | math.exp(-self.alpha*(epoch-self.T_ini)))
81 | else:
82 | return self.eta_final + (self.eta_ini-self.eta_final) * math.exp(
83 | -self.alpha*(epoch-self.T_ini))
84 |
85 | def step(self):
86 | self.last_epoch += 1
87 | return self.get_lr(self.last_epoch)
88 |
--------------------------------------------------------------------------------
/utils/progress/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.egg-info
3 | build/
4 | dist/
5 |
--------------------------------------------------------------------------------
/utils/progress/LICENSE:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2012 Giorgos Verigakis
2 | #
3 | # Permission to use, copy, modify, and distribute this software for any
4 | # purpose with or without fee is hereby granted, provided that the above
5 | # copyright notice and this permission notice appear in all copies.
6 | #
7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 |
--------------------------------------------------------------------------------
/utils/progress/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include README.rst LICENSE
2 |
--------------------------------------------------------------------------------
/utils/progress/README.rst:
--------------------------------------------------------------------------------
1 | Easy progress reporting for Python
2 | ==================================
3 |
4 | |pypi|
5 |
6 | |demo|
7 |
8 | .. |pypi| image:: https://img.shields.io/pypi/v/progress.svg
9 | .. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif
10 | :alt: Demo
11 |
12 | Bars
13 | ----
14 |
15 | There are 7 progress bars to choose from:
16 |
17 | - ``Bar``
18 | - ``ChargingBar``
19 | - ``FillingSquaresBar``
20 | - ``FillingCirclesBar``
21 | - ``IncrementalBar``
22 | - ``PixelBar``
23 | - ``ShadyBar``
24 |
25 | To use them, just call ``next`` to advance and ``finish`` to finish:
26 |
27 | .. code-block:: python
28 |
29 | from progress.bar import Bar
30 |
31 | bar = Bar('Processing', max=20)
32 | for i in range(20):
33 | # Do some work
34 | bar.next()
35 | bar.finish()
36 |
37 | The result will be a bar like the following: ::
38 |
39 | Processing |############# | 42/100
40 |
41 | To simplify the common case where the work is done in an iterator, you can
42 | use the ``iter`` method:
43 |
44 | .. code-block:: python
45 |
46 | for i in Bar('Processing').iter(it):
47 | # Do some work
48 |
49 | Progress bars are very customizable, you can change their width, their fill
50 | character, their suffix and more:
51 |
52 | .. code-block:: python
53 |
54 | bar = Bar('Loading', fill='@', suffix='%(percent)d%%')
55 |
56 | This will produce a bar like the following: ::
57 |
58 | Loading |@@@@@@@@@@@@@ | 42%
59 |
60 | You can use a number of template arguments in ``message`` and ``suffix``:
61 |
62 | ========== ================================
63 | Name Value
64 | ========== ================================
65 | index current value
66 | max maximum value
67 | remaining max - index
68 | progress index / max
69 | percent progress * 100
70 | avg simple moving average time per item (in seconds)
71 | elapsed elapsed time in seconds
72 | elapsed_td elapsed as a timedelta (useful for printing as a string)
73 | eta avg * remaining
74 | eta_td eta as a timedelta (useful for printing as a string)
75 | ========== ================================
76 |
77 | Instead of passing all configuration options on instatiation, you can create
78 | your custom subclass:
79 |
80 | .. code-block:: python
81 |
82 | class FancyBar(Bar):
83 | message = 'Loading'
84 | fill = '*'
85 | suffix = '%(percent).1f%% - %(eta)ds'
86 |
87 | You can also override any of the arguments or create your own:
88 |
89 | .. code-block:: python
90 |
91 | class SlowBar(Bar):
92 | suffix = '%(remaining_hours)d hours remaining'
93 | @property
94 | def remaining_hours(self):
95 | return self.eta // 3600
96 |
97 |
98 | Spinners
99 | ========
100 |
101 | For actions with an unknown number of steps you can use a spinner:
102 |
103 | .. code-block:: python
104 |
105 | from progress.spinner import Spinner
106 |
107 | spinner = Spinner('Loading ')
108 | while state != 'FINISHED':
109 | # Do some work
110 | spinner.next()
111 |
112 | There are 5 predefined spinners:
113 |
114 | - ``Spinner``
115 | - ``PieSpinner``
116 | - ``MoonSpinner``
117 | - ``LineSpinner``
118 | - ``PixelSpinner``
119 |
120 |
121 | Other
122 | =====
123 |
124 | There are a number of other classes available too, please check the source or
125 | subclass one of them to create your own.
126 |
127 |
128 | License
129 | =======
130 |
131 | progress is licensed under ISC
132 |
--------------------------------------------------------------------------------
/utils/progress/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CAS-CLab/DGNet/6b709a388c463d7468fbad953ad0112bc3abe66d/utils/progress/demo.gif
--------------------------------------------------------------------------------
/utils/progress/progress/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2012 Giorgos Verigakis
2 | #
3 | # Permission to use, copy, modify, and distribute this software for any
4 | # purpose with or without fee is hereby granted, provided that the above
5 | # copyright notice and this permission notice appear in all copies.
6 | #
7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 |
15 | from __future__ import division
16 |
17 | from collections import deque
18 | from datetime import timedelta
19 | from math import ceil
20 | from sys import stderr
21 | from time import time
22 |
23 |
24 | __version__ = '1.3'
25 |
26 |
27 | class Infinite(object):
28 | file = stderr
29 | sma_window = 10 # Simple Moving Average window
30 |
31 | def __init__(self, *args, **kwargs):
32 | self.index = 0
33 | self.start_ts = time()
34 | self.avg = 0
35 | self._ts = self.start_ts
36 | self._xput = deque(maxlen=self.sma_window)
37 | for key, val in kwargs.items():
38 | setattr(self, key, val)
39 |
40 | def __getitem__(self, key):
41 | if key.startswith('_'):
42 | return None
43 | return getattr(self, key, None)
44 |
45 | @property
46 | def elapsed(self):
47 | return int(time() - self.start_ts)
48 |
49 | @property
50 | def elapsed_td(self):
51 | return timedelta(seconds=self.elapsed)
52 |
53 | def update_avg(self, n, dt):
54 | if n > 0:
55 | self._xput.append(dt / n)
56 | self.avg = sum(self._xput) / len(self._xput)
57 |
58 | def update(self):
59 | pass
60 |
61 | def start(self):
62 | pass
63 |
64 | def finish(self):
65 | pass
66 |
67 | def next(self, n=1):
68 | now = time()
69 | dt = now - self._ts
70 | self.update_avg(n, dt)
71 | self._ts = now
72 | self.index = self.index + n
73 | self.update()
74 |
75 | def iter(self, it):
76 | try:
77 | for x in it:
78 | yield x
79 | self.next()
80 | finally:
81 | self.finish()
82 |
83 |
84 | class Progress(Infinite):
85 | def __init__(self, *args, **kwargs):
86 | super(Progress, self).__init__(*args, **kwargs)
87 | self.max = kwargs.get('max', 100)
88 |
89 | @property
90 | def eta(self):
91 | return int(ceil(self.avg * self.remaining))
92 |
93 | @property
94 | def eta_td(self):
95 | return timedelta(seconds=self.eta)
96 |
97 | @property
98 | def percent(self):
99 | return self.progress * 100
100 |
101 | @property
102 | def progress(self):
103 | return min(1, self.index / self.max)
104 |
105 | @property
106 | def remaining(self):
107 | return max(self.max - self.index, 0)
108 |
109 | def start(self):
110 | self.update()
111 |
112 | def goto(self, index):
113 | incr = index - self.index
114 | self.next(incr)
115 |
116 | def iter(self, it):
117 | try:
118 | self.max = len(it)
119 | except TypeError:
120 | pass
121 |
122 | try:
123 | for x in it:
124 | yield x
125 | self.next()
126 | finally:
127 | self.finish()
128 |
--------------------------------------------------------------------------------
/utils/progress/progress/bar.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2012 Giorgos Verigakis
4 | #
5 | # Permission to use, copy, modify, and distribute this software for any
6 | # purpose with or without fee is hereby granted, provided that the above
7 | # copyright notice and this permission notice appear in all copies.
8 | #
9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 |
17 | from __future__ import unicode_literals
18 | from . import Progress
19 | from .helpers import WritelnMixin
20 |
21 |
22 | class Bar(WritelnMixin, Progress):
23 | width = 32
24 | message = ''
25 | suffix = '%(index)d/%(max)d'
26 | bar_prefix = ' |'
27 | bar_suffix = '| '
28 | empty_fill = ' '
29 | fill = '#'
30 | hide_cursor = True
31 |
32 | def update(self):
33 | filled_length = int(self.width * self.progress)
34 | empty_length = self.width - filled_length
35 |
36 | message = self.message % self
37 | bar = self.fill * filled_length
38 | empty = self.empty_fill * empty_length
39 | suffix = self.suffix % self
40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix,
41 | suffix])
42 | self.writeln(line)
43 |
44 |
45 | class ChargingBar(Bar):
46 | suffix = '%(percent)d%%'
47 | bar_prefix = ' '
48 | bar_suffix = ' '
49 | empty_fill = '∙'
50 | fill = '█'
51 |
52 |
53 | class FillingSquaresBar(ChargingBar):
54 | empty_fill = '▢'
55 | fill = '▣'
56 |
57 |
58 | class FillingCirclesBar(ChargingBar):
59 | empty_fill = '◯'
60 | fill = '◉'
61 |
62 |
63 | class IncrementalBar(Bar):
64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█')
65 |
66 | def update(self):
67 | nphases = len(self.phases)
68 | filled_len = self.width * self.progress
69 | nfull = int(filled_len) # Number of full chars
70 | phase = int((filled_len - nfull) * nphases) # Phase of last char
71 | nempty = self.width - nfull # Number of empty chars
72 |
73 | message = self.message % self
74 | bar = self.phases[-1] * nfull
75 | current = self.phases[phase] if phase > 0 else ''
76 | empty = self.empty_fill * max(0, nempty - len(current))
77 | suffix = self.suffix % self
78 | line = ''.join([message, self.bar_prefix, bar, current, empty,
79 | self.bar_suffix, suffix])
80 | self.writeln(line)
81 |
82 |
83 | class PixelBar(IncrementalBar):
84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿')
85 |
86 |
87 | class ShadyBar(IncrementalBar):
88 | phases = (' ', '░', '▒', '▓', '█')
89 |
--------------------------------------------------------------------------------
/utils/progress/progress/counter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2012 Giorgos Verigakis
4 | #
5 | # Permission to use, copy, modify, and distribute this software for any
6 | # purpose with or without fee is hereby granted, provided that the above
7 | # copyright notice and this permission notice appear in all copies.
8 | #
9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 |
17 | from __future__ import unicode_literals
18 | from . import Infinite, Progress
19 | from .helpers import WriteMixin
20 |
21 |
22 | class Counter(WriteMixin, Infinite):
23 | message = ''
24 | hide_cursor = True
25 |
26 | def update(self):
27 | self.write(str(self.index))
28 |
29 |
30 | class Countdown(WriteMixin, Progress):
31 | hide_cursor = True
32 |
33 | def update(self):
34 | self.write(str(self.remaining))
35 |
36 |
37 | class Stack(WriteMixin, Progress):
38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█')
39 | hide_cursor = True
40 |
41 | def update(self):
42 | nphases = len(self.phases)
43 | i = min(nphases - 1, int(self.progress * nphases))
44 | self.write(self.phases[i])
45 |
46 |
47 | class Pie(Stack):
48 | phases = ('○', '◔', '◑', '◕', '●')
49 |
--------------------------------------------------------------------------------
/utils/progress/progress/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2012 Giorgos Verigakis
2 | #
3 | # Permission to use, copy, modify, and distribute this software for any
4 | # purpose with or without fee is hereby granted, provided that the above
5 | # copyright notice and this permission notice appear in all copies.
6 | #
7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14 |
15 | from __future__ import print_function
16 |
17 |
18 | HIDE_CURSOR = '\x1b[?25l'
19 | SHOW_CURSOR = '\x1b[?25h'
20 |
21 |
22 | class WriteMixin(object):
23 | hide_cursor = False
24 |
25 | def __init__(self, message=None, **kwargs):
26 | super(WriteMixin, self).__init__(**kwargs)
27 | self._width = 0
28 | if message:
29 | self.message = message
30 |
31 | if self.file.isatty():
32 | if self.hide_cursor:
33 | print(HIDE_CURSOR, end='', file=self.file)
34 | print(self.message, end='', file=self.file)
35 | self.file.flush()
36 |
37 | def write(self, s):
38 | if self.file.isatty():
39 | b = '\b' * self._width
40 | c = s.ljust(self._width)
41 | print(b + c, end='', file=self.file)
42 | self._width = max(self._width, len(s))
43 | self.file.flush()
44 |
45 | def finish(self):
46 | if self.file.isatty() and self.hide_cursor:
47 | print(SHOW_CURSOR, end='', file=self.file)
48 |
49 |
50 | class WritelnMixin(object):
51 | hide_cursor = False
52 |
53 | def __init__(self, message=None, **kwargs):
54 | super(WritelnMixin, self).__init__(**kwargs)
55 | if message:
56 | self.message = message
57 |
58 | if self.file.isatty() and self.hide_cursor:
59 | print(HIDE_CURSOR, end='', file=self.file)
60 |
61 | def clearln(self):
62 | if self.file.isatty():
63 | print('\r\x1b[K', end='', file=self.file)
64 |
65 | def writeln(self, line):
66 | if self.file.isatty():
67 | self.clearln()
68 | print(line, end='', file=self.file)
69 | self.file.flush()
70 |
71 | def finish(self):
72 | if self.file.isatty():
73 | print(file=self.file)
74 | if self.hide_cursor:
75 | print(SHOW_CURSOR, end='', file=self.file)
76 |
77 |
78 | from signal import signal, SIGINT
79 | from sys import exit
80 |
81 |
82 | class SigIntMixin(object):
83 | """Registers a signal handler that calls finish on SIGINT"""
84 |
85 | def __init__(self, *args, **kwargs):
86 | super(SigIntMixin, self).__init__(*args, **kwargs)
87 | signal(SIGINT, self._sigint_handler)
88 |
89 | def _sigint_handler(self, signum, frame):
90 | self.finish()
91 | exit(0)
92 |
--------------------------------------------------------------------------------
/utils/progress/progress/spinner.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Copyright (c) 2012 Giorgos Verigakis
4 | #
5 | # Permission to use, copy, modify, and distribute this software for any
6 | # purpose with or without fee is hereby granted, provided that the above
7 | # copyright notice and this permission notice appear in all copies.
8 | #
9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 |
17 | from __future__ import unicode_literals
18 | from . import Infinite
19 | from .helpers import WriteMixin
20 |
21 |
22 | class Spinner(WriteMixin, Infinite):
23 | message = ''
24 | phases = ('-', '\\', '|', '/')
25 | hide_cursor = True
26 |
27 | def update(self):
28 | i = self.index % len(self.phases)
29 | self.write(self.phases[i])
30 |
31 |
32 | class PieSpinner(Spinner):
33 | phases = ['◷', '◶', '◵', '◴']
34 |
35 |
36 | class MoonSpinner(Spinner):
37 | phases = ['◑', '◒', '◐', '◓']
38 |
39 |
40 | class LineSpinner(Spinner):
41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻']
42 |
43 | class PixelSpinner(Spinner):
44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽']
45 |
--------------------------------------------------------------------------------
/utils/progress/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import setup
4 |
5 | import progress
6 |
7 |
8 | setup(
9 | name='progress',
10 | version=progress.__version__,
11 | description='Easy to use progress bars',
12 | long_description=open('README.rst').read(),
13 | author='Giorgos Verigakis',
14 | author_email='verigak@gmail.com',
15 | url='http://github.com/verigak/progress/',
16 | license='ISC',
17 | packages=['progress'],
18 | classifiers=[
19 | 'Environment :: Console',
20 | 'Intended Audience :: Developers',
21 | 'License :: OSI Approved :: ISC License (ISCL)',
22 | 'Programming Language :: Python :: 2.6',
23 | 'Programming Language :: Python :: 2.7',
24 | 'Programming Language :: Python :: 3.3',
25 | 'Programming Language :: Python :: 3.4',
26 | 'Programming Language :: Python :: 3.5',
27 | 'Programming Language :: Python :: 3.6',
28 | ]
29 | )
30 |
--------------------------------------------------------------------------------