├── .gitignore ├── LICENSE ├── README.md ├── dataloader.py ├── img └── DGNet.png ├── main.py ├── models ├── __init__.py ├── cifar │ ├── __init__.py │ └── resdg.py ├── imagenet │ ├── __init__.py │ └── resdg.py ├── mask.py └── mobilenet_v2 │ ├── __init__.py │ ├── mobilenet_v2_dg.py │ └── mobilenet_v2_dg_util.py ├── options.py ├── regularization.py ├── requirements.txt ├── scripts ├── cifar_e.sh ├── cifar_t.sh ├── imagenet_e.sh ├── imagenet_t.sh ├── mobilenet_v2_e.sh └── mobilenet_v2_t.sh └── utils ├── __init__.py ├── logger.py ├── misc.py └── progress ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.rst ├── demo.gif ├── progress ├── __init__.py ├── bar.py ├── counter.py ├── helpers.py └── spinner.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # dataset 7 | data/ 8 | 9 | # log 10 | logs/ 11 | 12 | # jupyter notebook 13 | *.ipynb 14 | 15 | # IDE 16 | .vscode 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 anonymous-9800 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dynamic Dual Gating Neural Networks 2 | 3 | This repository contains the PyTorch implementation for 4 | 5 | > **Dynamic Dual Gating Neural Networks** 6 | > Fanrong Li, Gang Li, Xiangyu He, Jian Cheng 7 | > ICCV 2021 Oral 8 | 9 | ![image](img/DGNet.png) 10 | 11 | ## Getting Started 12 | 13 | ### Requirements 14 | 15 | The main requirements of this work are: 16 | 17 | - Python 3.7 18 | - PyTorch == 1.5.0 19 | - Torchvision == 0.6.0 20 | - CUDA 10.2 21 | 22 | We recommand using conda env to setup the experimental environments. 23 | 24 | 25 | ```shell script 26 | # Create environment 27 | conda create -n DGNet python=3.7 28 | conda activate DGNet 29 | 30 | # Install PyTorch & Torchvision 31 | pip install torch==1.5.0 torchvision==0.6.0 32 | 33 | # Clone repo 34 | git clone https://github.com/anonymous-9800/DGNet.git ./DGNet 35 | cd ./DGNet 36 | 37 | # Install other requirements 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | ### Trained models 42 | Our trained models can be found here: [Google Drive](https://drive.google.com/file/d/1_-G5eHm3PUrrorjzp8w17W7ogZZoTElk/view?usp=sharing). And the pretrained cifar10 models can be found here: [Google Drive](https://drive.google.com/file/d/15sM2W2ADqtq5Gr8RTdaFalPK7qIw0VXF/view?usp=sharing). Unzip and place them into the DGNet folder. 43 | 44 | ### Evaluate a trained DGNet 45 | 46 | ```shell script 47 | # CIFAR-10 48 | sh ./scripts/cifar_e.sh [ARCH] [PATH-TO-DATASET] [GPU-IDs] [PATH-TO-SAVE] [PATH-TO-TRAINED-MODEL] 49 | 50 | # ResNet on ImageNet 51 | sh ./scripts/imagenet_e.sh [ARCH] [PATH-TO-DATASET] [GPU-IDs] [PATH-TO-SAVE] [PATH-TO-TRAINED-MODEL] 52 | 53 | # Example 54 | sh ./scripts/imagenet_e.sh resdg34 [PATH-TO-DATASET] 0 imagenet/resdg34-04-e ./trained_models_cls/imagenet_results/resdg34/sparse06/resdg34_04.pth.tar 55 | ``` 56 | 57 | ### Train a DGNet 58 | ```shell script 59 | # CIFAR-10 60 | sh ./scripts/cifar_t.sh [ARCH] [PATH-TO-DATASET] [TARGET-DENSITY] [GPU-IDs] [PATH-TO-SAVE] [PATH-TO-PRETRAINED-MODEL] 61 | 62 | # ResNet on ImageNet 63 | sh ./scripts/imagenet_t.sh [ARCH] [PATH-TO-DATASET] [TARGET-DENSITY] [GPU-IDs] [PATH-TO-SAVE] 64 | 65 | # Example 66 | sh ./scripts/imagenet_t.sh resdg34 [PATH-TO-DATASET] 0.4 0,1 imagent/resdg34-04 67 | ``` 68 | 69 | ## Main results 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 |
ModelMethodTop-1 (%)Top-5 (%)FLOPsGoogle Drive
ResNet-18DGNet (50%)70.1289.229.54E8Link
DGNet (60%)69.3888.947.88E8Link
ResNet-34DGNet (60%)73.0190.991.50E9Link
DGNet (70%)71.9590.461.21E9Link
ResNet-50DGNet (60%)76.4193.051.65E9Link
DGNet (70%)75.1292.341.31E9Link
MobileNet-V2DGNet (50%)71.6290.051.60E8Link
134 | 135 | ## Citation 136 | 137 | If you find this project useful for your research, please use the following BibTeX entry. 138 | 139 | @inproceedings{dgnet, 140 | title={Dynamic Dual Gating Neural Networks}, 141 | author={Li, Fanrong and Li, Gang and He, Xiangyu and Cheng, Jian}, 142 | booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, 143 | year={2021} 144 | } 145 | 146 | ## Contact 147 | For any questions, feel free to contact: 148 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import torchvision.datasets as datasets 5 | import torchvision.transforms as transforms 6 | 7 | 8 | def _getCifarLoader(data, dataset, batch_size, workers): 9 | traindir = os.path.join(data) 10 | valdir = os.path.join(data) 11 | if dataset == 'cifar10': 12 | normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 13 | std=(0.2023, 0.1994, 0.2010)) 14 | else: 15 | normalize = transforms.Normalize(mean=(0.507, 0.487, 0.441), 16 | std=(0.267, 0.256, 0.276)) 17 | 18 | logging.info('=> Preparing dataset %s' % dataset) 19 | transform_train = transforms.Compose([ 20 | transforms.RandomCrop(32, padding=4), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | normalize, 24 | ]) 25 | 26 | transform_test = transforms.Compose([ 27 | transforms.ToTensor(), 28 | normalize, 29 | ]) 30 | 31 | if dataset == 'cifar10': 32 | dataloader = datasets.CIFAR10 33 | num_classes = 10 34 | else: 35 | dataloader = datasets.CIFAR100 36 | num_classes = 100 37 | 38 | trainset = dataloader(root=traindir, 39 | train=True, 40 | download=False, 41 | transform=transform_train) 42 | trainloader = torch.utils.data.DataLoader(trainset, 43 | batch_size=batch_size, 44 | shuffle=True, 45 | num_workers=workers) 46 | 47 | testset = dataloader(root=valdir, 48 | train=False, 49 | download=False, 50 | transform=transform_test) 51 | testloader = torch.utils.data.DataLoader(testset, 52 | batch_size=batch_size, 53 | shuffle=False, 54 | num_workers=workers) 55 | return trainloader, testloader 56 | 57 | 58 | def _getImageNetLoader(data, batch_size, workers): 59 | traindir = os.path.join(data, 'train') 60 | valdir = os.path.join(data, 'val') 61 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 62 | std=[0.229, 0.224, 0.225]) 63 | 64 | train_loader = torch.utils.data.DataLoader( 65 | datasets.ImageFolder(traindir, transforms.Compose([ 66 | transforms.RandomResizedCrop(224, scale=(0.2, 1.0)), 67 | transforms.RandomHorizontalFlip(), 68 | transforms.ToTensor(), 69 | normalize, 70 | ])), 71 | batch_size=batch_size, shuffle=True, 72 | num_workers=workers, pin_memory=True, 73 | drop_last=True) 74 | 75 | val_loader = torch.utils.data.DataLoader( 76 | datasets.ImageFolder(valdir, transforms.Compose([ 77 | transforms.Resize(256), 78 | transforms.CenterCrop(224), 79 | transforms.ToTensor(), 80 | normalize, 81 | ])), 82 | batch_size=batch_size, shuffle=False, 83 | num_workers=workers, pin_memory=True, 84 | drop_last=True) 85 | return train_loader, val_loader 86 | 87 | 88 | def getDataLoader(data, dataset, batch_size, workers): 89 | if dataset == 'imagenet': 90 | return _getImageNetLoader(data, batch_size, workers) 91 | else: 92 | return _getCifarLoader(data, dataset, batch_size, workers) 93 | -------------------------------------------------------------------------------- /img/DGNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CAS-CLab/DGNet/6b709a388c463d7468fbad953ad0112bc3abe66d/img/DGNet.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import random 5 | import shutil 6 | import logging 7 | import torch 8 | import torch.nn as nn 9 | import models 10 | import numpy as np 11 | from options import parser 12 | from collections import OrderedDict 13 | from dataloader import getDataLoader 14 | from utils import * 15 | from regularization import * 16 | 17 | args = parser.parse_args() 18 | state = {k: v for k, v in args._get_kwargs()} 19 | print('Parameters:') 20 | for key, value in state.items(): 21 | print(' {key} : {value}'.format(key=key, value=value)) 22 | 23 | # Use CUDA 24 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 25 | use_cuda = torch.cuda.is_available() 26 | 27 | # Random seed 28 | if args.manualSeed is None: 29 | args.manualSeed = random.randint(1, 10000) 30 | random.seed(args.manualSeed) 31 | torch.manual_seed(args.manualSeed) 32 | np.random.seed(args.manualSeed) 33 | if use_cuda: 34 | torch.cuda.manual_seed_all(args.manualSeed) 35 | torch.backends.cudnn.deterministic = True 36 | 37 | best_acc = 0 # best test accuracy 38 | 39 | # Get loggers and save the config information 40 | train_log, test_log, checkpoint_dir, log_dir = get_loggers(args) 41 | 42 | def main(): 43 | global best_acc, train_log, test_log, checkpoint_dir, log_dir 44 | # create model 45 | logging.info("=" * 89) 46 | logging.info("=> creating model '{}'".format(args.arch)) 47 | model = models.get_model(pretrained=args.pretrained, dataset = args.dataset, 48 | arch = args.arch, bias=args.bias) 49 | # define loss function (criterion) and optimizer 50 | criterion = Loss() 51 | model.set_criterion(criterion) 52 | # Data loader 53 | trainloader, testloader = getDataLoader(args.data, args.dataset, args.batch_size, 54 | args.workers) 55 | # to cuda 56 | if torch.cuda.is_available() and args.gpu_id != -1: 57 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 58 | model = torch.nn.DataParallel(model).cuda() 59 | logging.info('=> running the model on gpu{}.'.format(args.gpu_id)) 60 | else: 61 | logging.info('=> running the model on cpu.') 62 | # define optimizer 63 | param_dict = dict(model.named_parameters()) 64 | params = [] 65 | BN_name_pool = [] 66 | for m_name, m in model.named_modules(): 67 | if isinstance(m, nn.BatchNorm2d): 68 | BN_name_pool.append(m_name + '.weight') 69 | BN_name_pool.append(m_name + '.bias') 70 | for key, value in param_dict.items(): 71 | if (key in BN_name_pool and 'mobilenet' in args.arch) or 'mask' in key: 72 | params += [{'params': [value], 'lr': args.learning_rate, 'weight_decay': 0.}] 73 | else: 74 | params += [{'params':[value]}] 75 | optimizer = torch.optim.SGD(params, lr=args.learning_rate,weight_decay=args.weight_decay, 76 | momentum=args.momentum, nesterov=True) 77 | p_anneal = ExpAnnealing(0, 1, 0, alpha=args.alpha) 78 | # ready 79 | logging.info("=" * 89) 80 | # Evaluate 81 | if args.evaluate: 82 | logging.info('Evaluate model') 83 | top1, top5 = validate(testloader, model, criterion, 0, use_cuda, 84 | (args.lbda, 0), args.den_target) 85 | logging.info('Test Acc (Top-1): %.2f, Test Acc (Top-5): %.2f' % (top1, top5)) 86 | return 87 | # training 88 | logging.info('\n Train for {} epochs'.format(args.epochs)) 89 | train_process(model, args.epochs, testloader, trainloader, criterion, optimizer, 90 | use_cuda, args.lbda, args.gamma, p_anneal, checkpoint_dir, args.den_target) 91 | train_log.close() 92 | test_log.close() 93 | logging.info('Best acc: {}'.format(best_acc)) 94 | return 95 | 96 | 97 | def train_process(model, total_epochs, testloader, trainloader, criterion, optimizer, 98 | use_cuda, lbda, gamma, p_anneal, checkpoint_dir, den_target): 99 | global best_acc 100 | for epoch in range(total_epochs): 101 | p = p_anneal.get_lr(epoch) 102 | # get target density 103 | state['den_target'] = den_target 104 | # update lr 105 | adjust_learning_rate(optimizer, epoch=epoch) 106 | # Training 107 | train(trainloader, model, criterion, optimizer, epoch, use_cuda, (lbda, gamma), 108 | den_target, p) 109 | test_acc, _ = validate(testloader, model, criterion, epoch, use_cuda, 110 | (lbda, gamma), den_target, p=p) 111 | # save checkpoint 112 | if checkpoint_dir is not None: 113 | is_best = test_acc > best_acc 114 | best_acc = max(test_acc, best_acc) 115 | model_dict = model.module.state_dict() if use_cuda else model.state_dict() 116 | save_checkpoint( 117 | { 118 | 'epoch': epoch + 1, 119 | 'state_dict': model_dict, 120 | 'acc': test_acc, 121 | 'best_acc': best_acc, 122 | 'optimizer': optimizer.state_dict() 123 | }, 124 | is_best=is_best, 125 | checkpoint_dir=checkpoint_dir) 126 | return 127 | 128 | 129 | def train(train_loader, model, criterion, optimizer, epoch, use_cuda, param, 130 | den_target, p): 131 | lbda, gamma = param 132 | # switch to train mode 133 | model.train() 134 | logging.info("=" * 89) 135 | 136 | batch_time, data_time, closses, rlosses, blosses, losses, top1, top5 = getAvgMeter(8) 137 | 138 | end = time.time() 139 | bar = Bar('Processing', max=len(train_loader)) 140 | for batch_idx, (x, targets) in enumerate(train_loader): 141 | # measure data loading time 142 | data_time.update(time.time() - end) 143 | # get inputs 144 | if use_cuda: 145 | x, targets = x.cuda(), targets.cuda() 146 | x, targets = torch.autograd.Variable(x), torch.autograd.Variable(targets) 147 | batch_size = x.size(0) 148 | # inference 149 | inputs = {"x": x, "label": targets, "den_target": den_target, "lbda": lbda, 150 | "gamma": gamma, "p": p} 151 | outputs= model(**inputs) 152 | loss = outputs["closs"].mean() + outputs["rloss"].mean() + outputs["bloss"].mean() 153 | # measure accuracy and record loss 154 | prec1, prec5 = accuracy(outputs["out"].data, targets.data, topk=(1, 5)) 155 | closses.update(outputs["closs"].mean().item(), batch_size) 156 | rlosses.update(outputs["rloss"].mean().item(), batch_size) 157 | blosses.update(outputs["bloss"].mean().item(), batch_size) 158 | losses.update(loss.item(), batch_size) 159 | top1.update(prec1.item(), batch_size) 160 | top5.update(prec5.item(), batch_size) 161 | # compute gradient and do SGD step 162 | optimizer.zero_grad() 163 | loss.backward() 164 | optimizer.step() 165 | # measure elapsed time 166 | batch_time.update(time.time() - end) 167 | end = time.time() 168 | # plot progress 169 | bar.suffix = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | '.format( 170 | batch=batch_idx+1, size=len(train_loader), data=data_time.val, bt=batch_time.val, 171 | )+'Total: {total:} | (C,R,B)Loss: {closs:.2f}, {rloss:.2f}, {bloss:.2f}'.format( 172 | total=bar.elapsed_td, closs=closses.avg, rloss=rlosses.avg, bloss=blosses.avg, 173 | )+' | Loss: {loss:.2f} | top1: {top1:.2f} | top5: {top5:.2f}'.format(top1=top1.avg, 174 | top5=top5.avg, loss=losses.avg) 175 | bar.next() 176 | bar.finish() 177 | train_log.write(content="{epoch}\t{top1.avg:.4f}\t{top5.avg:.4f}\t{loss.avg:.4f}\t" 178 | "{closs.avg:.4f}\t{rloss.avg:.4f}\t{bloss.avg:.4f}".format( 179 | epoch=epoch, top1=top1, top5=top5,loss=losses, closs=closses, 180 | rloss=rlosses, bloss=blosses), 181 | wrap=True, flush=True) 182 | return 183 | 184 | 185 | def validate(val_loader, model, criterion, epoch, use_cuda, param, den_target, p=0): 186 | global log_dir 187 | lbda, gamma = param 188 | # switch to evaluate mode 189 | model.eval() 190 | logging.info("=" * 89) 191 | 192 | (batch_time, data_time, closses, rlosses, blosses, losses, 193 | top1, top5, block_flops)= getAvgMeter(9) 194 | 195 | with torch.no_grad(): 196 | end = time.time() 197 | bar = Bar('Processing', max=len(val_loader)) 198 | for batch_idx, (x, targets) in enumerate(val_loader): 199 | # measure data loading time 200 | data_time.update(time.time() - end) 201 | # get inputs 202 | if use_cuda: 203 | x, targets = x.cuda(), targets.cuda(non_blocking=True) 204 | x, targets = torch.autograd.Variable(x), torch.autograd.Variable(targets) 205 | batch_size = x.size(0) 206 | # inference 207 | inputs = {"x": x, "label": targets, "den_target": den_target, "lbda": lbda, 208 | "gamma": gamma, "p": p} 209 | outputs= model(**inputs) 210 | loss = outputs["closs"].mean() + outputs["rloss"].mean() + outputs["bloss"].mean() 211 | # measure accuracy and record loss 212 | prec1, prec5 = accuracy(outputs["out"].data, targets.data, topk=(1, 5)) 213 | closses.update(outputs["closs"].mean().item(), batch_size) 214 | rlosses.update(outputs["rloss"].mean().item(), batch_size) 215 | blosses.update(outputs["bloss"].mean().item(), batch_size) 216 | losses.update(loss.item(), batch_size) 217 | top1.update(prec1.item(), batch_size) 218 | top5.update(prec5.item(), batch_size) 219 | # measure elapsed time 220 | batch_time.update(time.time() - end) 221 | end = time.time() 222 | # get flops 223 | flops_real = outputs["flops_real"] 224 | flops_mask = outputs["flops_mask"] 225 | flops_ori = outputs["flops_ori"] 226 | flops_conv, flops_mask, flops_ori, flops_conv1, flops_fc = analyse_flops( 227 | flops_real, flops_mask, flops_ori, batch_size) 228 | block_flops.update(flops_conv, batch_size) 229 | # plot progress 230 | bar.suffix = '({batch}/{size}) Batch: {bt:.3f}s | Total: {total:}'.format( 231 | batch=batch_idx+1, size=len(val_loader), bt=batch_time.avg, total=bar.elapsed_td 232 | )+' | (C,R,B)Loss: {closs:.2f}, {rloss:.2f}, {bloss:.2f}'.format( 233 | closs=closses.avg, rloss=rlosses.avg, bloss=blosses.avg, 234 | )+' | Loss: {loss:.2f} | top1: {top1:.2f} | top5: {top5:.2f}'.format( 235 | top1=top1.avg, top5=top5.avg, loss=losses.avg) 236 | bar.next() 237 | bar.finish() 238 | # log 239 | if use_cuda: 240 | model.module.record_flops(block_flops.avg, flops_mask, flops_ori, flops_conv1, flops_fc) 241 | else: 242 | model.record_flops(block_flops.avg, flops_mask, flops_ori, flops_conv1, flops_fc) 243 | flops = (block_flops.avg[-1]+flops_mask[-1]+flops_conv1.mean()+flops_fc.mean())/1024 244 | flops_per = (block_flops.avg[-1]+flops_mask[-1]+flops_conv1.mean()+flops_fc.mean())/( 245 | flops_ori[-1]+flops_conv1.mean()+flops_fc.mean())*100 246 | test_log.write(content="{epoch}\t{top1.avg:.4f}\t{top5.avg:.4f}\t{loss.avg:.4f}\t" 247 | "{closs.avg:.4f}\t{rloss.avg:.4f}\t{bloss.avg:.4f}\t" 248 | "{flops_per:.2f}%\t{flops:.2f}K\t".format(epoch=epoch, top1=top1, 249 | top5=top5, loss=losses, closs=closses, rloss=rlosses, 250 | bloss=blosses, flops_per=flops_per, flops=flops), 251 | wrap=True, flush=True) 252 | return (top1.avg, top5.avg) 253 | 254 | 255 | def getAvgMeter(num): 256 | return [AverageMeter() for _ in range(num)] 257 | 258 | 259 | def adjust_learning_rate(optimizer, epoch): 260 | global state 261 | if args.lr_mode == 'cosine': 262 | lr = 0.5*args.learning_rate*(1+math.cos(math.pi*float(epoch)/float(args.epochs))) 263 | state['learning_rate'] = lr 264 | for param_group in optimizer.param_groups: 265 | param_group['lr'] = lr 266 | elif args.lr_mode == 'step': 267 | if epoch in args.schedule: 268 | state['learning_rate'] *= args.lr_decay 269 | for param_group in optimizer.param_groups: 270 | param_group['lr'] = state['learning_rate'] 271 | else: 272 | raise NotImplementedError('can not support lr mode {}'.format(args.lr_mode)) 273 | logging.info("\nEpoch: {epoch:3d} | learning rate = {lr:.6f}".format( 274 | epoch=epoch, lr=state['learning_rate'])) 275 | 276 | 277 | def save_checkpoint(state, 278 | is_best, 279 | filename='checkpoint.pth.tar', 280 | checkpoint_dir='.'): 281 | filename = os.path.join(checkpoint_dir, filename) 282 | torch.save(state, filename, pickle_protocol=4) 283 | if is_best: 284 | shutil.copyfile(filename, os.path.join(checkpoint_dir, 'model_best.pth.tar')) 285 | 286 | if __name__ == "__main__": 287 | main() 288 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import logging 5 | import pretrainedmodels 6 | import torchvision.models as torch_models 7 | import torch.backends.cudnn as cudnn 8 | from torchvision.models.utils import load_state_dict_from_url 9 | from . import cifar as cifar_models 10 | from . import imagenet as imagenet_extra_models 11 | from . import mobilenet_v2 as mobilenet_models 12 | 13 | 14 | SUPPORTED_DATASETS = ('imagenet', 'cifar10') 15 | 16 | TORCHVISION_MODEL_NAMES = sorted(name for name in torch_models.__dict__ 17 | if not name.startswith("__") 18 | and callable(torch_models.__dict__[name])) 19 | 20 | IMAGENET_MODEL_NAMES = copy.deepcopy(TORCHVISION_MODEL_NAMES) 21 | IMAGENET_MODEL_NAMES.extend(sorted(name for name in imagenet_extra_models.__dict__ 22 | if name.islower() and not name.startswith("__") 23 | and callable(imagenet_extra_models.__dict__[name]))) 24 | 25 | CIFAR_MODEL_NAMES = sorted(name for name in cifar_models.__dict__ 26 | if name.islower() and not name.startswith("__") 27 | and callable(cifar_models.__dict__[name])) 28 | 29 | MOBILENET_MODEL_NAMES = sorted(name for name in mobilenet_models.__dict__ 30 | if name.islower() and not name.startswith("__") 31 | and callable(mobilenet_models.__dict__[name])) 32 | 33 | ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR_MODEL_NAMES 34 | + MOBILENET_MODEL_NAMES))) 35 | 36 | model_urls = { 37 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 38 | } 39 | 40 | 41 | def get_model(pretrained, dataset, arch, **kwargs): 42 | """Create a pytorch model based on the model architecture and dataset 43 | 44 | Args: 45 | pretrained [boolean]: True is you wish to load a pretrained model. 46 | Some models do not have a pretrained version. 47 | dataset: dataset name ('imagenet', 'cifar100', and 'cifar10' are supported) 48 | arch: architecture name 49 | """ 50 | dataset = dataset.lower() 51 | if dataset not in SUPPORTED_DATASETS: 52 | raise ValueError('Dataset {} is not supported'.format(dataset)) 53 | 54 | model = None 55 | cadene = False 56 | try: 57 | if dataset == 'imagenet': 58 | if 'mobilenet' in arch: 59 | model = _create_mobilenet_model(arch, pretrained, **kwargs) 60 | else: 61 | kwargs['num_classes'] = 1000 62 | model = _create_imagenet_model(arch, pretrained, **kwargs) 63 | elif dataset == 'cifar10': 64 | kwargs['num_classes'] = 10 65 | model = _create_cifar10_model(arch, pretrained, **kwargs) 66 | except ValueError: 67 | raise ValueError('Could not recognize dataset {} and model {} pair'.format(dataset, arch)) 68 | 69 | logging.info("=> created a %s%s model with the %s dataset" % ('pretrained ' if pretrained else '', 70 | arch, dataset)) 71 | logging.info(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0)) 72 | return model 73 | 74 | 75 | def _create_imagenet_model(arch, pretrained, **kwargs): 76 | dataset = "imagenet" 77 | model = None 78 | pretrained_pytorch = pretrained == 'pytorch' 79 | pretrained_checkpoint = os.path.isfile(pretrained) 80 | if arch in TORCHVISION_MODEL_NAMES: 81 | try: 82 | model = getattr(torch_models, arch)(pretrained=pretrained_pytorch) 83 | except NotImplementedError: 84 | # In torchvision 0.3, trying to download a model that has no 85 | # pretrained image available will raise NotImplementedError 86 | if not pretrained_pytorch: 87 | raise 88 | if model is None and (arch in imagenet_extra_models.__dict__): 89 | model = imagenet_extra_models.__dict__[arch](**kwargs) 90 | if pretrained_pytorch: 91 | model_dict = model.state_dict() 92 | # get pretrained model 93 | if arch.startswith('resdg'): 94 | arch_pretrained = 'resnet' + arch.lstrip('resdg') 95 | else: 96 | raise ValueError("There is no pretrained model for {} in pytorch".format(arch)) 97 | logging.info("=> use a pretrained %s model to initialize." % (arch_pretrained)) 98 | pretrained_model = getattr(torch_models, arch_pretrained)(pretrained=pretrained) 99 | pretrained_dict = pretrained_model.state_dict() 100 | # filter out unnecessary keys 101 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 102 | # overwrite entries in the existing state dict 103 | model_dict.update(pretrained_dict) 104 | # load the new state dict 105 | model.load_state_dict(model_dict) 106 | elif pretrained_checkpoint: 107 | checkpoint = torch.load(pretrained, map_location=lambda storage, loc: storage) 108 | logging.info("=> loaded checkpoint (prec {:.2f})".format(checkpoint['best_acc'])) 109 | model_dict = model.state_dict() 110 | pretrained_dict = checkpoint['state_dict'] 111 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 112 | model_dict.update(pretrained_dict) 113 | model.load_state_dict(model_dict) 114 | 115 | if model is None and (arch in pretrainedmodels.model_names): 116 | model = pretrainedmodels.__dict__[arch]( 117 | num_classes=1000, 118 | pretrained=(dataset if pretrained else None)) 119 | 120 | if model is None: 121 | error_message = '' 122 | if arch not in IMAGENET_MODEL_NAMES: 123 | error_message = "Model {} is not supported for dataset ImageNet".format(arch) 124 | elif pretrained: 125 | error_message = "Model {} (ImageNet) does not have a pretrained model".format(arch) 126 | raise ValueError(error_message or 'Failed to find model {}'.format(arch)) 127 | return model 128 | 129 | 130 | def _create_cifar10_model(arch, pretrained, **kwargs): 131 | try: 132 | model = cifar_models.__dict__[arch](**kwargs) 133 | except KeyError: 134 | raise ValueError("Model {} is not supported for dataset CIFAR10".format(arch)) 135 | pretrained_path = pretrained 136 | pretrained = os.path.isfile(pretrained_path) 137 | # load pretrained model 138 | if pretrained: 139 | checkpoint = torch.load(pretrained_path, map_location=lambda storage, loc: storage) 140 | logging.info("=> loaded checkpoint (prec {:.2f})".format(checkpoint['best_acc'])) 141 | model_dict = model.state_dict() 142 | pretrained_dict = checkpoint['state_dict'] 143 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 144 | model_dict.update(pretrained_dict) 145 | model.load_state_dict(model_dict) 146 | return model 147 | 148 | def _create_mobilenet_model(arch, pretrained, **kwargs): 149 | model = mobilenet_models.__dict__[arch](**kwargs) 150 | pretrained_pytorch = pretrained == 'pytorch' 151 | pretrained_checkpoint = os.path.isfile(pretrained) 152 | if pretrained_pytorch: 153 | model_dict = model.state_dict() 154 | arch_pretrained = arch[:-3] 155 | logging.info("=> use a pretrained %s model to initialize." % (arch_pretrained)) 156 | if 'mobilenet_v2' in arch: 157 | pretrained_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], progress=True) 158 | else: 159 | pretrained_model = mobilenet_models.__dict__[arch_pretrained](pretrained=True, **kwargs) 160 | pretrained_dict = pretrained_model.state_dict() 161 | # filter out unnecessary keys 162 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 163 | # overwrite entries in the existing state dict 164 | model_dict.update(pretrained_dict) 165 | # load the new state dict 166 | model.load_state_dict(model_dict) 167 | elif pretrained_checkpoint: 168 | checkpoint = torch.load(pretrained, map_location=lambda storage, loc: storage) 169 | logging.info("=> loaded checkpoint (prec {:.2f})".format(checkpoint['best_acc'])) 170 | model_dict = model.state_dict() 171 | pretrained_dict = checkpoint['state_dict'] 172 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 173 | model_dict.update(pretrained_dict) 174 | model.load_state_dict(model_dict) 175 | else: 176 | logging.info("=> no checkpoint found at '{}'".format(pretrained)) 177 | return model 178 | -------------------------------------------------------------------------------- /models/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | from .resdg import * 2 | -------------------------------------------------------------------------------- /models/cifar/resdg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | from prettytable import PrettyTable 6 | from ..mask import Mask_s, Mask_c 7 | 8 | 9 | __all__ = ['resdg20_cifar10', 'resdg32_cifar10', 'resdg56_cifar10', 10 | 'resdg110_cifar10'] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=dilation, groups=groups, bias=False, dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 22 | 23 | 24 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False): 25 | if ceil_mode: 26 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 27 | else: 28 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | def __init__(self, inplanes, planes, h, w, eta=4, 34 | stride=1, downsample=None, **kwargs): 35 | super(BasicBlock, self).__init__() 36 | # gating modules 37 | self.height = conv2d_out_dim(h, kernel_size=3, stride=stride, padding=1) 38 | self.width = conv2d_out_dim(w, kernel_size=3, stride=stride, padding=1) 39 | self.mask_s = Mask_s(self.height, self.width, inplanes, eta, eta, **kwargs) 40 | self.mask_c = Mask_c(inplanes, planes, **kwargs) 41 | self.upsample = nn.Upsample(size=(self.height, self.width), mode='nearest') 42 | # conv 1 43 | self.conv1 = conv3x3(inplanes, planes, stride) 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | # conv 2 47 | self.conv2 = conv3x3(planes, planes) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | # misc 50 | self.downsample = downsample 51 | self.inplanes, self.planes = inplanes, planes 52 | self.b = eta * eta 53 | self.b_reduce = (eta-1) * (eta-1) 54 | flops_conv1_full = torch.Tensor([9 * self.height * self.width * planes * inplanes]) 55 | flops_conv2_full = torch.Tensor([9 * self.height * self.width * planes * planes]) 56 | # downsample flops 57 | self.flops_downsample = torch.Tensor([self.height*self.width*planes*inplanes] 58 | )if downsample is not None else torch.Tensor([0]) 59 | # full flops 60 | self.flops_full = flops_conv1_full + flops_conv2_full + self.flops_downsample 61 | # mask flops 62 | flops_mks = self.mask_s.get_flops() 63 | flops_mkc = self.mask_c.get_flops() 64 | self.flops_mask = torch.Tensor([flops_mks + flops_mkc]) 65 | 66 | def forward(self, input): 67 | x, norm_1, norm_2, flops = input 68 | residual = x 69 | # spatial mask 70 | mask_s_m, norm_s, norm_s_t = self.mask_s(x) # [N, 1, h, w] 71 | mask_s = self.upsample(mask_s_m) # [N, 1, H, W] 72 | # conv 1 73 | mask_c, norm_c, norm_c_t = self.mask_c(x) # [N, C_out, 1, 1] 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | if not self.training: 78 | out = out * mask_c * mask_s 79 | else: 80 | out = out * mask_c 81 | # conv 2 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = out * mask_s 85 | # identity 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | out += residual 89 | out = self.relu(out) 90 | # flops 91 | flops_blk = self.get_flops(mask_s_m, mask_s, mask_c) 92 | flops = torch.cat((flops, flops_blk.unsqueeze(0))) 93 | # norm 94 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0))) 95 | norm_2 = torch.cat((norm_2, torch.cat((norm_c, norm_c_t)).unsqueeze(0))) 96 | return (out, norm_1, norm_2, flops) 97 | 98 | def get_flops(self, mask_s, mask_s_up, mask_c): 99 | s_sum = mask_s.sum((1,2,3)) 100 | c_sum = mask_c.sum((1,2,3)) 101 | # conv1 102 | flops_conv1 = 9 * self.b * s_sum * c_sum * self.inplanes 103 | # conv2 104 | flops_conv2 = 9 * self.b * s_sum * self.planes * c_sum 105 | # total 106 | flops = flops_conv1 + flops_conv2 + self.flops_downsample.to(flops_conv1.device) 107 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device))) 108 | 109 | 110 | class ResNetCifar10(nn.Module): 111 | def __init__(self, depth, num_classes=10, h=32, w=32, **kwargs): 112 | super(ResNetCifar10, self).__init__() 113 | self.height, self.width = h, w 114 | # Model type specifies number of layers for CIFAR-10 model 115 | n = (depth - 2) // 6 116 | block = BasicBlock 117 | # norm 118 | self._norm_layer = nn.BatchNorm2d 119 | # conv1 120 | self.inplanes = 16 121 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,bias=False) 122 | self.bn1 = nn.BatchNorm2d(16) 123 | self.relu = nn.ReLU(inplace=True) 124 | # residual blocks 125 | self.layer1, h, w = self._make_layer(block, 16, n, h, w, 4, **kwargs) 126 | self.layer2, h, w = self._make_layer(block, 32, n, h, w, 2, stride=2, **kwargs) 127 | self.layer3, h, w = self._make_layer(block, 64, n, h, w, 2, stride=2, **kwargs) 128 | self.avgpool = nn.AvgPool2d(8) 129 | self.fc = nn.Linear(64 * block.expansion, num_classes) 130 | # flops 131 | self.flops_conv1 = torch.Tensor([9 * self.height * self.width * 16 * 3]) 132 | self.flops_fc = torch.Tensor([64 * block.expansion * num_classes]) 133 | # criterion 134 | self.criterion = None 135 | 136 | for m in self.modules(): 137 | if isinstance(m, nn.Conv2d): 138 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 139 | m.weight.data.normal_(0, math.sqrt(2. / n)) 140 | elif isinstance(m, nn.BatchNorm2d): 141 | if m.weight is not None and m.bias is not None: 142 | m.weight.data.fill_(1) 143 | m.bias.data.zero_() 144 | 145 | def _make_layer(self, block, planes, blocks, h, w, tile, stride=1, **kwargs): 146 | norm_layer = self._norm_layer 147 | downsample = None 148 | if stride != 1 or self.inplanes != planes * block.expansion: 149 | downsample = nn.Sequential( 150 | nn.Conv2d(self.inplanes, planes * block.expansion, 151 | kernel_size=1, stride=stride, bias=False), 152 | nn.BatchNorm2d(planes * block.expansion), 153 | ) 154 | layers = [] 155 | layers.append(block(self.inplanes, planes, h, w, tile, 156 | stride, downsample, **kwargs)) 157 | h = conv2d_out_dim(h, kernel_size=1, stride=stride, padding=0) 158 | w = conv2d_out_dim(w, kernel_size=1, stride=stride, padding=0) 159 | self.inplanes = planes * block.expansion 160 | for i in range(1, blocks): 161 | layers.append(block(self.inplanes, planes, h, w, tile, **kwargs)) 162 | return nn.Sequential(*layers), h, w 163 | 164 | def forward(self, x, label, den_target, lbda, gamma, p): 165 | batch_num, _, _, _ = x.shape 166 | # conv1 167 | x = self.conv1(x) 168 | x = self.bn1(x) 169 | x = self.relu(x) # 32x32 170 | # residual blocks 171 | norm1 = torch.zeros(1, batch_num+1).to(x.device) 172 | norm2 = torch.zeros(1, batch_num+1).to(x.device) 173 | flops = torch.zeros(1, batch_num+2).to(x.device) 174 | x = self.layer1((x, norm1, norm2, flops)) # 32x32 175 | x = self.layer2(x) # 16x16 176 | x, norm1, norm2, flops = self.layer3(x) # 8x8 177 | # fc layer 178 | x = self.avgpool(x) 179 | x = x.view(x.size(0), -1) 180 | x = self.fc(x) 181 | # flops 182 | flops_real = [flops[1:, 0:batch_num].permute(1, 0).contiguous(), 183 | self.flops_conv1.to(x.device), self.flops_fc.to(x.device)] 184 | flops_mask, flops_ori = flops[1:, -2].unsqueeze(0), flops[1:, -1].unsqueeze(0) 185 | # norm 186 | norm_s = norm1[1:, 0:batch_num].permute(1, 0).contiguous() 187 | norm_c = norm2[1:, 0:batch_num].permute(1, 0).contiguous() 188 | norm_s_t, norm_c_t = norm1[1:, -1].unsqueeze(0), norm2[1:, -1].unsqueeze(0) 189 | # get outputs 190 | outputs = {} 191 | outputs["closs"], outputs["rloss"], outputs["bloss"] = self.get_loss( 192 | x, label, batch_num, den_target, lbda, gamma, p, 193 | norm_s, norm_c, norm_s_t, norm_c_t, 194 | flops_real, flops_mask, flops_ori) 195 | outputs["out"] = x 196 | outputs["flops_real"] = flops_real 197 | outputs["flops_mask"] = flops_mask 198 | outputs["flops_ori"] = flops_ori 199 | return outputs 200 | 201 | def set_criterion(self, criterion): 202 | self.criterion = criterion 203 | return 204 | 205 | def get_loss(self, output, label, batch_size, den_target, lbda, gamma, p, 206 | mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, 207 | flops_real, flops_mask, flops_ori): 208 | closs, rloss, bloss = self.criterion(output, label, flops_real, flops_mask, 209 | flops_ori, batch_size, den_target, lbda, mask_norm_s, mask_norm_c, 210 | norm_s_t, norm_c_t, gamma, p) 211 | return closs, rloss, bloss 212 | 213 | def record_flops(self, flops_conv, flops_mask, flops_ori, flops_conv1, flops_fc): 214 | i = 0 215 | table = PrettyTable(['Layer', 'Conv FLOPs', 'Conv %', 'Mask FLOPs', 'Total FLOPs', 'Total %', 'Original FLOPs']) 216 | table.add_row(['layer0'] + ['{flops:.2f}K'.format(flops=flops_conv1/1024)] + [' ' for _ in range(5)]) 217 | for name, m in self.named_modules(): 218 | if isinstance(m, BasicBlock): 219 | table.add_row([name] + ['{flops:.2f}K'.format(flops=flops_conv[i]/1024)] + ['{per_f:.2f}%'.format( 220 | per_f=flops_conv[i]/flops_ori[i]*100)] + ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] + 221 | ['{total:.2f}K'.format(total=(flops_conv[i]+flops_mask[i])/1024)] + ['{per_t:.2f}%'.format( 222 | per_t=(flops_conv[i]+flops_mask[i])/flops_ori[i]*100)] + 223 | ['{ori:.2f}K'.format(ori=flops_ori[i]/1024)]) 224 | i+=1 225 | table.add_row(['fc'] + ['{flops:.2f}K'.format(flops=flops_fc/1024)] + [' ' for _ in range(5)]) 226 | table.add_row(['Total'] + ['{flops:.2f}K'.format(flops=(flops_conv[i]+flops_conv1+flops_fc)/1024)] + 227 | ['{per_f:.2f}%'.format(per_f=(flops_conv[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] + 228 | ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] + ['{total:.2f}K'.format( 229 | total=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/1024)] + ['{per_t:.2f}%'.format( 230 | per_t=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] + 231 | ['{ori:.2f}K'.format(ori=(flops_ori[i]+flops_conv1+flops_fc)/1024)]) 232 | logging.info('\n{}'.format(table)) 233 | 234 | 235 | def resdg20_cifar10(**kwargs): 236 | """ 237 | return a ResNet 20 object for cifar-10. 238 | """ 239 | return ResNetCifar10(20, **kwargs) 240 | 241 | 242 | def resdg32_cifar10(**kwargs): 243 | """ 244 | return a ResNet 32 object for cifar-10. 245 | """ 246 | return ResNetCifar10(32, **kwargs) 247 | 248 | 249 | def resdg56_cifar10(**kwargs): 250 | """ 251 | return a ResNet 56 object for cifar-10. 252 | """ 253 | return ResNetCifar10(56, **kwargs) 254 | 255 | 256 | def resdg110_cifar10(**kwargs): 257 | """ 258 | return a ResNet 110 object for cifar-10. 259 | """ 260 | return ResNetCifar10(110, **kwargs) 261 | -------------------------------------------------------------------------------- /models/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .resdg import * 2 | -------------------------------------------------------------------------------- /models/imagenet/resdg.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | from prettytable import PrettyTable 6 | from ..mask import Mask_s, Mask_c 7 | 8 | __all__ = ['resdg18', 'resdg34', 'resdg50'] 9 | 10 | 11 | def conv1x1(in_planes, out_planes, stride=1): 12 | """1x1 convolution""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 14 | 15 | 16 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=dilation, groups=groups, bias=False, dilation=dilation) 20 | 21 | 22 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False): 23 | if ceil_mode: 24 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 25 | else: 26 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | def __init__(self, inplanes, planes, h, w, eta=8, stride=1, 32 | downsample=None, groups=1, base_width=64, dilation=1, 33 | norm_layer=None, **kwargs): 34 | super(BasicBlock, self).__init__() 35 | if norm_layer is None: 36 | norm_layer = nn.BatchNorm2d 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | # gating modules 42 | self.height = conv2d_out_dim(h, kernel_size=3, stride=stride, padding=1) 43 | self.width = conv2d_out_dim(w, kernel_size=3, stride=stride, padding=1) 44 | self.mask_s = Mask_s(self.height, self.width, inplanes, eta, eta, **kwargs) 45 | self.mask_c = Mask_c(inplanes, planes, **kwargs) 46 | self.upsample = nn.Upsample(size=(self.height, self.width), mode='nearest') 47 | # conv 1 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | # conv 2 52 | self.conv2 = conv3x3(planes, planes) 53 | self.bn2 = norm_layer(planes) 54 | # misc 55 | self.downsample = downsample 56 | self.inplanes, self.planes = inplanes, planes 57 | # flops 58 | flops_conv1_full = torch.Tensor([9 * self.height * self.width * planes * inplanes]) 59 | flops_conv2_full = torch.Tensor([9 * self.height * self.width * planes * planes]) 60 | self.flops_downsample = torch.Tensor([self.height*self.width*planes*inplanes] 61 | )if downsample is not None else torch.Tensor([0]) 62 | self.flops_full = flops_conv1_full + flops_conv2_full + self.flops_downsample 63 | # mask flops 64 | flops_mks = self.mask_s.get_flops() 65 | flops_mkc = self.mask_c.get_flops() 66 | self.flops_mask = torch.Tensor([flops_mks + flops_mkc]) 67 | 68 | def forward(self, input): 69 | x, norm_1, norm_2, flops = input 70 | residual = x 71 | mask_s_m, norm_s, norm_s_t = self.mask_s(x) # [N, 1, h, w] 72 | mask_c, norm_c, norm_c_t = self.mask_c(x) # [N, C_out, 1, 1] 73 | mask_s = self.upsample(mask_s_m) # [N, 1, H, W] 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | out = out * mask_c * mask_s if not self.training else out * mask_c 78 | # conv 2 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = out * mask_s 82 | # identity 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | out += residual 86 | out = self.relu(out) 87 | # norm 88 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0))) 89 | norm_2 = torch.cat((norm_2, torch.cat((norm_c, norm_c_t)).unsqueeze(0))) 90 | # flops 91 | flops_blk = self.get_flops(mask_s, mask_c) 92 | flops = torch.cat((flops, flops_blk.unsqueeze(0))) 93 | return (out, norm_1, norm_2, flops) 94 | 95 | def get_flops(self, mask_s_up, mask_c): 96 | s_sum = mask_s_up.sum((1,2,3)) 97 | c_sum = mask_c.sum((1,2,3)) 98 | # conv1 99 | flops_conv1 = 9 * s_sum * c_sum * self.inplanes 100 | # conv2 101 | flops_conv2 = 9 * s_sum * c_sum * self.planes 102 | # total 103 | flops = flops_conv1 + flops_conv2 + self.flops_downsample.to(flops_conv1.device) 104 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device))) 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | expansion = 4 109 | __constants__ = ['downsample'] 110 | def __init__(self, inplanes, planes, h, w, eta=8, stride=1, 111 | downsample=None, groups=1, base_width=64, dilation=1, 112 | norm_layer=None, **kwargs): 113 | super(Bottleneck, self).__init__() 114 | if norm_layer is None: 115 | norm_layer = nn.BatchNorm2d 116 | width = int(planes * (base_width / 64.)) * groups 117 | # spatial gating module 118 | self.height_1, self.width_1 = h, w 119 | self.height_2 = conv2d_out_dim(h, 3, dilation, stride, dilation) 120 | self.width_2 = conv2d_out_dim(w, 3, dilation, stride, dilation) 121 | self.mask_s = Mask_s(self.height_2, self.width_2, inplanes, eta, eta, **kwargs) 122 | self.upsample_1 = nn.Upsample(size=(self.height_1, self.width_1), mode='nearest') 123 | self.upsample_2 = nn.Upsample(size=(self.height_2, self.width_2), mode='nearest') 124 | # conv 1 125 | self.conv1 = conv1x1(inplanes, width) 126 | self.bn1 = norm_layer(width) 127 | self.mask_c1 = Mask_c(inplanes, width, **kwargs) 128 | # conv 2 129 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 130 | self.bn2 = norm_layer(width) 131 | self.mask_c2 = Mask_c(width, width, **kwargs) 132 | # conv 3 133 | self.conv3 = conv1x1(width, planes * self.expansion) 134 | self.bn3 = norm_layer(planes * self.expansion) 135 | # misc 136 | self.relu = nn.ReLU(inplace=True) 137 | self.downsample = downsample 138 | self.inplanes, self.width, self.planes = inplanes, width, planes * self.expansion 139 | # flops 140 | flops_conv1_full = torch.Tensor([self.height_1 * self.width_1 * width * inplanes]) 141 | flops_conv2_full = torch.Tensor([9 * self.height_2 * self.width_2 * width * width]) 142 | flops_conv3_full = torch.Tensor([self.height_2 * self.width_2 * width * planes*self.expansion]) 143 | self.flops_downsample = torch.Tensor([self.height_2*self.width_2*planes*self.expansion*inplanes] 144 | ) if self.downsample is not None else torch.Tensor([0]) 145 | self.flops_full = flops_conv1_full+flops_conv2_full+flops_conv3_full+self.flops_downsample 146 | # mask flops 147 | flops_mask_s = self.mask_s.get_flops() 148 | flops_mask_c1 = self.mask_c1.get_flops() 149 | flops_mask_c2 = self.mask_c2.get_flops() 150 | self.flops_mask = torch.Tensor([flops_mask_s + flops_mask_c1 + flops_mask_c2]) 151 | 152 | def forward(self, input): 153 | x, norm_1, norm_2, flops = input 154 | identity = x 155 | # spatial mask 156 | mask_s_m, norm_s, norm_s_t = self.mask_s(x) # [N, 1, h, w] 157 | mask_c1, norm_c1, norm_c1_t = self.mask_c1(x) 158 | mask_s1 = self.upsample_1(mask_s_m) # [N, 1, H1, W1] 159 | mask_s = self.upsample_2(mask_s_m) # [N, 1, H2, W2] 160 | # conv 1 161 | out = self.conv1(x) 162 | out = self.bn1(out) 163 | out = self.relu(out) 164 | out = out * mask_c1 * mask_s1 if not self.training else out * mask_c1 165 | # conv 2 166 | mask_c2, norm_c2, norm_c2_t = self.mask_c2(out) 167 | out = self.conv2(out) 168 | out = self.bn2(out) 169 | out = self.relu(out) 170 | out = out * mask_c2 * mask_s if not self.training else out * mask_c2 171 | # conv 3 172 | out = self.conv3(out) 173 | out = self.bn3(out) 174 | out = out * mask_s 175 | # identity 176 | if self.downsample is not None: 177 | identity = self.downsample(x) 178 | out += identity 179 | out = self.relu(out) 180 | # norm 181 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0))) 182 | norm_2 = torch.cat((norm_2, torch.cat((norm_c1, norm_c1_t)).unsqueeze(0))) 183 | norm_2 = torch.cat((norm_2, torch.cat((norm_c2, norm_c2_t)).unsqueeze(0))) 184 | # flops 185 | flops_blk = self.get_flops(mask_s, mask_s1, mask_c1, mask_c2) 186 | flops = torch.cat((flops, flops_blk.unsqueeze(0))) 187 | return (out, norm_1, norm_2, flops) 188 | 189 | def get_flops(self, mask_s, mask_s1, mask_c1, mask_c2): 190 | s_sum = mask_s.sum((1,2,3)) 191 | c1_sum, c2_sum = mask_c1.sum((1,2,3)), mask_c2.sum((1,2,3)) 192 | # conv 193 | s_sum_1 = mask_s1.sum((1,2,3)) 194 | flops_conv1 = s_sum_1 * c1_sum * self.inplanes 195 | flops_conv2 = 9 * s_sum * c2_sum * c1_sum 196 | flops_conv3 = s_sum * self.planes * c2_sum 197 | # total 198 | flops = flops_conv1+flops_conv2+flops_conv3+self.flops_downsample.to(flops_conv1.device) 199 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device))) 200 | 201 | 202 | class ResDG(nn.Module): 203 | 204 | def __init__(self, block, layers, h=224, w=224, num_classes=1000, 205 | zero_init_residual=False, groups=1, width_per_group=64, 206 | replace_stride_with_dilation=None, norm_layer=None, **kwargs): 207 | super(ResDG, self).__init__() 208 | # block 209 | self.height, self.width = h, w 210 | # norm layer 211 | if norm_layer is None: 212 | norm_layer = nn.BatchNorm2d 213 | self._norm_layer = norm_layer 214 | 215 | self.inplanes = 64 216 | self.dilation = 1 217 | if replace_stride_with_dilation is None: 218 | # each element in the tuple indicates if we should replace 219 | # the 2x2 stride with a dilated convolution instead 220 | replace_stride_with_dilation = [False, False, False] 221 | if len(replace_stride_with_dilation) != 3: 222 | raise ValueError("replace_stride_with_dilation should be None " 223 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 224 | self.groups = groups 225 | self.base_width = width_per_group 226 | # conv1 227 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 228 | self.bn1 = norm_layer(self.inplanes) 229 | self.relu = nn.ReLU(inplace=True) 230 | h = conv2d_out_dim(h, kernel_size=7, stride=2, padding=3) 231 | w = conv2d_out_dim(w, kernel_size=7, stride=2, padding=3) 232 | self.flops_conv1 = torch.Tensor([49 * h * w * self.inplanes * 3]) 233 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 234 | h = conv2d_out_dim(h, kernel_size=3, stride=2, padding=1) 235 | w = conv2d_out_dim(w, kernel_size=3, stride=2, padding=1) 236 | # residual blocks 237 | self.layer1, h, w = self._make_layer(block, 64, layers[0], h, w, 8, **kwargs) 238 | self.layer2, h, w = self._make_layer(block, 128, layers[1], h, w, 4, stride=2, 239 | dilate=replace_stride_with_dilation[0], **kwargs) 240 | self.layer3, h, w = self._make_layer(block, 256, layers[2], h, w, 2, stride=2, 241 | dilate=replace_stride_with_dilation[1], **kwargs) 242 | self.layer4, h, w = self._make_layer(block, 512, layers[3], h, w, 1, stride=2, 243 | dilate=replace_stride_with_dilation[2], **kwargs) 244 | # fc layer 245 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 246 | self.fc = nn.Linear(512 * block.expansion, num_classes) 247 | self.flops_fc = torch.Tensor([512 * block.expansion * num_classes]) 248 | # criterion 249 | self.criterion = None 250 | 251 | for m in self.modules(): 252 | if isinstance(m, nn.Conv2d): 253 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 254 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 255 | nn.init.constant_(m.weight, 1) 256 | nn.init.constant_(m.bias, 0) 257 | 258 | # Zero-initialize the last BN in each residual branch, 259 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 260 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 261 | if zero_init_residual: 262 | for m in self.modules(): 263 | if isinstance(m, Bottleneck): 264 | nn.init.constant_(m.bn3.weight, 0) 265 | elif isinstance(m, BasicBlock): 266 | nn.init.constant_(m.bn2.weight, 0) 267 | 268 | def _make_layer(self, block, planes, blocks, h, w, tile, stride=1, dilate=False, **kwargs): 269 | norm_layer, downsample, previous_dilation = self._norm_layer, None, self.dilation 270 | mask_s = torch.ones(blocks) 271 | if dilate: 272 | self.dilation *= stride 273 | stride = 1 274 | if stride != 1 or self.inplanes != planes * block.expansion: 275 | downsample = nn.Sequential( 276 | conv1x1(self.inplanes, planes * block.expansion, stride), 277 | norm_layer(planes * block.expansion), 278 | ) 279 | layers = [] 280 | layers.append(block(self.inplanes, planes, h, w, tile, stride, downsample, 281 | self.groups, self.base_width, previous_dilation, norm_layer, **kwargs)) 282 | h = conv2d_out_dim(h, kernel_size=1, stride=stride, padding=0) 283 | w = conv2d_out_dim(w, kernel_size=1, stride=stride, padding=0) 284 | self.inplanes = planes * block.expansion 285 | for i in range(1, blocks): 286 | layers.append(block(self.inplanes, planes, h, w, tile, groups=self.groups, 287 | base_width=self.base_width, dilation=self.dilation, 288 | norm_layer=norm_layer,**kwargs)) 289 | return nn.Sequential(*layers), h, w 290 | 291 | def forward(self, x, label, den_target, lbda, gamma, p): 292 | # See note [TorchScript super()] 293 | batch_num, _, _, _ = x.shape 294 | # conv1 295 | x = self.conv1(x) 296 | x = self.bn1(x) 297 | x = self.relu(x) 298 | x = self.maxpool(x) 299 | # residual modules 300 | norm1 = torch.zeros(1, batch_num+1).to(x.device) 301 | norm2 = torch.zeros(1, batch_num+1).to(x.device) 302 | flops = torch.zeros(1, batch_num+2).to(x.device) 303 | x = self.layer1((x, norm1, norm2, flops)) 304 | x = self.layer2(x) 305 | x = self.layer3(x) 306 | x, norm1, norm2, flops = self.layer4(x) 307 | # fc layer 308 | x = self.avgpool(x) 309 | x = torch.flatten(x, 1) 310 | x = self.fc(x) 311 | # norm and flops 312 | norm_s = norm1[1:, 0:batch_num].permute(1, 0).contiguous() 313 | norm_c = norm2[1:, 0:batch_num].permute(1, 0).contiguous() 314 | norm_s_t = norm1[1:, -1].unsqueeze(0) 315 | norm_c_t = norm2[1:, -1].unsqueeze(0) 316 | flops_real = [flops[1:, 0:batch_num].permute(1, 0).contiguous(), 317 | self.flops_conv1.to(x.device), self.flops_fc.to(x.device)] 318 | flops_mask = flops[1:, -2].unsqueeze(0) 319 | flops_ori = flops[1:, -1].unsqueeze(0) 320 | # get outputs 321 | outputs = {} 322 | outputs["closs"], outputs["rloss"], outputs["bloss"] = self.get_loss( 323 | x, label, batch_num, den_target, lbda, gamma, p, 324 | norm_s, norm_c, norm_s_t, norm_c_t, 325 | flops_real, flops_mask, flops_ori) 326 | outputs["out"] = x 327 | outputs["flops_real"] = flops_real 328 | outputs["flops_mask"] = flops_mask 329 | outputs["flops_ori"] = flops_ori 330 | return outputs 331 | 332 | def set_criterion(self, criterion): 333 | self.criterion = criterion 334 | return 335 | 336 | def get_loss(self, output, label, batch_size, den_target, lbda, gamma, p, 337 | mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, 338 | flops_real, flops_mask, flops_ori): 339 | closs, rloss, bloss = self.criterion(output, label, flops_real, flops_mask, 340 | flops_ori, batch_size, den_target, lbda, mask_norm_s, mask_norm_c, 341 | norm_s_t, norm_c_t, gamma, p) 342 | return closs, rloss, bloss 343 | 344 | def record_flops(self, flops_conv, flops_mask, flops_ori, flops_conv1, flops_fc): 345 | i = 0 346 | table = PrettyTable(['Layer', 'Conv FLOPs', 'Conv %', 'Mask FLOPs', 'Total FLOPs', 'Total %', 'Original FLOPs']) 347 | table.add_row(['layer0'] + ['{flops:.2f}K'.format(flops=flops_conv1/1024)] + [' ' for _ in range(5)]) 348 | for name, m in self.named_modules(): 349 | if isinstance(m, (BasicBlock, Bottleneck)): 350 | table.add_row([name] + ['{flops:.2f}K'.format(flops=flops_conv[i]/1024)] + ['{per_f:.2f}%'.format( 351 | per_f=flops_conv[i]/flops_ori[i]*100)] + ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] + 352 | ['{total:.2f}K'.format(total=(flops_conv[i]+flops_mask[i])/1024)] + ['{per_t:.2f}%'.format( 353 | per_t=(flops_conv[i]+flops_mask[i])/flops_ori[i]*100)] + 354 | ['{ori:.2f}K'.format(ori=flops_ori[i]/1024)]) 355 | i+=1 356 | table.add_row(['fc'] + ['{flops:.2f}K'.format(flops=flops_fc/1024)] + [' ' for _ in range(5)]) 357 | table.add_row(['Total'] + ['{flops:.2f}K'.format(flops=(flops_conv[i]+flops_conv1+flops_fc)/1024)] + 358 | ['{per_f:.2f}%'.format(per_f=(flops_conv[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] + 359 | ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] + ['{total:.2f}K'.format( 360 | total=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/1024)] + ['{per_t:.2f}%'.format( 361 | per_t=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] + 362 | ['{ori:.2f}K'.format(ori=(flops_ori[i]+flops_conv1+flops_fc)/1024)]) 363 | logging.info('\n{}'.format(table)) 364 | 365 | 366 | def _resdg(arch, block, layers, **kwargs): 367 | model = ResDG(block, layers, **kwargs) 368 | return model 369 | 370 | 371 | def resdg18(**kwargs): 372 | r"""ResNet-18 model from 373 | `"Deep Residual Learning for Image Recognition" `_ 374 | """ 375 | return _resdg('resdg18', BasicBlock, [2, 2, 2, 2], **kwargs) 376 | 377 | 378 | def resdg34(**kwargs): 379 | r"""ResNet-34 model from 380 | `"Deep Residual Learning for Image Recognition" `_ 381 | """ 382 | return _resdg('resdg34', BasicBlock, [3, 4, 6, 3], **kwargs) 383 | 384 | 385 | def resdg50(**kwargs): 386 | r"""ResNet-50 model from 387 | `"Deep Residual Learning for Image Recognition" `_ 388 | """ 389 | return _resdg('resdg50', Bottleneck, [3, 4, 6, 3], **kwargs) 390 | -------------------------------------------------------------------------------- /models/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | 8 | class GumbelSoftmax(nn.Module): 9 | ''' 10 | gumbel softmax gate. 11 | ''' 12 | def __init__(self, eps=1): 13 | super(GumbelSoftmax, self).__init__() 14 | self.eps = eps 15 | self.sigmoid = nn.Sigmoid() 16 | 17 | def gumbel_sample(self, template_tensor, eps=1e-8): 18 | uniform_samples_tensor = template_tensor.clone().uniform_() 19 | gumble_samples_tensor = torch.log(uniform_samples_tensor+eps)-torch.log( 20 | 1-uniform_samples_tensor+eps) 21 | return gumble_samples_tensor 22 | 23 | def gumbel_softmax(self, logits): 24 | """ Draw a sample from the Gumbel-Softmax distribution""" 25 | gsamples = self.gumbel_sample(logits.data) 26 | logits = logits + Variable(gsamples) 27 | soft_samples = self.sigmoid(logits / self.eps) 28 | return soft_samples, logits 29 | 30 | def forward(self, logits): 31 | if not self.training: 32 | out_hard = (logits>=0).float() 33 | return out_hard 34 | out_soft, prob_soft = self.gumbel_softmax(logits) 35 | out_hard = ((out_soft >= 0.5).float() - out_soft).detach() + out_soft 36 | return out_hard 37 | 38 | 39 | class Mask_s(nn.Module): 40 | ''' 41 | Attention Mask spatial. 42 | ''' 43 | def __init__(self, h, w, planes, block_w, block_h, eps=0.66667, 44 | bias=-1, **kwargs): 45 | super(Mask_s, self).__init__() 46 | # Parameter 47 | self.width, self.height, self.channel = w, h, planes 48 | self.mask_h, self.mask_w = int(np.ceil(h / block_h)), int(np.ceil(w / block_w)) 49 | self.eleNum_s = torch.Tensor([self.mask_h*self.mask_w]) 50 | # spatial attention 51 | self.atten_s = nn.Conv2d(planes, 1, kernel_size=3, stride=1, bias=bias>=0, padding=1) 52 | if bias>=0: 53 | nn.init.constant_(self.atten_s.bias, bias) 54 | # Gate 55 | self.gate_s = GumbelSoftmax(eps=eps) 56 | # Norm 57 | self.norm = lambda x: torch.norm(x, p=1, dim=(1,2,3)) 58 | 59 | def forward(self, x): 60 | batch, channel, height, width = x.size() 61 | # Pooling 62 | input_ds = F.adaptive_avg_pool2d(input=x, output_size=(self.mask_h, self.mask_w)) 63 | # spatial attention 64 | s_in = self.atten_s(input_ds) # [N, 1, h, w] 65 | # spatial gate 66 | mask_s = self.gate_s(s_in) # [N, 1, h, w] 67 | # norm 68 | norm = self.norm(mask_s) 69 | norm_t = self.eleNum_s.to(x.device) 70 | return mask_s, norm, norm_t 71 | 72 | def get_flops(self): 73 | flops = self.mask_h * self.mask_w * self.channel * 9 74 | return flops 75 | 76 | 77 | class Mask_c(nn.Module): 78 | ''' 79 | Attention Mask. 80 | ''' 81 | def __init__(self, inplanes, outplanes, fc_reduction=4, eps=0.66667, bias=-1, **kwargs): 82 | super(Mask_c, self).__init__() 83 | # Parameter 84 | self.bottleneck = inplanes // fc_reduction 85 | self.inplanes, self.outplanes = inplanes, outplanes 86 | self.eleNum_c = torch.Tensor([outplanes]) 87 | # channel attention 88 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 89 | self.atten_c = nn.Sequential( 90 | nn.Conv2d(inplanes, self.bottleneck, kernel_size=1, stride=1, bias=False), 91 | nn.BatchNorm2d(self.bottleneck), 92 | nn.ReLU(inplace=True), 93 | nn.Conv2d(self.bottleneck, outplanes, kernel_size=1, stride=1, bias=bias>=0), 94 | ) 95 | if bias>=0: 96 | nn.init.constant_(self.atten_c[3].bias, bias) 97 | # Gate 98 | self.gate_c = GumbelSoftmax(eps=eps) 99 | # Norm 100 | self.norm = lambda x: torch.norm(x, p=1, dim=(1,2,3)) 101 | 102 | def forward(self, x): 103 | batch, channel, _, _ = x.size() 104 | context = self.avg_pool(x) # [N, C, 1, 1] 105 | # transform 106 | c_in = self.atten_c(context) # [N, C_out, 1, 1] 107 | # channel gate 108 | mask_c = self.gate_c(c_in) # [N, C_out, 1, 1] 109 | # norm 110 | norm = self.norm(mask_c) 111 | norm_t = self.eleNum_c.to(x.device) 112 | return mask_c, norm, norm_t 113 | 114 | def get_flops(self): 115 | flops = self.inplanes * self.bottleneck + self.bottleneck * self.outplanes 116 | return flops 117 | -------------------------------------------------------------------------------- /models/mobilenet_v2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .mobilenet_v2_dg import * 4 | -------------------------------------------------------------------------------- /models/mobilenet_v2/mobilenet_v2_dg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import logging 4 | from torch import nn 5 | from prettytable import PrettyTable 6 | from .mobilenet_v2_dg_util import ConvBNReLU_1st, InvertedResidual 7 | 8 | 9 | __all__ = ['mobilenet_v2_dg'] 10 | 11 | 12 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False): 13 | if ceil_mode: 14 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 15 | else: 16 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 17 | 18 | 19 | def _make_divisible(v, divisor, min_value=None): 20 | """ 21 | This function is taken from the original tf repo. 22 | It ensures that all layers have a channel number that is divisible by 8 23 | It can be seen here: 24 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 25 | :param v: 26 | :param divisor: 27 | :param min_value: 28 | :return: 29 | """ 30 | if min_value is None: 31 | min_value = divisor 32 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 33 | # Make sure that round down does not go down by more than 10%. 34 | if new_v < 0.9 * v: 35 | new_v += divisor 36 | return new_v 37 | 38 | 39 | class MobileNetV2(nn.Module): 40 | def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, 41 | round_nearest=8, in_size=(224, 224), block = InvertedResidual, **kwargs): 42 | """ 43 | MobileNet V2 main class 44 | 45 | Args: 46 | num_classes (int): Number of classes 47 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 48 | inverted_residual_setting: Network structure 49 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 50 | Set to 1 to turn off rounding 51 | """ 52 | super(MobileNetV2, self).__init__() 53 | input_channel = 32 54 | last_channel = 1280 55 | h, w = in_size 56 | 57 | if inverted_residual_setting is None: 58 | inverted_residual_setting = [ 59 | # t, c, n, s, tile 60 | [1, 16, 1, 1, 16], 61 | [6, 24, 2, 2, 8], 62 | [6, 32, 3, 2, 4], 63 | [6, 64, 4, 2, 2], 64 | [6, 96, 3, 1, 2], 65 | [6, 160, 3, 2, 2], 66 | [6, 320, 1, 1, 2], 67 | ] 68 | # only check the first element, assuming user knows t,c,n,s are required 69 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 5: 70 | raise ValueError("inverted_residual_setting should be non-empty " 71 | "or a 5-element list, got {}".format(inverted_residual_setting)) 72 | # building first layer 73 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 74 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 75 | features = [ConvBNReLU_1st(3, input_channel, stride=2)] 76 | h = conv2d_out_dim(h, kernel_size=3, stride=2, padding=1) 77 | w = conv2d_out_dim(w, kernel_size=3, stride=2, padding=1) 78 | self.flops_conv1 = torch.Tensor([3 * h * w * 3 * input_channel]) 79 | # building inverted residual blocks 80 | for t, c, n, s, tile in inverted_residual_setting: 81 | output_channel = _make_divisible(c * width_mult, round_nearest) 82 | for i in range(n): 83 | stride = s if i == 0 else 1 84 | features.append(block(input_channel, output_channel, stride, 85 | expand_ratio=t, h=h, w=w, eta=tile, **kwargs)) 86 | h = conv2d_out_dim(h, kernel_size=3, stride=stride, padding=1) 87 | w = conv2d_out_dim(w, kernel_size=3, stride=stride, padding=1) 88 | input_channel = output_channel 89 | # building last several layers 90 | features.append(ConvBNReLU_1st(input_channel, self.last_channel, kernel_size=1)) 91 | self.flops_fc = torch.Tensor([input_channel * self.last_channel *h*w]) 92 | # make it nn.Sequential 93 | self.features = nn.Sequential(*features) 94 | # building classifier 95 | self.classifier = nn.Sequential( 96 | nn.Dropout(0.2), 97 | nn.Linear(self.last_channel, num_classes), 98 | ) 99 | self.flops_fc = self.flops_fc + torch.Tensor([num_classes * self.last_channel]) 100 | # criterion 101 | self.criterion = None 102 | 103 | # weight initialization 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 107 | if m.bias is not None: 108 | nn.init.zeros_(m.bias) 109 | elif isinstance(m, nn.BatchNorm2d): 110 | nn.init.ones_(m.weight) 111 | nn.init.zeros_(m.bias) 112 | elif isinstance(m, nn.Linear): 113 | nn.init.normal_(m.weight, 0, 0.01) 114 | nn.init.zeros_(m.bias) 115 | 116 | def forward(self, x, label, den_target, lbda, gamma, p): 117 | batch_num, _, _, _ = x.shape 118 | norm1 = torch.zeros(1, batch_num+1).to(x.device) 119 | norm2 = torch.zeros(1, batch_num+1).to(x.device) 120 | flops = torch.zeros(1, batch_num+2).to(x.device) 121 | x, norm1, norm2, flops = self.features((x, norm1, norm2, flops)) 122 | x = x.mean([2, 3]) 123 | x = self.classifier(x) 124 | # norm and flops 125 | norm_s = norm1[1:, 0:batch_num].permute(1, 0).contiguous() 126 | norm_c = norm2[1:, 0:batch_num].permute(1, 0).contiguous() 127 | norm_s_t = norm1[1:, -1].unsqueeze(0) 128 | norm_c_t = norm2[1:, -1].unsqueeze(0) 129 | flops_real = [flops[1:, 0:batch_num].permute(1, 0).contiguous(), 130 | self.flops_conv1.to(x.device), self.flops_fc.to(x.device)] 131 | flops_mask = flops[1:, -2].unsqueeze(0) 132 | flops_ori = flops[1:, -1].unsqueeze(0) 133 | # get outputs 134 | outputs = {} 135 | outputs["closs"], outputs["rloss"], outputs["bloss"] = self.get_loss( 136 | x, label, batch_num, den_target, lbda, gamma, p, 137 | norm_s, norm_c, norm_s_t, norm_c_t, 138 | flops_real, flops_mask, flops_ori) 139 | outputs["out"] = x 140 | outputs["flops_real"] = flops_real 141 | outputs["flops_mask"] = flops_mask 142 | outputs["flops_ori"] = flops_ori 143 | return outputs 144 | 145 | def set_criterion(self, criterion): 146 | self.criterion = criterion 147 | return 148 | 149 | def get_loss(self, output, label, batch_size, den_target, lbda, gamma, p, 150 | mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, 151 | flops_real, flops_mask, flops_ori): 152 | closs, rloss, bloss = self.criterion(output, label, flops_real, flops_mask, 153 | flops_ori, batch_size, den_target, lbda, mask_norm_s, mask_norm_c, 154 | norm_s_t, norm_c_t, gamma, p) 155 | return closs, rloss, bloss 156 | 157 | def record_flops(self, flops_conv, flops_mask, flops_ori, flops_conv1, flops_fc): 158 | i = 0 159 | table = PrettyTable(['Layer', 'Conv FLOPs', 'Conv %', 'Mask FLOPs', 'Total FLOPs', 'Total %', 'Original FLOPs']) 160 | table.add_row(['layer0'] + ['{flops:.2f}K'.format(flops=flops_conv1/1024)] + [' ' for _ in range(5)]) 161 | for name, m in self.named_modules(): 162 | if isinstance(m, InvertedResidual): 163 | table.add_row([name] + ['{flops:.2f}K'.format(flops=flops_conv[i]/1024)] + ['{per_f:.2f}%'.format( 164 | per_f=flops_conv[i]/flops_ori[i]*100)] + ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] + 165 | ['{total:.2f}K'.format(total=(flops_conv[i]+flops_mask[i])/1024)] + ['{per_t:.2f}%'.format( 166 | per_t=(flops_conv[i]+flops_mask[i])/flops_ori[i]*100)] + 167 | ['{ori:.2f}K'.format(ori=flops_ori[i]/1024)]) 168 | i+=1 169 | table.add_row(['fc'] + ['{flops:.2f}K'.format(flops=flops_fc/1024)] + [' ' for _ in range(5)]) 170 | table.add_row(['Total'] + ['{flops:.2f}K'.format(flops=(flops_conv[i]+flops_conv1+flops_fc)/1024)] + 171 | ['{per_f:.2f}%'.format(per_f=(flops_conv[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] + 172 | ['{mask:.2f}K'.format(mask=flops_mask[i]/1024)] + ['{total:.2f}K'.format( 173 | total=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/1024)] + ['{per_t:.2f}%'.format( 174 | per_t=(flops_conv[i]+flops_mask[i]+flops_conv1+flops_fc)/(flops_ori[i]+flops_conv1+flops_fc)*100)] + 175 | ['{ori:.2f}K'.format(ori=(flops_ori[i]+flops_conv1+flops_fc)/1024)]) 176 | logging.info('\n{}'.format(table)) 177 | 178 | 179 | def mobilenet_v2_dg(**kwargs): 180 | return MobileNetV2(block=InvertedResidual, **kwargs) 181 | -------------------------------------------------------------------------------- /models/mobilenet_v2/mobilenet_v2_dg_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import nn 4 | from ..mask import Mask_s, Mask_c 5 | 6 | 7 | def conv2d_out_dim(dim, kernel_size, padding=0, stride=1, dilation=1, ceil_mode=False): 8 | if ceil_mode: 9 | return int(math.ceil((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 10 | else: 11 | return int(math.floor((dim + 2 * padding - dilation * (kernel_size - 1) - 1) / stride + 1)) 12 | 13 | class ConvBNReLU(nn.Sequential): 14 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 15 | padding = (kernel_size - 1) // 2 16 | super(ConvBNReLU, self).__init__( 17 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 18 | nn.BatchNorm2d(out_planes), 19 | nn.ReLU6(inplace=True) 20 | ) 21 | 22 | 23 | class ConvBNReLU_1st(nn.Sequential): 24 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 25 | padding = (kernel_size - 1) // 2 26 | super(ConvBNReLU_1st, self).__init__( 27 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 28 | nn.BatchNorm2d(out_planes), 29 | nn.ReLU6(inplace=True) 30 | ) 31 | 32 | def forward(self, input): 33 | x, norm_1, norm_2, flops = input 34 | x = super(ConvBNReLU_1st, self).forward(x) 35 | return x, norm_1, norm_2, flops 36 | 37 | 38 | class Sequential_DG(nn.Sequential): 39 | def __init__(self, layers): 40 | super(Sequential_DG, self).__init__(*layers) 41 | self._module_num = len(layers) 42 | 43 | def forward(self, input): 44 | x, mask_c, mask_s1, mask_s2 = input 45 | i = 0 46 | for module in self._modules.values(): 47 | if self.training: 48 | if i == self._module_num-2: 49 | x = x * mask_c 50 | x = module(x) 51 | else: 52 | if i == 0: 53 | x = module(x) * mask_s1 54 | elif i == self._module_num-2: 55 | x = x * mask_c * mask_s2 56 | x = module(x) 57 | else: 58 | x = module(x) 59 | i += 1 60 | return x 61 | 62 | 63 | class InvertedResidual(nn.Module): 64 | def __init__(self, inp, oup, stride, expand_ratio, h, w, eta, **kwargs): 65 | super(InvertedResidual, self).__init__() 66 | self.stride = stride 67 | assert stride in [1, 2] 68 | 69 | self.height = conv2d_out_dim(h, kernel_size=3, stride=stride, padding=1) 70 | self.width = conv2d_out_dim(w, kernel_size=3, stride=stride, padding=1) 71 | self.spatial = self.height * self.width 72 | self.expand = expand_ratio == 1 73 | hidden_dim = int(round(inp * expand_ratio)) 74 | self.use_res_connect = self.stride == 1 and inp == oup 75 | 76 | layers = [] 77 | if expand_ratio != 1: 78 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 79 | layers.extend([ 80 | # dw 81 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 82 | # pw-linear 83 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 84 | nn.BatchNorm2d(oup), 85 | ]) 86 | if self.use_res_connect: 87 | self.conv = Sequential_DG(layers) 88 | # channel mask 89 | self.mask_c = Mask_c(inp, hidden_dim, **kwargs) 90 | flops_mkc = self.mask_c.get_flops() 91 | # spatial mask 92 | self.mask_s = Mask_s(self.height, self.width, inp, eta, eta, **kwargs) 93 | self.upsample = nn.Upsample(size=(h, w), mode='nearest') 94 | flops_mks = self.mask_s.get_flops() 95 | else: 96 | self.conv = nn.Sequential(*layers) 97 | flops_mkc, flops_mks = 0, 0 98 | self.norm_c_t = torch.Tensor([hidden_dim]) 99 | self.norm_s_t = torch.Tensor([self.spatial]) 100 | # misc 101 | self.inp, self.oup = inp, oup 102 | self.hidden_dim = hidden_dim 103 | # flops 104 | flops_dw_full = torch.Tensor([9 * self.spatial * hidden_dim]) 105 | flops_pw_full = torch.Tensor([self.spatial * hidden_dim * oup]) 106 | self.flops_full = flops_dw_full + flops_pw_full 107 | if expand_ratio != 1: 108 | self.flops_full = self.flops_full + torch.Tensor([h * w * hidden_dim * inp]) 109 | self.upsample1 = nn.Upsample(size=(h, w), mode='nearest') 110 | # mask flops 111 | self.flops_mask = torch.Tensor([flops_mks + flops_mkc]) 112 | 113 | def forward(self, input): 114 | if not self.use_res_connect: 115 | x, norm_1, norm_2, flops = input 116 | x = self.conv(x) 117 | norm_s = torch.ones((x.shape[0], self.spatial), device=x.device).sum(1) 118 | norm_c = torch.ones((x.shape[0], self.hidden_dim), device=x.device).sum(1) 119 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, self.norm_s_t.to(x.device))).unsqueeze(0))) 120 | norm_2 = torch.cat((norm_2, torch.cat((norm_c, self.norm_c_t.to(x.device))).unsqueeze(0))) 121 | flops_blk = torch.cat((torch.ones(x.shape[0])*self.flops_full, self.flops_mask, self.flops_full)).to(flops.device) 122 | flops = torch.cat((flops, flops_blk.unsqueeze(0))) 123 | return (x, norm_1, norm_2, flops) 124 | else: 125 | x_in, norm_1, norm_2, flops = input 126 | # channel mask 127 | mask_c, norm_c, norm_c_t = self.mask_c(x_in) # [N, C_out, 1, 1] 128 | # spatial mask 129 | mask_s_m, norm_s, norm_s_t = self.mask_s(x_in) # [N, 1, h, w] 130 | mask_s1 = self.upsample1(mask_s_m) # [N, 1, H1, W1] 131 | mask_s = self.upsample(mask_s_m) # [N, 1, H, W] 132 | x = self.conv((x_in, mask_c, mask_s1, mask_s)) 133 | x = x * mask_s 134 | # norm 135 | norm_1 = torch.cat((norm_1, torch.cat((norm_s, norm_s_t)).unsqueeze(0))) 136 | norm_2 = torch.cat((norm_2, torch.cat((norm_c, norm_c_t)).unsqueeze(0))) 137 | # flops 138 | flops_blk = self.get_flops(mask_c, mask_s) 139 | flops = torch.cat((flops, flops_blk.unsqueeze(0))) 140 | return (x+x_in, norm_1, norm_2, flops) 141 | 142 | def get_flops(self, mask_c, mask_s_up): 143 | s_sum = mask_s_up.sum((1,2,3)) 144 | c_sum = mask_c.sum((1,2,3)) 145 | # convdw 146 | flops_dw = 9 * s_sum * c_sum 147 | # convpw 148 | flops_pw = s_sum * c_sum * self.oup 149 | # conv1x1 150 | flops = flops_dw + flops_pw 151 | if not self.expand: 152 | mask_s_1 = self.upsample1(mask_s_up) 153 | flops = flops + mask_s_1.sum((1,2,3)) * c_sum * self.inp 154 | # total 155 | return torch.cat((flops, self.flops_mask.to(flops.device), self.flops_full.to(flops.device))) 156 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import models 3 | 4 | 5 | # Parse arguments 6 | parser = argparse.ArgumentParser(description='PyTorch Training') 7 | # Datasets 8 | parser.add_argument('-d', '--data', default='path to dataset', type=str) 9 | parser.add_argument('-dset', '--dataset', default='dataset', type=str) 10 | parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 4)') 11 | # Architecture 12 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=models.ALL_MODEL_NAMES, 13 | help='model architecture: ' + ' | '.join(models.ALL_MODEL_NAMES) + ' (default: resnet50)') 14 | # Optimization options 15 | parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') 16 | parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', 17 | help='initial learning rate (default: 0.001 | for inception recommend 0.0256)') 18 | parser.add_argument('--lr-decay', default=0.1, type=float, metavar='LD', 19 | help='every lr-decay-step epochs learning rate decays by LD (default:0.1 | for inception recommend 0.16)') 20 | parser.add_argument('--lr-mode', default='step', type=str, help='learning rate mode') 21 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum (default: 0.9)') 22 | parser.add_argument('--weight-decay', '-wd', default=1e-4, type=float, metavar='WD', help='weight decay for sgd (default: 1e-4)') 23 | parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], help='Decrease learning rate at these epochs.') 24 | parser.add_argument('--den-target', default=0.5, type=float, help='target density of the mask.') 25 | parser.add_argument('--lbda', default=5, type=float, help='penalty factor of the L2 loss for mask.') 26 | parser.add_argument('--gamma', default=1, type=float, help='penalty factor of the L2 loss for balance gate.') 27 | parser.add_argument('--alpha', default=5e-2, type=float, help='alpha in exp annealing.') 28 | # Training 29 | parser.add_argument('--epochs', default=300, type=int, metavar='EPOCHS', help='number of total iteration to run.') 30 | # Device options 31 | parser.add_argument('--gpu-id', default='-1', type=str, help='id(s) for CUDA_VISIBLE_DEVICES') 32 | # Miscs 33 | parser.add_argument('--manualSeed', type=int, help='manual seed') 34 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', help='evaluate model on validation set') 35 | parser.add_argument('--pretrained', default='', type=str, metavar='PATH', 36 | help='use pre-trained model: ''pytorch: use pytorch official | path to self-trained model') 37 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 38 | help='path to store the checkpoint and log checkpoint path = ./checkpoints/PATH, log path = ./logs/PATH') 39 | parser.add_argument('--bias', default=2, type=float, help='initial value of the bias in the last fc layer of mask module.') 40 | -------------------------------------------------------------------------------- /regularization.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class spar_loss(nn.Module): 8 | def __init__(self): 9 | super(spar_loss, self).__init__() 10 | 11 | def forward(self, flops_real, flops_mask, flops_ori, batch_size, den_target, lbda): 12 | # total sparsity 13 | flops_tensor, flops_conv1, flops_fc = flops_real[0], flops_real[1], flops_real[2] 14 | # block flops 15 | flops_conv = flops_tensor[0:batch_size,:].mean(0).sum() 16 | flops_mask = flops_mask.mean(0).sum() 17 | flops_ori = flops_ori.mean(0).sum() + flops_conv1.mean() + flops_fc.mean() 18 | flops_real = flops_conv + flops_mask + flops_conv1.mean() + flops_fc.mean() 19 | # loss 20 | rloss = lbda * (flops_real / flops_ori - den_target)**2 21 | return rloss 22 | 23 | 24 | class blance_loss(nn.Module): 25 | def __init__(self): 26 | super(blance_loss, self).__init__() 27 | 28 | def forward(self, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size, 29 | den_target, gamma, p): 30 | norm_s = mask_norm_s 31 | norm_s_t = norm_s_t.mean(0) 32 | norm_c = mask_norm_c 33 | norm_c_t = norm_c_t.mean(0) 34 | den_s = norm_s[0:batch_size,:].mean(0) / norm_s_t 35 | den_c = norm_c[0:batch_size,:].mean(0) / norm_c_t 36 | den_tar = math.sqrt(den_target) 37 | bloss_s = get_bloss_basic(den_s, den_tar, batch_size, gamma, p) 38 | bloss_c = get_bloss_basic(den_c, den_tar, batch_size, gamma, p) 39 | bloss = bloss_s + bloss_c 40 | return bloss 41 | 42 | 43 | def get_bloss_basic(spar, spar_tar, batch_size, gamma, p): 44 | # bound 45 | bloss_l = (F.relu(p*spar_tar-spar)**2).mean() 46 | bloss_u = (F.relu(spar-1+p-p*spar_tar)**2).mean() 47 | bloss = gamma * (bloss_l + bloss_u) 48 | return bloss 49 | 50 | 51 | class Loss(nn.Module): 52 | def __init__(self): 53 | super(Loss, self).__init__() 54 | self.task_loss = nn.CrossEntropyLoss() 55 | self.spar_loss = spar_loss() 56 | self.balance_loss = blance_loss() 57 | 58 | def forward(self, output, targets, flops_real, flops_mask, flops_ori, batch_size, 59 | den_target, lbda, mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, 60 | gamma, p): 61 | closs = self.task_loss(output, targets) 62 | sloss = self.spar_loss(flops_real, flops_mask, flops_ori, batch_size, den_target, lbda) 63 | bloss = self.balance_loss(mask_norm_s, mask_norm_c, norm_s_t, norm_c_t, batch_size, 64 | den_target, gamma, p) 65 | return closs, sloss, bloss 66 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | prettytable 2 | matplotlib 3 | pretrainedmodels 4 | numpy 5 | -------------------------------------------------------------------------------- /scripts/cifar_e.sh: -------------------------------------------------------------------------------- 1 | echo "Dynamic dual gating model $1 for cifar-10." 2 | 3 | time python main.py -d $2 -dset cifar10 -j 2 -a $1 -b 128 \ 4 | --checkpoint $4 --gpu-id $3 --pretrained $5 -e 5 | -------------------------------------------------------------------------------- /scripts/cifar_t.sh: -------------------------------------------------------------------------------- 1 | echo "Dynamic dual gating model $1 for cifar-10." 2 | 3 | dgnet_cifar10(){ 4 | time python main.py -d $2 -dset cifar10 -j 2 -a $1 -b 128 -lr 0.1 \ 5 | --weight-decay 5e-4 --schedule 150 225 --checkpoint $5 \ 6 | --gpu-id $4 --den-target $3 --alpha 2e-2 --pretrained $6 7 | } 8 | 9 | checkpoint1="$5_varience1" 10 | checkpoint2="$5_varience2" 11 | checkpoint3="$5_varience3" 12 | 13 | dgnet_cifar10 $1 $2 $3 $4 $checkpoint1 $6 14 | dgnet_cifar10 $1 $2 $3 $4 $checkpoint2 $6 15 | dgnet_cifar10 $1 $2 $3 $4 $checkpoint3 $6 16 | -------------------------------------------------------------------------------- /scripts/imagenet_e.sh: -------------------------------------------------------------------------------- 1 | echo "Dynamic dual gating model $1 for ImageNet." 2 | 3 | time python main.py -d $2 -dset imagenet -a $1 \ 4 | --checkpoint $4 --gpu-id $3 --pretrained $5 -e -------------------------------------------------------------------------------- /scripts/imagenet_t.sh: -------------------------------------------------------------------------------- 1 | echo "Dynamic dual gating model $1 for ImageNet." 2 | 3 | time python main.py -d $2 -dset imagenet -a $1 -lr 0.05 \ 4 | --weight-decay 1e-4 --epochs 100 --checkpoint $5 --gpu-id $4 \ 5 | --den-target $3 --pretrained pytorch --lr-mode cosine 6 | -------------------------------------------------------------------------------- /scripts/mobilenet_v2_e.sh: -------------------------------------------------------------------------------- 1 | echo "Dynamic dual gating model $1 for ImageNet." 2 | 3 | time python main.py -d $2 -dset imagenet -a $1 \ 4 | --checkpoint $4 --gpu-id $3 --pretrained $5 -e 5 | -------------------------------------------------------------------------------- /scripts/mobilenet_v2_t.sh: -------------------------------------------------------------------------------- 1 | echo "Dynamic dual gating model $1 for ImageNet." 2 | 3 | time python main.py -d $2 -dset imagenet -a $1 -lr 0.05 \ 4 | --weight-decay 4e-5 --epochs 200 --checkpoint $5 \ 5 | --gpu-id $4 --den-target $3 --pretrained pytorch \ 6 | --lr-mode cosine 7 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import get_loggers 2 | from .misc import AverageMeter, accuracy 3 | from .misc import analyse_flops, ExpAnnealing 4 | 5 | # progress bar 6 | import os, sys 7 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 8 | from progress.bar import Bar as Bar 9 | 10 | __all__ = ['AverageMeter', 'Bar', 'accuracy', 'get_loggers', 11 | 'analyse_flops', 'ExpAnnealing'] 12 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import matplotlib.pyplot as plt 3 | import os 4 | import datetime 5 | import torch 6 | import logging 7 | import errno 8 | import numpy as np 9 | from logging import handlers 10 | 11 | __all__ = ['Logger', 'LoggerMonitor', 'savefig', 'get_loggers'] 12 | 13 | 14 | def savefig(fname, dpi=None): 15 | dpi = 150 if dpi is None else dpi 16 | plt.savefig(fname, dpi=dpi) 17 | 18 | 19 | def plot_overlap(logger, names=None): 20 | names = logger.names if names is None else names 21 | numbers = logger.numbers 22 | for _, name in enumerate(names): 23 | x = np.arange(len(numbers[name])) 24 | plt.plot(x, np.asarray(numbers[name])) 25 | return [logger.title + '(' + name + ')' for name in names] 26 | 27 | 28 | def get_loggers(args): 29 | """ 30 | Generate loggers 31 | 32 | Args: 33 | - args : config information 34 | """ 35 | # log file and checkpoint file 36 | checkpoint = args.checkpoint 37 | arch = args.arch 38 | if(checkpoint == ''): 39 | dir_name = arch + '_' + datetime.datetime.now().strftime('%m%d_%H%M') 40 | else: 41 | dir_name = checkpoint 42 | log_dir = os.path.join('logs', dir_name) 43 | checkpoint_dir = log_dir 44 | print('\n--------------------------------------------------------') 45 | if not os.path.isdir(checkpoint_dir): 46 | mkdir_p(log_dir) 47 | mkdir_p(checkpoint_dir) 48 | print("=> make directory '{}'".format(log_dir)) 49 | else: 50 | print("=> directory '{}' exists".format(log_dir)) 51 | 52 | train_log = Logger(os.path.join(log_dir, 'train.log')) 53 | test_log = Logger(os.path.join(log_dir, 'test.log')) 54 | config_log = Logger(os.path.join(log_dir, 'config.log')) 55 | if not os.path.isdir(os.path.join(log_dir, 'tb')): 56 | os.makedirs(os.path.join(log_dir, 'tb')) 57 | 58 | # msg logger 59 | log_level = logging.INFO 60 | fmt = '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' 61 | logging.basicConfig(level=log_level, 62 | filename=os.path.join(log_dir, 'message.log'), 63 | filemode='w', 64 | format=fmt) 65 | console = logging.StreamHandler() 66 | console.setLevel(logging.INFO) 67 | # set a format which is simpler for console use 68 | formatter = logging.Formatter('%(message)s') 69 | # tell the handler to use this format 70 | console.setFormatter(formatter) 71 | # add the handler to the root logger 72 | logging.getLogger('').addHandler(console) 73 | 74 | # Save the config info 75 | for k, v in vars(args).items(): 76 | config_log.write(content="{k} : {v}".format(k=k, v=v), 77 | wrap=True, 78 | flush=True) 79 | config_log.close() 80 | 81 | # logger initialization 82 | test_log.write(content="epoch\ttop1\ttop5\tloss\tcloss\trloss\tbloss\tdensity\tflops_per\tflops\t", 83 | wrap=True, 84 | flush=True) 85 | train_log.write(content="epoch\ttop1\ttop5\tloss\tcloss\trloss\tbloss", 86 | wrap=True, 87 | flush=True) 88 | return train_log, test_log, checkpoint_dir, log_dir 89 | 90 | 91 | def has_children(module): 92 | try: 93 | next(module.children()) 94 | return True 95 | except StopIteration: 96 | return False 97 | 98 | 99 | class Logger(object): 100 | '''Save training process to log file with simple plot function.''' 101 | def __init__(self, fpath, title=None, resume=False): 102 | self.file = None 103 | self.resume = resume 104 | self.title = '' if title is None else title 105 | if fpath is not None: 106 | if resume: 107 | self.file = open(fpath, 'r') 108 | name = self.file.readline() 109 | self.names = name.rstrip().split('\t') 110 | self.numbers = {} 111 | for _, name in enumerate(self.names): 112 | self.numbers[name] = [] 113 | 114 | for numbers in self.file: 115 | numbers = numbers.rstrip().split('\t') 116 | for i in range(0, len(numbers)): 117 | self.numbers[self.names[i]].append(numbers[i]) 118 | self.file.close() 119 | self.file = open(fpath, 'a') 120 | else: 121 | self.file = open(fpath, 'w') 122 | 123 | def set_names(self, names): 124 | if self.resume: 125 | pass 126 | # initialize numbers as empty list 127 | self.numbers = {} 128 | self.names = names 129 | for _, name in enumerate(self.names): 130 | self.file.write(name) 131 | self.file.write('\t') 132 | self.numbers[name] = [] 133 | self.file.write('\n') 134 | self.file.flush() 135 | 136 | def append(self, numbers): 137 | assert len(self.names) == len(numbers), 'Numbers do not match names' 138 | for index, num in enumerate(numbers): 139 | self.file.write("{0:.6f}".format(num)) 140 | self.file.write('\t') 141 | self.numbers[self.names[index]].append(num) 142 | self.file.write('\n') 143 | self.file.flush() 144 | 145 | def plot(self, names=None): 146 | names = self.names if names is None else names 147 | numbers = self.numbers 148 | for _, name in enumerate(names): 149 | x = np.arange(len(numbers[name])) 150 | plt.plot(x, np.asarray(numbers[name])) 151 | plt.legend([self.title + '(' + name + ')' for name in names]) 152 | plt.grid(True) 153 | 154 | def close(self): 155 | if self.file is not None: 156 | self.file.close() 157 | 158 | def write(self, content, wrap=True, flush=False, verbose=False): 159 | """ 160 | write file and flush buffer to the disk 161 | :param content: str 162 | :param wrap: bool, whether to add '\n' at the end of the content 163 | :param flush: bool, whether to flush buffer to the disk, default=False 164 | :param verbose: bool, whether to print the content, default=False 165 | :return: 166 | void 167 | """ 168 | if verbose: 169 | print(content) 170 | if wrap: 171 | content += "\n" 172 | self.file.write(content) 173 | if flush: 174 | self.file.flush() 175 | os.fsync(self.file) 176 | 177 | 178 | class LoggerMonitor(object): 179 | '''Load and visualize multiple logs.''' 180 | def __init__(self, paths): 181 | '''paths is a distionary with {name:filepath} pair''' 182 | self.loggers = [] 183 | for title, path in paths.items(): 184 | logger = Logger(path, title=title, resume=True) 185 | self.loggers.append(logger) 186 | 187 | def plot(self, names=None): 188 | plt.figure() 189 | plt.subplot(121) 190 | legend_text = [] 191 | for logger in self.loggers: 192 | legend_text += plot_overlap(logger, names) 193 | plt.legend(legend_text, 194 | bbox_to_anchor=(1.05, 1), 195 | loc=2, 196 | borderaxespad=0.) 197 | plt.grid(True) 198 | 199 | 200 | def mkdir_p(path): 201 | '''make dir if not exist''' 202 | try: 203 | os.makedirs(path) 204 | except OSError as exc: # Python >2.5 205 | if exc.errno == errno.EEXIST and os.path.isdir(path): 206 | pass 207 | else: 208 | raise 209 | 210 | 211 | def size_to_str(torch_size): 212 | """Convert a pytorch Size object to a string""" 213 | assert isinstance(torch_size, (torch.Size, tuple, list)) 214 | return '(' + (', ').join(['%d' % v for v in torch_size]) + ')' 215 | 216 | 217 | def to_np(var): 218 | return var.data.cpu().numpy() 219 | 220 | 221 | def norm_filters(weights, p=1): 222 | """Compute the p-norm of convolution filters. 223 | 224 | Args: 225 | weights - a 4D convolution weights tensor. 226 | Has shape = (#filters, #channels, k_w, k_h) 227 | p - the exponent value in the norm formulation 228 | """ 229 | assert weights.dim() == 4 230 | return weights.view(weights.size(0), -1).norm(p=p, dim=1) 231 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | __all__ = ['AverageMeter', 'accuracy', 'analyse_flops', 'ExpAnnealing'] 6 | 7 | 8 | def accuracy(output, target, topk=(1,)): 9 | """Computes the precision@k for the specified values of k""" 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | res = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(0) 20 | res.append(correct_k.mul_(100.0 / batch_size)) 21 | return res 22 | 23 | 24 | def analyse_flops(flops_real, flops_mask, flops_ori, batch_size): 25 | def add_sum(data): 26 | s = data.sum().unsqueeze(0) 27 | out = torch.cat([data, s]) 28 | return out 29 | block_flops, flops_conv1, flops_fc = flops_real[0], flops_real[1], flops_real[2] 30 | flops_mask = flops_mask.mean(0) 31 | # block flops 32 | flops_conv = add_sum(block_flops[0:batch_size,:].mean(0)) 33 | flops_mask = add_sum(flops_mask) 34 | flops_ori = add_sum(flops_ori.mean(0)) 35 | return flops_conv, flops_mask, flops_ori, flops_conv1.mean(), flops_fc.mean() 36 | 37 | 38 | class AverageMeter(object): 39 | r"""Computes and stores the average and current value 40 | Imported from 41 | https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 42 | """ 43 | def __init__(self): 44 | self.reset() 45 | 46 | def reset(self): 47 | self.val = 0 48 | self.avg = 0 49 | self.sum = 0 50 | self.count = 0 51 | 52 | def update(self, val, n=1): 53 | self.val = val 54 | self.sum += val * n 55 | self.count += n 56 | self.avg = self.sum / self.count 57 | 58 | 59 | class ExpAnnealing(object): 60 | r""" 61 | Args: 62 | T_max (int): Maximum number of iterations. 63 | eta_ini (float): Initial density. Default: 1. 64 | eta_min (float): Minimum density. Default: 0. 65 | """ 66 | 67 | def __init__(self, T_ini, eta_ini=1, eta_final=0, up=False, alpha=1): 68 | self.T_ini = T_ini 69 | self.eta_final = eta_final 70 | self.eta_ini = eta_ini 71 | self.up = up 72 | self.last_epoch = 0 73 | self.alpha = alpha 74 | 75 | def get_lr(self, epoch): 76 | if epoch < self.T_ini: 77 | return self.eta_ini 78 | elif self.up: 79 | return self.eta_ini + (self.eta_final-self.eta_ini) * (1- 80 | math.exp(-self.alpha*(epoch-self.T_ini))) 81 | else: 82 | return self.eta_final + (self.eta_ini-self.eta_final) * math.exp( 83 | -self.alpha*(epoch-self.T_ini)) 84 | 85 | def step(self): 86 | self.last_epoch += 1 87 | return self.get_lr(self.last_epoch) 88 | -------------------------------------------------------------------------------- /utils/progress/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | build/ 4 | dist/ 5 | -------------------------------------------------------------------------------- /utils/progress/LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /utils/progress/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst LICENSE 2 | -------------------------------------------------------------------------------- /utils/progress/README.rst: -------------------------------------------------------------------------------- 1 | Easy progress reporting for Python 2 | ================================== 3 | 4 | |pypi| 5 | 6 | |demo| 7 | 8 | .. |pypi| image:: https://img.shields.io/pypi/v/progress.svg 9 | .. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif 10 | :alt: Demo 11 | 12 | Bars 13 | ---- 14 | 15 | There are 7 progress bars to choose from: 16 | 17 | - ``Bar`` 18 | - ``ChargingBar`` 19 | - ``FillingSquaresBar`` 20 | - ``FillingCirclesBar`` 21 | - ``IncrementalBar`` 22 | - ``PixelBar`` 23 | - ``ShadyBar`` 24 | 25 | To use them, just call ``next`` to advance and ``finish`` to finish: 26 | 27 | .. code-block:: python 28 | 29 | from progress.bar import Bar 30 | 31 | bar = Bar('Processing', max=20) 32 | for i in range(20): 33 | # Do some work 34 | bar.next() 35 | bar.finish() 36 | 37 | The result will be a bar like the following: :: 38 | 39 | Processing |############# | 42/100 40 | 41 | To simplify the common case where the work is done in an iterator, you can 42 | use the ``iter`` method: 43 | 44 | .. code-block:: python 45 | 46 | for i in Bar('Processing').iter(it): 47 | # Do some work 48 | 49 | Progress bars are very customizable, you can change their width, their fill 50 | character, their suffix and more: 51 | 52 | .. code-block:: python 53 | 54 | bar = Bar('Loading', fill='@', suffix='%(percent)d%%') 55 | 56 | This will produce a bar like the following: :: 57 | 58 | Loading |@@@@@@@@@@@@@ | 42% 59 | 60 | You can use a number of template arguments in ``message`` and ``suffix``: 61 | 62 | ========== ================================ 63 | Name Value 64 | ========== ================================ 65 | index current value 66 | max maximum value 67 | remaining max - index 68 | progress index / max 69 | percent progress * 100 70 | avg simple moving average time per item (in seconds) 71 | elapsed elapsed time in seconds 72 | elapsed_td elapsed as a timedelta (useful for printing as a string) 73 | eta avg * remaining 74 | eta_td eta as a timedelta (useful for printing as a string) 75 | ========== ================================ 76 | 77 | Instead of passing all configuration options on instatiation, you can create 78 | your custom subclass: 79 | 80 | .. code-block:: python 81 | 82 | class FancyBar(Bar): 83 | message = 'Loading' 84 | fill = '*' 85 | suffix = '%(percent).1f%% - %(eta)ds' 86 | 87 | You can also override any of the arguments or create your own: 88 | 89 | .. code-block:: python 90 | 91 | class SlowBar(Bar): 92 | suffix = '%(remaining_hours)d hours remaining' 93 | @property 94 | def remaining_hours(self): 95 | return self.eta // 3600 96 | 97 | 98 | Spinners 99 | ======== 100 | 101 | For actions with an unknown number of steps you can use a spinner: 102 | 103 | .. code-block:: python 104 | 105 | from progress.spinner import Spinner 106 | 107 | spinner = Spinner('Loading ') 108 | while state != 'FINISHED': 109 | # Do some work 110 | spinner.next() 111 | 112 | There are 5 predefined spinners: 113 | 114 | - ``Spinner`` 115 | - ``PieSpinner`` 116 | - ``MoonSpinner`` 117 | - ``LineSpinner`` 118 | - ``PixelSpinner`` 119 | 120 | 121 | Other 122 | ===== 123 | 124 | There are a number of other classes available too, please check the source or 125 | subclass one of them to create your own. 126 | 127 | 128 | License 129 | ======= 130 | 131 | progress is licensed under ISC 132 | -------------------------------------------------------------------------------- /utils/progress/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CAS-CLab/DGNet/6b709a388c463d7468fbad953ad0112bc3abe66d/utils/progress/demo.gif -------------------------------------------------------------------------------- /utils/progress/progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /utils/progress/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /utils/progress/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /utils/progress/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | if self.file.isatty(): 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /utils/progress/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /utils/progress/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | import progress 6 | 7 | 8 | setup( 9 | name='progress', 10 | version=progress.__version__, 11 | description='Easy to use progress bars', 12 | long_description=open('README.rst').read(), 13 | author='Giorgos Verigakis', 14 | author_email='verigak@gmail.com', 15 | url='http://github.com/verigak/progress/', 16 | license='ISC', 17 | packages=['progress'], 18 | classifiers=[ 19 | 'Environment :: Console', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: ISC License (ISCL)', 22 | 'Programming Language :: Python :: 2.6', 23 | 'Programming Language :: Python :: 2.7', 24 | 'Programming Language :: Python :: 3.3', 25 | 'Programming Language :: Python :: 3.4', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | ] 29 | ) 30 | --------------------------------------------------------------------------------