├── .gitignore ├── README.md ├── adaptive_inference.py ├── args.py ├── dataloader.py ├── imgs ├── RANet_overview.png ├── anytime_results.png └── dynamic_results.png ├── main.py ├── models ├── RANet.py └── __init__.py ├── op_counter.py ├── train_cifar.sh └── train_imagenet.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Resolution Adaptive Networks for Efficient Inference (CVPR2020) 2 | [Le Yang*](https://github.com/yangle15), [Yizeng Han*](https://github.com/thuallen), [Xi Chen*](https://github.com/FateDawnLeon), Shiji Song, [Jifeng Dai](https://github.com/daijifeng001), [Gao Huang](https://github.com/gaohuang) 3 | 4 | This repository contains the implementation of the paper, '[Resolution Adaptive Networks for Efficient Inference](https://arxiv.org/pdf/2003.07326.pdf)'. The proposed Resolution Adaptive Networks (RANet) conduct the adaptive inferece by exploiting the ``spatial redundancy`` of input images. Our motivation is that low-resolution representations are sufficient for classifying easy samples containing large objects with prototypical features, while only some hard samples need spatially detailed information, which can be demonstrated by the follow figure. 5 | 6 |
7 | 8 | ## Results 9 | 10 |
11 | 12 | Accuracy (top-1) of anytime prediction models as a function of computational budget on the CIFAR-10 (left), CIFAR-100 13 | (middle) and ImageNet (right) datasets. Higher is better. 14 | 15 |
16 | 17 | Accuracy (top-1) of budgeted batch classification models as a function of average computational budget per image the on CIFAR- 18 | 10 (left), CIFAR-100 (middle) and ImageNet (right) datasets. Higher is better. 19 | 20 | ## Dependencies: 21 | 22 | * Python3 23 | 24 | * PyTorch >= 1.0 25 | 26 | ## Usage 27 | We Provide shell scripts for training a RANet on CIFAR and ImageNet. 28 | 29 | ### Train a RANet on CIFAR 30 | * Modify the train_cifar.sh to config your path to the dataset, your GPU devices and your saving directory. Then run 31 | ```sh 32 | bash train_cifar.sh 33 | ``` 34 | 35 | * You can train your RANet with other configurations. 36 | ```sh 37 | python main.py --arch RANet --gpu '0' --data-root YOUR_DATA_PATH --data 'cifar10' --step 2 --nChannels 16 --stepmode 'lg' --scale-list '1-2-3' --grFactor '4-2-1' --bnFactor '4-2-1' 38 | ``` 39 | 40 | ### Train a RANet on ImageNet 41 | * Modify the train_imagenet.sh to config your path to the dataset, your GPU devices and your saving directory. Then run 42 | ```sh 43 | bash train_imagenet.sh 44 | ``` 45 | 46 | * You can train your RANet with other configurations. 47 | ```sh 48 | python main.py --arch RANet --gpu '0,1,2,3' --data-root YOUR_DATA_PATH --data 'ImageNet' --step 8 --growthRate 16 --nChannels 32 --stepmode 'even' --scale-list '1-2-3-4' --grFactor '4-2-2-1' --bnFactor '4-2-2-1' 49 | ``` 50 | 51 | 52 | 53 | ### Citation 54 | If you find this work useful or use our codes in your own research, please use the following bibtex: 55 | ``` 56 | @inproceedings{yang2020resolution, 57 | title={Resolution Adaptive Networks for Efficient Inference}, 58 | author={Yang, Le and Han, Yizeng and Chen, Xi and Song, Shiji and Dai, Jifeng and Huang, Gao}, 59 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 60 | year={2020} 61 | } 62 | ``` 63 | 64 | ### Contact 65 | If you have any questions, please feel free to contact the authors. 66 | 67 | Le Yang: yangle15@mails.tsinghua.edu.cn 68 | 69 | Yizeng Han: [hanyz18@mails.tsinghua.edu.cn](mailto:hanyz18@mails.tsinghua.edu.cn) 70 | 71 | ### Acknowledgments 72 | We use the pytorch implementation of MSDNet in our experiments. The code can be found [here](https://github.com/kalviny/MSDNet-PyTorch). 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /adaptive_inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import os 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | 11 | def dynamic_evaluate(model, test_loader, val_loader, args): 12 | tester = Tester(model, args) 13 | if os.path.exists(os.path.join(args.save, 'logits_single.pth')): 14 | val_pred, val_target, test_pred, test_target = \ 15 | torch.load(os.path.join(args.save, 'logits_single.pth')) 16 | else: 17 | val_pred, val_target = tester.calc_logit(val_loader) 18 | test_pred, test_target = tester.calc_logit(test_loader) 19 | torch.save((val_pred, val_target, test_pred, test_target), 20 | os.path.join(args.save, 'logits_single.pth')) 21 | 22 | flops = torch.load(os.path.join(args.save, 'flops.pth')) 23 | 24 | acc_list, exp_flops_list = [], [] 25 | with open(os.path.join(args.save, 'dynamic.txt'), 'w') as fout: 26 | samples = {} 27 | for p in range(1, 40): 28 | print("*********************") 29 | _p = torch.FloatTensor(1).fill_(p * 1.0 / 20) 30 | probs = torch.exp(torch.log(_p) * torch.range(1, args.num_exits)) 31 | probs /= probs.sum() 32 | acc_val, _, T = tester.dynamic_eval_find_threshold( 33 | val_pred, val_target, probs, flops) 34 | acc_test, exp_flops, exit_buckets = tester.dynamic_eval_with_threshold( 35 | test_pred, test_target, flops, T) 36 | print('valid acc: {:.3f}, test acc: {:.3f}, test flops: {:.2f}M'.format(acc_val, acc_test, exp_flops / 1e6)) 37 | fout.write('{}\t{}\n'.format(acc_test, exp_flops.item())) 38 | acc_list.append(acc_test) 39 | exp_flops_list.append(exp_flops) 40 | samples[p] = exit_buckets 41 | torch.save([exp_flops_list, acc_list], os.path.join(args.save, 'dynamic.pth')) 42 | torch.save(samples, os.path.join(args.save, 'exit_samples_by_p.pth')) 43 | # return acc_list, exp_flops_list 44 | 45 | 46 | class Tester(object): 47 | def __init__(self, model, args=None): 48 | self.args = args 49 | self.model = model 50 | self.softmax = nn.Softmax(dim=1).cuda() 51 | 52 | def calc_logit(self, dataloader): 53 | self.model.eval() 54 | n_stage = self.args.num_exits 55 | logits = [[] for _ in range(n_stage)] 56 | targets = [] 57 | for i, (input, target) in enumerate(dataloader): 58 | targets.append(target) 59 | with torch.no_grad(): 60 | input_var = torch.autograd.Variable(input) 61 | output = self.model(input_var) 62 | if not isinstance(output, list): 63 | output = [output] 64 | for b in range(n_stage): 65 | _t = self.softmax(output[b]) 66 | 67 | logits[b].append(_t) 68 | 69 | if i % self.args.print_freq == 0: 70 | print('Generate Logit: [{0}/{1}]'.format(i, len(dataloader))) 71 | 72 | for b in range(n_stage): 73 | logits[b] = torch.cat(logits[b], dim=0) 74 | 75 | size = (n_stage, logits[0].size(0), logits[0].size(1)) 76 | ts_logits = torch.Tensor().resize_(size).zero_() 77 | for b in range(n_stage): 78 | ts_logits[b].copy_(logits[b]) 79 | 80 | targets = torch.cat(targets, dim=0) 81 | ts_targets = torch.Tensor().resize_(size[1]).copy_(targets) 82 | 83 | return ts_logits, ts_targets 84 | 85 | def dynamic_eval_find_threshold(self, logits, targets, p, flops): 86 | """ 87 | logits: m * n * c 88 | m: Stages 89 | n: Samples 90 | c: Classes 91 | """ 92 | n_stage, n_sample, c = logits.size() 93 | 94 | max_preds, argmax_preds = logits.max(dim=2, keepdim=False) 95 | 96 | _, sorted_idx = max_preds.sort(dim=1, descending=True) 97 | 98 | filtered = torch.zeros(n_sample) 99 | T = torch.Tensor(n_stage).fill_(1e8) 100 | 101 | for k in range(n_stage - 1): 102 | acc, count = 0.0, 0 103 | out_n = math.floor(n_sample * p[k]) 104 | for i in range(n_sample): 105 | ori_idx = sorted_idx[k][i] 106 | if filtered[ori_idx] == 0: 107 | count += 1 108 | if count == out_n: 109 | T[k] = max_preds[k][ori_idx] 110 | break 111 | filtered.add_(max_preds[k].ge(T[k]).type_as(filtered)) 112 | 113 | T[n_stage -1] = -1e8 # accept all of the samples at the last stage 114 | 115 | acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage) 116 | acc, expected_flops = 0, 0 117 | for i in range(n_sample): 118 | gold_label = targets[i] 119 | for k in range(n_stage): 120 | if max_preds[k][i].item() >= T[k]: # force the sample to exit at k 121 | if int(gold_label.item()) == int(argmax_preds[k][i].item()): 122 | acc += 1 123 | acc_rec[k] += 1 124 | exp[k] += 1 125 | break 126 | acc_all = 0 127 | for k in range(n_stage): 128 | _t = 1.0 * exp[k] / n_sample 129 | expected_flops += _t * flops[k] 130 | acc_all += acc_rec[k] 131 | 132 | return acc * 100.0 / n_sample, expected_flops, T 133 | 134 | def dynamic_eval_with_threshold(self, logits, targets, flops, T): 135 | n_stage, n_sample, _ = logits.size() 136 | max_preds, argmax_preds = logits.max(dim=2, keepdim=False) # take the max logits as confidence 137 | 138 | exit_buckets = {i:{j:[] for j in range(n_stage)} for i in range(1000)} # for each exit use a bucket to keep track of samples outputing from it 139 | 140 | acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage) 141 | acc, expected_flops = 0, 0 142 | for i in range(n_sample): 143 | gold_label = targets[i] 144 | for k in range(n_stage): 145 | if max_preds[k][i].item() >= T[k]: # force to exit at k 146 | _g = int(gold_label.item()) 147 | _pred = int(argmax_preds[k][i].item()) 148 | if _g == _pred: 149 | acc += 1 150 | acc_rec[k] += 1 151 | exp[k] += 1 152 | exit_buckets[int(gold_label)][k].append(i) 153 | break 154 | 155 | acc_all, sample_all = 0, 0 156 | for k in range(n_stage): 157 | _t = exp[k] * 1.0 / n_sample 158 | sample_all += exp[k] 159 | expected_flops += _t * flops[k] 160 | acc_all += acc_rec[k] 161 | 162 | return acc * 100.0 / n_sample, expected_flops, exit_buckets 163 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | import argparse 5 | 6 | model_names = ['RANet'] 7 | 8 | arg_parser = argparse.ArgumentParser(description='RANet Image classification') 9 | 10 | exp_group = arg_parser.add_argument_group('exp', 'experiment setting') 11 | exp_group.add_argument('--save', default='save/default-{}'.format(time.time()), 12 | type=str, metavar='SAVE', 13 | help='path to the experiment logging directory' 14 | '(default: save/debug)') 15 | exp_group.add_argument('--resume', action='store_true', default=None, 16 | help='path to latest checkpoint (default: none)') 17 | exp_group.add_argument('--evalmode', default=None, 18 | choices=['anytime', 'dynamic', 'both'], 19 | help='which mode to evaluate') 20 | exp_group.add_argument('--evaluate-from', default='', type=str, metavar='PATH', 21 | help='path to saved checkpoint (default: none)') 22 | exp_group.add_argument('--print-freq', '-p', default=10, type=int, 23 | metavar='N', help='print frequency (default: 100)') 24 | exp_group.add_argument('--seed', default=0, type=int, 25 | help='random seed') 26 | exp_group.add_argument('--gpu', default='0', type=str, help='GPU available.') 27 | 28 | # dataset related 29 | data_group = arg_parser.add_argument_group('data', 'dataset setting') 30 | data_group.add_argument('--data', metavar='D', default='cifar10', 31 | choices=['cifar10', 'cifar100', 'ImageNet'], 32 | help='data to work on') 33 | data_group.add_argument('--data-root', metavar='DIR', default='/data/cx/data', 34 | help='path to dataset (default: data)') 35 | data_group.add_argument('--use-valid', action='store_true', default=False, 36 | help='use validation set or not') 37 | data_group.add_argument('-j', '--workers', default=4, type=int, metavar='N', 38 | help='number of data loading workers (default: 4)') 39 | 40 | # model arch related 41 | arch_group = arg_parser.add_argument_group('arch', 'model architecture setting') 42 | arch_group.add_argument('--arch', type=str, default='RANet') 43 | arch_group.add_argument('--reduction', default=0.5, type=float, 44 | metavar='C', help='compression ratio of DenseNet' 45 | ' (1 means dot\'t use compression) (default: 0.5)') 46 | 47 | # msdnet config 48 | arch_group.add_argument('--nBlocks', type=int, default=2) 49 | arch_group.add_argument('--nChannels', type=int, default=16) 50 | arch_group.add_argument('--growthRate', type=int, default=6) 51 | arch_group.add_argument('--grFactor', default='4-2-1', type=str) 52 | arch_group.add_argument('--bnFactor', default='4-2-1', type=str) 53 | arch_group.add_argument('--block-step', type=int, default=2) 54 | arch_group.add_argument('--scale-list', default='1-2-3', type=str) 55 | arch_group.add_argument('--compress-factor', default=0.25, type=float) 56 | arch_group.add_argument('--step', type=int, default=4) 57 | arch_group.add_argument('--stepmode', type=str, default='even', choices=['even', 'lg']) 58 | arch_group.add_argument('--bnAfter', action='store_true', default=True) 59 | 60 | 61 | # training related 62 | optim_group = arg_parser.add_argument_group('optimization', 'optimization setting') 63 | optim_group.add_argument('--epochs', default=300, type=int, metavar='N', 64 | help='number of total epochs to run (default: 300)') 65 | optim_group.add_argument('--start-epoch', default=0, type=int, metavar='N', 66 | help='manual epoch number (useful on restarts)') 67 | optim_group.add_argument('-b', '--batch-size', default=64, type=int, 68 | metavar='N', help='mini-batch size (default: 64)') 69 | optim_group.add_argument('--optimizer', default='sgd', 70 | choices=['sgd', 'rmsprop', 'adam'], metavar='N', 71 | help='optimizer (default=sgd)') 72 | optim_group.add_argument('--lr', '--learning-rate', default=0.1, type=float, 73 | metavar='LR', 74 | help='initial learning rate (default: 0.1)') 75 | optim_group.add_argument('--lr-type', default='multistep', type=str, metavar='T', 76 | help='learning rate strategy (default: multistep)', 77 | choices=['cosine', 'multistep']) 78 | optim_group.add_argument('--decay-rate', default=0.1, type=float, metavar='N', 79 | help='decay rate of learning rate (default: 0.1)') 80 | optim_group.add_argument('--momentum', default=0.9, type=float, metavar='M', 81 | help='momentum (default=0.9)') 82 | optim_group.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 83 | metavar='W', help='weight decay (default: 1e-4)') 84 | 85 | args = arg_parser.parse_args() 86 | 87 | args.grFactor = list(map(int, args.grFactor.split('-'))) 88 | args.bnFactor = list(map(int, args.bnFactor.split('-'))) 89 | args.scale_list = list(map(int, args.scale_list.split('-'))) 90 | args.nScales = len(args.grFactor) 91 | 92 | if args.use_valid: 93 | args.splits = ['train', 'val', 'test'] 94 | else: 95 | args.splits = ['train', 'val'] 96 | 97 | if args.data == 'cifar10': 98 | args.num_classes = 10 99 | elif args.data == 'cifar100': 100 | args.num_classes = 100 101 | else: 102 | args.num_classes = 1000 103 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | 6 | 7 | def get_dataloaders(args): 8 | train_loader, val_loader, test_loader = None, None, None 9 | if args.data == 'cifar10': 10 | normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467], 11 | std=[0.2471, 0.2435, 0.2616]) 12 | train_set = datasets.CIFAR10(args.data_root, train=True, 13 | transform=transforms.Compose([ 14 | transforms.RandomCrop(32, padding=4), 15 | transforms.RandomHorizontalFlip(), 16 | transforms.ToTensor(), 17 | normalize 18 | ])) 19 | val_set = datasets.CIFAR10(args.data_root, train=False, 20 | transform=transforms.Compose([ 21 | transforms.ToTensor(), 22 | normalize 23 | ])) 24 | elif args.data == 'cifar100': 25 | normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], 26 | std=[0.2675, 0.2565, 0.2761]) 27 | train_set = datasets.CIFAR100(args.data_root, train=True, 28 | transform=transforms.Compose([ 29 | transforms.RandomCrop(32, padding=4), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.ToTensor(), 32 | normalize 33 | ])) 34 | val_set = datasets.CIFAR100(args.data_root, train=False, 35 | transform=transforms.Compose([ 36 | transforms.ToTensor(), 37 | normalize 38 | ])) 39 | else: 40 | # ImageNet 41 | traindir = os.path.join(args.data_root, 'train') 42 | valdir = os.path.join(args.data_root, 'val') 43 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 44 | std=[0.229, 0.224, 0.225]) 45 | train_set = datasets.ImageFolder(traindir, transforms.Compose([ 46 | transforms.RandomResizedCrop(224), 47 | transforms.RandomHorizontalFlip(), 48 | transforms.ToTensor(), 49 | normalize 50 | ])) 51 | val_set = datasets.ImageFolder(valdir, transforms.Compose([ 52 | transforms.Resize(256), 53 | transforms.CenterCrop(224), 54 | transforms.ToTensor(), 55 | normalize 56 | ])) 57 | if args.use_valid: 58 | train_set_index = torch.randperm(len(train_set)) 59 | if os.path.exists(os.path.join(args.save, 'index.pth')): 60 | print('!!!!!! Load train_set_index !!!!!!') 61 | train_set_index = torch.load(os.path.join(args.save, 'index.pth')) 62 | else: 63 | print('!!!!!! Save train_set_index !!!!!!') 64 | torch.save(train_set_index, os.path.join(args.save, 'index.pth')) 65 | if args.data.startswith('cifar'): 66 | num_sample_valid = 5000 67 | else: 68 | num_sample_valid = 50000 69 | 70 | if 'train' in args.splits: 71 | train_loader = torch.utils.data.DataLoader( 72 | train_set, batch_size=args.batch_size, 73 | sampler=torch.utils.data.sampler.SubsetRandomSampler( 74 | train_set_index[:-num_sample_valid]), 75 | num_workers=args.workers, pin_memory=False) 76 | if 'val' in args.splits: 77 | val_loader = torch.utils.data.DataLoader( 78 | train_set, batch_size=args.batch_size, 79 | sampler=torch.utils.data.sampler.SubsetRandomSampler( 80 | train_set_index[-num_sample_valid:]), 81 | num_workers=args.workers, pin_memory=False) 82 | if 'test' in args.splits: 83 | test_loader = torch.utils.data.DataLoader( 84 | val_set, 85 | batch_size=args.batch_size, shuffle=False, 86 | num_workers=args.workers, pin_memory=False) 87 | else: 88 | if 'train' in args.splits: 89 | train_loader = torch.utils.data.DataLoader( 90 | train_set, 91 | batch_size=args.batch_size, shuffle=True, 92 | num_workers=args.workers, pin_memory=False) 93 | if 'val' or 'test' in args.splits: 94 | val_loader = torch.utils.data.DataLoader( 95 | val_set, 96 | batch_size=args.batch_size, shuffle=False, 97 | num_workers=args.workers, pin_memory=False) 98 | test_loader = val_loader 99 | 100 | return train_loader, val_loader, test_loader 101 | -------------------------------------------------------------------------------- /imgs/RANet_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangle15/RANet-pytorch/be0f25a2286160bc612181507cb44a3ff2cd0e46/imgs/RANet_overview.png -------------------------------------------------------------------------------- /imgs/anytime_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangle15/RANet-pytorch/be0f25a2286160bc612181507cb44a3ff2cd0e46/imgs/anytime_results.png -------------------------------------------------------------------------------- /imgs/dynamic_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangle15/RANet-pytorch/be0f25a2286160bc612181507cb44a3ff2cd0e46/imgs/dynamic_results.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import shutil 5 | import models 6 | 7 | from dataloader import get_dataloaders 8 | from args import args 9 | from adaptive_inference import dynamic_evaluate 10 | from op_counter import measure_model 11 | 12 | import torch 13 | import torch.optim 14 | import torch.nn as nn 15 | import torch.backends.cudnn as cudnn 16 | 17 | torch.manual_seed(args.seed) 18 | 19 | if args.gpu: 20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 21 | 22 | 23 | def main(): 24 | 25 | global args 26 | best_prec1, best_epoch = 0.0, 0 27 | 28 | if not os.path.exists(args.save): 29 | os.makedirs(args.save) 30 | 31 | if args.data.startswith('cifar'): 32 | IM_SIZE = 32 33 | else: 34 | IM_SIZE = 224 35 | 36 | print(args.arch) 37 | model = getattr(models, args.arch)(args) 38 | args.num_exits = len(model.classifier) 39 | global n_flops 40 | 41 | n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE) 42 | 43 | torch.save(n_flops, os.path.join(args.save, 'flops.pth')) 44 | del(model) 45 | 46 | print(args) 47 | with open('{}/args.txt'.format(args.save), 'w') as f: 48 | print(args, file=f) 49 | 50 | model = getattr(models, args.arch)(args) 51 | model = torch.nn.DataParallel(model.cuda()) 52 | criterion = nn.CrossEntropyLoss().cuda() 53 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 54 | 55 | if args.resume: 56 | checkpoint = load_checkpoint(args) 57 | if checkpoint is not None: 58 | args.start_epoch = checkpoint['epoch'] + 1 59 | best_prec1 = checkpoint['best_prec1'] 60 | model.load_state_dict(checkpoint['state_dict']) 61 | optimizer.load_state_dict(checkpoint['optimizer']) 62 | 63 | cudnn.benchmark = True 64 | 65 | train_loader, val_loader, test_loader = get_dataloaders(args) 66 | 67 | if args.evalmode is not None: 68 | state_dict = torch.load(args.evaluate_from)['state_dict'] 69 | model.load_state_dict(state_dict) 70 | 71 | if args.evalmode == 'anytime': 72 | validate(test_loader, model, criterion) 73 | elif args.evalmode == 'dynamic': 74 | dynamic_evaluate(model, test_loader, val_loader, args) 75 | else: 76 | validate(test_loader, model, criterion) 77 | dynamic_evaluate(model, test_loader, val_loader, args) 78 | return 79 | 80 | scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1' 81 | '\tval_prec1\ttrain_prec5\tval_prec5'] 82 | 83 | for epoch in range(args.start_epoch, args.epochs): 84 | 85 | train_loss, train_prec1, train_prec5, lr = train(train_loader, model, criterion, optimizer, epoch) 86 | 87 | val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion) 88 | 89 | scores.append(('{}\t{:.3f}' + '\t{:.4f}' * 6) 90 | .format(epoch, lr, train_loss, val_loss, 91 | train_prec1, val_prec1, train_prec5, val_prec5)) 92 | 93 | is_best = val_prec1 > best_prec1 94 | if is_best: 95 | best_prec1 = val_prec1 96 | best_epoch = epoch 97 | print('Best var_prec1 {}'.format(best_prec1)) 98 | 99 | model_filename = 'checkpoint_%03d.pth.tar' % epoch 100 | save_checkpoint({ 101 | 'epoch': epoch, 102 | 'arch': args.arch, 103 | 'state_dict': model.state_dict(), 104 | 'best_prec1': best_prec1, 105 | 'optimizer': optimizer.state_dict(), 106 | }, args, is_best, model_filename, scores) 107 | 108 | model_path = '%s/save_models/checkpoint_%03d.pth.tar' % (args.save, epoch-1) 109 | if os.path.exists(model_path): 110 | os.remove(model_path) 111 | 112 | print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch)) 113 | 114 | ### Test the final model 115 | print('********** Final prediction results **********') 116 | validate(test_loader, model, criterion) 117 | 118 | return 119 | 120 | def train(train_loader, model, criterion, optimizer, epoch): 121 | batch_time = AverageMeter() 122 | data_time = AverageMeter() 123 | losses = AverageMeter() 124 | top1, top5 = [], [] 125 | for i in range(args.num_exits): 126 | top1.append(AverageMeter()) 127 | top5.append(AverageMeter()) 128 | 129 | # switch to train mode 130 | model.train() 131 | 132 | end = time.time() 133 | 134 | running_lr = None 135 | for i, (input, target) in enumerate(train_loader): 136 | lr = adjust_learning_rate(optimizer, epoch, args, batch=i, 137 | nBatch=len(train_loader), method=args.lr_type) 138 | 139 | if running_lr is None: 140 | running_lr = lr 141 | 142 | data_time.update(time.time() - end) 143 | 144 | target = target.cuda(non_blocking=True) 145 | input_var = torch.autograd.Variable(input) 146 | target_var = torch.autograd.Variable(target) 147 | 148 | output = model(input_var) 149 | if not isinstance(output, list): 150 | output = [output] 151 | 152 | loss = 0.0 153 | for j in range(len(output)): 154 | loss += criterion(output[j], target_var) 155 | 156 | losses.update(loss.item(), input.size(0)) 157 | 158 | for j in range(len(output)): 159 | prec1, prec5 = accuracy(output[j].data, target, topk=(1, 5)) 160 | top1[j].update(prec1.item(), input.size(0)) 161 | top5[j].update(prec5.item(), input.size(0)) 162 | 163 | # compute gradient and do SGD step 164 | optimizer.zero_grad() 165 | loss.backward() 166 | optimizer.step() 167 | 168 | # measure elapsed time 169 | batch_time.update(time.time() - end) 170 | end = time.time() 171 | 172 | if i % args.print_freq == 0: 173 | print('Epoch: [{0}][{1}/{2}]\t' 174 | 'Time {batch_time.avg:.3f}\t' 175 | 'Data {data_time.avg:.3f}\t' 176 | 'Loss {loss.val:.4f}\t' 177 | 'Acc@1 {top1.val:.4f}\t' 178 | 'Acc@5 {top5.val:.4f}'.format( 179 | epoch, i + 1, len(train_loader), 180 | batch_time=batch_time, data_time=data_time, 181 | loss=losses, top1=top1[-1], top5=top5[-1])) 182 | 183 | return losses.avg, top1[-1].avg, top5[-1].avg, running_lr 184 | 185 | def validate(val_loader, model, criterion): 186 | batch_time = AverageMeter() 187 | losses = AverageMeter() 188 | data_time = AverageMeter() 189 | top1, top5 = [], [] 190 | for i in range(args.num_exits): 191 | top1.append(AverageMeter()) 192 | top5.append(AverageMeter()) 193 | 194 | model.eval() 195 | 196 | end = time.time() 197 | with torch.no_grad(): 198 | for i, (input, target) in enumerate(val_loader): 199 | target = target.cuda(non_blocking=True) 200 | input = input.cuda() 201 | 202 | input_var = torch.autograd.Variable(input) 203 | target_var = torch.autograd.Variable(target) 204 | 205 | data_time.update(time.time() - end) 206 | 207 | output = model(input_var) 208 | if not isinstance(output, list): 209 | output = [output] 210 | 211 | loss = 0.0 212 | for j in range(len(output)): 213 | loss += criterion(output[j], target_var) 214 | 215 | losses.update(loss.item(), input.size(0)) 216 | 217 | for j in range(len(output)): 218 | prec1, prec5 = accuracy(output[j].data, target, topk=(1, 5)) 219 | top1[j].update(prec1.item(), input.size(0)) 220 | top5[j].update(prec5.item(), input.size(0)) 221 | 222 | # measure elapsed time 223 | batch_time.update(time.time() - end) 224 | end = time.time() 225 | 226 | if i % args.print_freq == 0: 227 | print('Epoch: [{0}/{1}]\t' 228 | 'Time {batch_time.avg:.3f}\t' 229 | 'Data {data_time.avg:.3f}\t' 230 | 'Loss {loss.val:.4f}\t' 231 | 'Acc@1 {top1.val:.4f}\t' 232 | 'Acc@5 {top5.val:.4f}'.format( 233 | i + 1, len(val_loader), 234 | batch_time=batch_time, data_time=data_time, 235 | loss=losses, top1=top1[-1], top5=top5[-1])) 236 | 237 | result_file = os.path.join(args.save, 'AnytimeResults.txt') 238 | 239 | fd = open(result_file, 'w+') 240 | fd.write('AnytimeResults' + '\n') 241 | for j in range(args.num_exits): 242 | test_str = (' @{ext}** flops {flops:.2f}M prec@1 {top1.avg:.3f} prec@5 {top5.avg:.3f}'.format(ext = j+1, flops=n_flops[j]/1e6, top1=top1[j], top5=top5[j])) 243 | print(test_str) 244 | fd = open(result_file, 'a+') 245 | fd.write(test_str + '\n') 246 | fd.close() 247 | torch.save([e.avg for e in top1], os.path.join(args.save, 'acc.pth')) 248 | return losses.avg, top1[-1].avg, top5[-1].avg 249 | 250 | def save_checkpoint(state, args, is_best, filename, result): 251 | print(args) 252 | result_filename = os.path.join(args.save, 'scores.tsv') 253 | model_dir = os.path.join(args.save, 'save_models') 254 | latest_filename = os.path.join(model_dir, 'latest.txt') 255 | model_filename = os.path.join(model_dir, filename) 256 | best_filename = os.path.join(model_dir, 'model_best.pth.tar') 257 | os.makedirs(args.save, exist_ok=True) 258 | os.makedirs(model_dir, exist_ok=True) 259 | print("=> saving checkpoint '{}'".format(model_filename)) 260 | 261 | torch.save(state, model_filename) 262 | 263 | with open(result_filename, 'w') as f: 264 | print('\n'.join(result), file=f) 265 | 266 | with open(latest_filename, 'w') as fout: 267 | fout.write(model_filename) 268 | 269 | if is_best: 270 | shutil.copyfile(model_filename, best_filename) 271 | best_filename_epoch = os.path.join(model_dir, 'best_model_epoch.txt') 272 | with open(best_filename_epoch, 'w') as fout: 273 | fout.write(model_filename) 274 | 275 | print("=> saved checkpoint '{}'".format(model_filename)) 276 | return 277 | 278 | def load_checkpoint(args): 279 | model_dir = os.path.join(args.save, 'save_models') 280 | latest_filename = os.path.join(model_dir, 'latest.txt') 281 | if os.path.exists(latest_filename): 282 | with open(latest_filename, 'r') as fin: 283 | model_filename = fin.readlines()[0].strip() 284 | else: 285 | return None 286 | print("=> loading checkpoint '{}'".format(model_filename)) 287 | state = torch.load(model_filename) 288 | print("=> loaded checkpoint '{}'".format(model_filename)) 289 | return state 290 | 291 | class AverageMeter(object): 292 | """Computes and stores the average and current value""" 293 | 294 | def __init__(self): 295 | self.reset() 296 | 297 | def reset(self): 298 | self.val = 0 299 | self.avg = 0 300 | self.sum = 0 301 | self.count = 0 302 | 303 | def update(self, val, n=1): 304 | self.val = val 305 | self.sum += val * n 306 | self.count += n 307 | self.avg = self.sum / self.count 308 | 309 | def accuracy(output, target, topk=(1,)): 310 | """Computes the precor@k for the specified values of k""" 311 | maxk = max(topk) 312 | batch_size = target.size(0) 313 | 314 | _, pred = output.topk(maxk, 1, True, True) 315 | pred = pred.t() 316 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 317 | 318 | res = [] 319 | for k in topk: 320 | correct_k = correct[:k].view(-1).float().sum(0) 321 | res.append(correct_k.mul_(100.0 / batch_size)) 322 | return res 323 | 324 | def adjust_learning_rate(optimizer, epoch, args, batch=None, 325 | nBatch=None, method='multistep'): 326 | if method == 'cosine': 327 | T_total = args.epochs * nBatch 328 | T_cur = (epoch % args.epochs) * nBatch + batch 329 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * T_cur / T_total)) 330 | elif method == 'multistep': 331 | if args.data.startswith('cifar'): 332 | lr, decay_rate = args.lr, 0.1 333 | if epoch >= args.epochs * 0.75: 334 | lr *= decay_rate ** 2 335 | elif epoch >= args.epochs * 0.5: 336 | lr *= decay_rate 337 | else: 338 | lr = args.lr * (0.1 ** (epoch // 30)) 339 | for param_group in optimizer.param_groups: 340 | param_group['lr'] = lr 341 | return lr 342 | 343 | if __name__ == '__main__': 344 | main() 345 | -------------------------------------------------------------------------------- /models/RANet.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import os 3 | import copy 4 | import math 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class ConvBasic(nn.Module): 13 | def __init__(self, nIn, nOut, kernel=3, stride=1, padding=1): 14 | super(ConvBasic, self).__init__() 15 | self.net = nn.Sequential( 16 | nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride, 17 | padding=padding, bias=False), 18 | nn.BatchNorm2d(nOut), 19 | nn.ReLU(True) 20 | ) 21 | 22 | def forward(self, x): 23 | return self.net(x) 24 | 25 | 26 | class ConvBN(nn.Module): 27 | def __init__(self, nIn, nOut, type: str, bnAfter, bnWidth): 28 | """ 29 | a basic conv in RANet, two type 30 | :param nIn: 31 | :param nOut: 32 | :param type: normal or down 33 | :param bnAfter: the location of batch Norm 34 | :param bnWidth: bottleneck factor 35 | """ 36 | super(ConvBN, self).__init__() 37 | layer = [] 38 | nInner = nIn 39 | if bnAfter is True: 40 | nInner = min(nInner, bnWidth * nOut) 41 | layer.append(nn.Conv2d( 42 | nIn, nInner, kernel_size=1, stride=1, padding=0, bias=False)) 43 | layer.append(nn.BatchNorm2d(nInner)) 44 | layer.append(nn.ReLU(True)) 45 | if type == 'normal': 46 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3, 47 | stride=1, padding=1, bias=False)) 48 | elif type == 'down': 49 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3, 50 | stride=2, padding=1, bias=False)) 51 | else: 52 | raise ValueError 53 | layer.append(nn.BatchNorm2d(nOut)) 54 | layer.append(nn.ReLU(True)) 55 | 56 | else: 57 | nInner = min(nInner, bnWidth * nOut) 58 | layer.append(nn.BatchNorm2d(nIn)) 59 | layer.append(nn.ReLU(True)) 60 | layer.append(nn.Conv2d( 61 | nIn, nInner, kernel_size=1, stride=1, padding=0, bias=False)) 62 | layer.append(nn.BatchNorm2d(nInner)) 63 | layer.append(nn.ReLU(True)) 64 | if type == 'normal': 65 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3, 66 | stride=1, padding=1, bias=False)) 67 | elif type == 'down': 68 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3, 69 | stride=2, padding=1, bias=False)) 70 | else: 71 | raise ValueError 72 | 73 | self.net = nn.Sequential(*layer) 74 | 75 | def forward(self, x): 76 | return self.net(x) 77 | 78 | 79 | class ConvUpNormal(nn.Module): 80 | def __init__(self, nIn1, nIn2, nOut, bottleneck, bnWidth1, bnWidth2, compress_factor, down_sample): 81 | ''' 82 | The convolution with normal and up-sampling connection. 83 | ''' 84 | super(ConvUpNormal, self).__init__() 85 | self.conv_up = ConvBN(nIn2, math.floor(nOut*compress_factor), 'normal', 86 | bottleneck, bnWidth2) 87 | if down_sample: 88 | self.conv_normal = ConvBN(nIn1, nOut-math.floor(nOut*compress_factor), 'down', 89 | bottleneck, bnWidth1) 90 | else: 91 | self.conv_normal = ConvBN(nIn1, nOut-math.floor(nOut*compress_factor), 'normal', 92 | bottleneck, bnWidth1) 93 | 94 | def forward(self, x): 95 | res = self.conv_normal(x[1]) 96 | _,_,h,w = res.size() 97 | res = [F.interpolate(x[1], size=(h,w), mode = 'bilinear', align_corners=True), 98 | F.interpolate(self.conv_up(x[0]), size=(h,w), mode = 'bilinear', align_corners=True), 99 | res] 100 | return torch.cat(res, dim=1) 101 | 102 | 103 | class ConvNormal(nn.Module): 104 | def __init__(self, nIn, nOut, bottleneck, bnWidth): 105 | ''' 106 | The convolution with normal connection. 107 | ''' 108 | super(ConvNormal, self).__init__() 109 | self.conv_normal = ConvBN(nIn, nOut, 'normal', 110 | bottleneck, bnWidth) 111 | 112 | def forward(self, x): 113 | if not isinstance(x, list): 114 | x = [x] 115 | res = [x[0], self.conv_normal(x[0])] 116 | return torch.cat(res, dim=1) 117 | 118 | 119 | class _BlockNormal(nn.Module): 120 | def __init__(self, num_layers, nIn, growth_rate, reduction_rate, trans, bnFactor): 121 | ''' 122 | The basic computational block in RANet with num_layers layers. 123 | trans: If True, the block will add a transiation layer at the end of the block 124 | with reduction_rate. 125 | ''' 126 | super(_BlockNormal, self).__init__() 127 | self.layers = nn.ModuleList() 128 | self.num_layers = num_layers 129 | for i in range(num_layers): 130 | self.layers.append(ConvNormal(nIn + i*growth_rate, growth_rate, True, bnFactor)) 131 | nOut = nIn + num_layers*growth_rate 132 | self.trans_flag = trans 133 | if trans: 134 | self.trans = ConvBasic(nOut, math.floor(1.0 * reduction_rate * nOut), kernel=1, stride=1, padding=0) 135 | 136 | def forward(self, x): 137 | output = [x] 138 | for i in range(self.num_layers): 139 | x = self.layers[i](x) 140 | # print(x.size()) 141 | output.append(x) 142 | x = output[-1] 143 | if self.trans_flag: 144 | x = self.trans(x) 145 | return x, output 146 | 147 | def _blockType(self): 148 | return 'norm' 149 | 150 | 151 | class _BlockUpNormal(nn.Module): 152 | def __init__(self, num_layers, nIn, nIn_lowFtrs, growth_rate, reduction_rate, trans, down, compress_factor, bnFactor1, bnFactor2): 153 | ''' 154 | The basic fusion block in RANet with num_layers layers. 155 | trans: If True, the block will add a transiation layer at the end of the block 156 | with reduction_rate. 157 | compress_factor: There will be compress_factor*100% information from the previous 158 | sub-network. 159 | ''' 160 | super(_BlockUpNormal, self).__init__() 161 | 162 | self.layers = nn.ModuleList() 163 | self.num_layers = num_layers 164 | for i in range(num_layers-1): 165 | self.layers.append(ConvUpNormal(nIn + i*growth_rate, nIn_lowFtrs[i], growth_rate, True, bnFactor1, bnFactor2, compress_factor, False)) 166 | 167 | self.layers.append(ConvUpNormal(nIn + (i+1)*growth_rate, nIn_lowFtrs[i+1], growth_rate, True, bnFactor1, bnFactor2, compress_factor, down)) 168 | nOut = nIn + num_layers*growth_rate 169 | 170 | self.conv_last = ConvBasic(nIn_lowFtrs[num_layers], math.floor(nOut*compress_factor), kernel=1, stride=1, padding=0) 171 | nOut = nOut + math.floor(nOut*compress_factor) 172 | self.trans_flag = trans 173 | if trans: 174 | self.trans = ConvBasic(nOut, math.floor(1.0 * reduction_rate * nOut), kernel=1, stride=1, padding=0) 175 | 176 | def forward(self, x, low_feat): 177 | output = [x] 178 | for i in range(self.num_layers): 179 | inp = [low_feat[i]] 180 | inp.append(x) 181 | x = self.layers[i](inp) 182 | output.append(x) 183 | x = output[-1] 184 | _,_,h,w = x.size() 185 | x = [x] 186 | x.append(F.interpolate(self.conv_last(low_feat[self.num_layers]), size=(h,w), mode = 'bilinear', align_corners=True)) 187 | x = torch.cat(x, dim = 1) 188 | if self.trans_flag: 189 | x = self.trans(x) 190 | return x, output 191 | 192 | def _blockType(self): 193 | return 'up' 194 | 195 | 196 | class RAFirstLayer(nn.Module): 197 | def __init__(self, nIn, nOut, args): 198 | ''' 199 | RAFirstLayer gennerates the base features for RANet. 200 | The scale 1 means the lowest resoultion in the network. 201 | ''' 202 | super(RAFirstLayer, self).__init__() 203 | _grFactor = args.grFactor[::-1] # 1-2-4 204 | _scale_list = args.scale_list[::-1] # 3-2-1 205 | self.layers = nn.ModuleList() 206 | if args.data.startswith('cifar'): 207 | self.layers.append(ConvBasic(nIn, nOut * _grFactor[0], 208 | kernel=3, stride=1, padding=1)) 209 | elif args.data == 'ImageNet': 210 | conv = nn.Sequential( 211 | nn.Conv2d(nIn, nOut * _grFactor[0], 7, 2, 3), 212 | nn.BatchNorm2d(nOut * _grFactor[0]), 213 | nn.ReLU(inplace=True), 214 | nn.MaxPool2d(3, 2, 1)) 215 | self.layers.append(conv) 216 | 217 | nIn = nOut * _grFactor[0] 218 | 219 | s = _scale_list[0] 220 | for i in range(1, args.nScales): 221 | if s == _scale_list[i]: 222 | self.layers.append(ConvBasic(nIn, nOut * _grFactor[i], 223 | kernel=3, stride=1, padding=1)) 224 | else: 225 | self.layers.append(ConvBasic(nIn, nOut * _grFactor[i], 226 | kernel=3, stride=2, padding=1)) 227 | s = _scale_list[i] 228 | nIn = nOut * _grFactor[i] 229 | 230 | def forward(self, x): 231 | # res[0] with the smallest resolutions 232 | res = [] 233 | for i in range(len(self.layers)): 234 | x = self.layers[i](x) 235 | res.append(x) 236 | return res[::-1] 237 | 238 | 239 | class RANet(nn.Module): 240 | def __init__(self, args): 241 | super(RANet, self).__init__() 242 | self.scale_flows = nn.ModuleList() 243 | self.classifier = nn.ModuleList() 244 | 245 | # self.args = args 246 | self.compress_factor = args.compress_factor 247 | self.bnFactor = copy.copy(args.bnFactor) 248 | 249 | scale_list = args.scale_list # 1-2-3 250 | self.nScales = len(args.scale_list) # 3 251 | 252 | # The number of blocks in each scale flow 253 | self.nBlocks = [0] 254 | for i in range(self.nScales): 255 | self.nBlocks.append(args.block_step*i + args.nBlocks) # [0, 2, 4, 6] 256 | 257 | # The number of layers in each block 258 | self.steps = args.step 259 | 260 | self.FirstLayer = RAFirstLayer(3, args.nChannels, args) 261 | 262 | steps = [args.step] 263 | for ii in range(self.nScales): 264 | 265 | scale_flow = nn.ModuleList() 266 | 267 | n_block_curr = 1 268 | nIn = args.nChannels*args.grFactor[ii] # grFactor = [4,2,1] 269 | _nIn_lowFtrs = [] 270 | 271 | for i in range(self.nBlocks[ii+1]): 272 | growth_rate = args.growthRate*args.grFactor[ii] 273 | 274 | # If transiation 275 | trans = self._trans_flag(n_block_curr, n_block_all = self.nBlocks[ii+1], inScale = scale_list[ii]) 276 | 277 | if n_block_curr > self.nBlocks[ii]: 278 | m, nOuts = self._build_norm_block(nIn, steps[n_block_curr-1], growth_rate, args.reduction, trans, bnFactor=self.bnFactor[ii]) 279 | if args.stepmode == 'even': 280 | steps.append(args.step) 281 | elif args.stepmode == 'lg': 282 | steps.append(steps[-1]+args.step) 283 | else: 284 | raise NotImplementedError 285 | else: 286 | if n_block_curr in self.nBlocks[:ii+1][-(scale_list[ii]-1):]: 287 | m, nOuts = self._build_upNorm_block(nIn, nIn_lowFtrs[i], steps[n_block_curr-1], growth_rate, args.reduction, trans, down=True, bnFactor1=self.bnFactor[ii], bnFactor2=self.bnFactor[ii-1]) 288 | else: 289 | m, nOuts = self._build_upNorm_block(nIn, nIn_lowFtrs[i], steps[n_block_curr-1], growth_rate, args.reduction, trans, down=False, bnFactor1=self.bnFactor[ii], bnFactor2=self.bnFactor[ii-1]) 290 | 291 | nIn = nOuts[-1] 292 | scale_flow.append(m) 293 | 294 | if n_block_curr > self.nBlocks[ii]: 295 | if args.data.startswith('cifar100'): 296 | self.classifier.append( 297 | self._build_classifier_cifar(nIn, 100)) 298 | elif args.data.startswith('cifar10'): 299 | self.classifier.append(self._build_classifier_cifar(nIn, 10)) 300 | elif args.data == 'ImageNet': 301 | self.classifier.append( 302 | self._build_classifier_imagenet(nIn, 1000)) 303 | else: 304 | raise NotImplementedError 305 | 306 | _nIn_lowFtrs.append(nOuts[:-1]) 307 | n_block_curr += 1 308 | 309 | nIn_lowFtrs = _nIn_lowFtrs 310 | self.scale_flows.append(scale_flow) 311 | 312 | args.num_exits = len(self.classifier) 313 | 314 | for m in self.scale_flows: 315 | for _m in m.modules(): 316 | self._init_weights(_m) 317 | 318 | for m in self.classifier: 319 | for _m in m.modules(): 320 | self._init_weights(_m) 321 | 322 | def _init_weights(self, m): 323 | if isinstance(m, nn.Conv2d): 324 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 325 | m.weight.data.normal_(0, math.sqrt(2. / n)) 326 | elif isinstance(m, nn.BatchNorm2d): 327 | m.weight.data.fill_(1) 328 | m.bias.data.zero_() 329 | elif isinstance(m, nn.Linear): 330 | m.bias.data.zero_() 331 | 332 | def _build_norm_block(self, nIn, step, growth_rate, reduction_rate, trans, bnFactor=2): 333 | 334 | block = _BlockNormal(step, nIn, growth_rate, reduction_rate, trans, bnFactor=bnFactor) 335 | nOuts = [] 336 | for i in range(step+1): 337 | nOut = (nIn + i * growth_rate) 338 | nOuts.append(nOut) 339 | if trans: 340 | nOut = math.floor(1.0 * reduction_rate * nOut) 341 | nOuts.append(nOut) 342 | 343 | return block, nOuts 344 | 345 | def _build_upNorm_block(self, nIn, nIn_lowFtr, step, growth_rate, reduction_rate, trans, down, bnFactor1=1, bnFactor2=2): 346 | compress_factor = self.compress_factor 347 | 348 | block = _BlockUpNormal(step, nIn, nIn_lowFtr, growth_rate, reduction_rate, trans, down, compress_factor, bnFactor1=bnFactor1, bnFactor2=bnFactor2) 349 | nOuts = [] 350 | for i in range(step+1): 351 | nOut = (nIn + i * growth_rate) 352 | nOuts.append(nOut) 353 | nOut = nOut + math.floor(nOut*compress_factor) 354 | if trans: 355 | nOut = math.floor(1.0 * reduction_rate * nOut) 356 | nOuts.append(nOut) 357 | 358 | return block, nOuts 359 | 360 | def _trans_flag(self, n_block_curr, n_block_all, inScale): 361 | flag = False 362 | for i in range(inScale-1): 363 | if n_block_curr == math.floor((i+1)*n_block_all /inScale): 364 | flag = True 365 | return flag 366 | 367 | def forward(self, x): 368 | inp = self.FirstLayer(x) 369 | res, low_ftrs = [], [] 370 | classifier_idx = 0 371 | for ii in range(self.nScales): 372 | _x = inp[ii] 373 | _low_ftrs = [] 374 | n_block_curr = 0 375 | for i in range(self.nBlocks[ii+1]): 376 | if self.scale_flows[ii][i]._blockType() == 'norm': 377 | _x, _low_ftr = self.scale_flows[ii][i](_x) 378 | _low_ftrs.append(_low_ftr) 379 | else: 380 | _x, _low_ftr = self.scale_flows[ii][i](_x, low_ftrs[i]) 381 | _low_ftrs.append(_low_ftr) 382 | n_block_curr += 1 383 | 384 | if n_block_curr > self.nBlocks[ii]: 385 | res.append(self.classifier[classifier_idx](_x)) 386 | classifier_idx += 1 387 | 388 | low_ftrs = _low_ftrs 389 | return res 390 | 391 | def _build_classifier_cifar(self, nIn, num_classes): 392 | interChannels1, interChannels2 = 128, 128 393 | conv = nn.Sequential( 394 | ConvBasic(nIn, interChannels1, kernel=3, stride=2, padding=1), 395 | ConvBasic(interChannels1, interChannels2, kernel=3, stride=2, padding=1), 396 | nn.AvgPool2d(2), 397 | ) 398 | return ClassifierModule(conv, interChannels2, num_classes) 399 | 400 | def _build_classifier_imagenet(self, nIn, num_classes): 401 | conv = nn.Sequential( 402 | ConvBasic(nIn, nIn, kernel=3, stride=2, padding=1), 403 | ConvBasic(nIn, nIn, kernel=3, stride=2, padding=1), 404 | nn.AvgPool2d(2) 405 | ) 406 | return ClassifierModule(conv, nIn, num_classes) 407 | 408 | class ClassifierModule(nn.Module): 409 | def __init__(self, m, channel, num_classes): 410 | super(ClassifierModule, self).__init__() 411 | self.m = m 412 | self.linear = nn.Linear(channel, num_classes) 413 | def forward(self, x): 414 | res = self.m(x) 415 | res = res.view(res.size(0), -1) 416 | return self.linear(res) 417 | 418 | 419 | if __name__ == '__main__': 420 | from args_v5 import arg_parser 421 | from op_counter import measure_model 422 | 423 | args = arg_parser.parse_args() 424 | # if args.gpu: 425 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 426 | 427 | args.nBlocks = 2 428 | args.Block_base = 2 429 | args.step = 8 430 | args.stepmode ='even' 431 | args.compress_factor = 0.25 432 | args.nChannels = 64 433 | args.data = 'ImageNet' 434 | args.growthRate = 16 435 | 436 | args.grFactor = '4-2-2-1' 437 | args.bnFactor = '4-2-2-1' 438 | args.scale_list = '1-2-3-4' 439 | 440 | args.reduction = 0.5 441 | 442 | args.grFactor = list(map(int, args.grFactor.split('-'))) 443 | args.bnFactor = list(map(int, args.bnFactor.split('-'))) 444 | args.scale_list = list(map(int, args.scale_list.split('-'))) 445 | args.nScales = len(args.grFactor) 446 | # print(args.grFactor) 447 | if args.use_valid: 448 | args.splits = ['train', 'val', 'test'] 449 | else: 450 | args.splits = ['train', 'val'] 451 | 452 | if args.data == 'cifar10': 453 | args.num_classes = 10 454 | elif args.data == 'cifar100': 455 | args.num_classes = 100 456 | else: 457 | args.num_classes = 1000 458 | 459 | inp_c = torch.rand(16,3,224,224) 460 | 461 | model = MSDNet(args) 462 | # output = model(inp_c) 463 | # oup = net_head(inp_c) 464 | # print(len(oup)) 465 | 466 | n_flops, n_params = measure_model(model, 224, 224) 467 | # net = _BlockNormal(num_layers = 4, nIn = 64, growth_rate = 24, reduction_rate = 0.5, trans_down = True) 468 | 469 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # from .msdnet import MSDNet as msdnet 2 | # from .msdnetV5 import MSDNet as msdnetV5 3 | # from .msdnetV5_imagenet import MSDNet as msdnetV5_imagenet 4 | #from .msdnetV5_bnf import MSDNet as msdnetV5_bnf 5 | #from .msdnetV5_bnf2 import MSDNet as msdnetV5_bnf2 6 | from .RANet import RANet 7 | 8 | #from .msdnetV5_bnf_lg_ba import MSDNet as msdnetv5_ba 9 | #from .msdnetV5_bnf_lg_ba_drop import MSDNet as msdnetv5_ba_drop 10 | 11 | #from .ranet_1 import MSDNet as ranet1 12 | 13 | -------------------------------------------------------------------------------- /op_counter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import unicode_literals 3 | from __future__ import print_function 4 | from __future__ import division 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.autograd import Variable 9 | from functools import reduce 10 | import operator 11 | 12 | ''' 13 | Calculate the FLOPS of each exit without lazy prediction pruning" 14 | ''' 15 | 16 | count_ops = 0 17 | count_params = 0 18 | cls_ops = [] 19 | cls_params = [] 20 | 21 | def get_num_gen(gen): 22 | return sum(1 for x in gen) 23 | 24 | 25 | def is_leaf(model): 26 | return get_num_gen(model.children()) == 0 27 | 28 | 29 | def get_layer_info(layer): 30 | layer_str = str(layer) 31 | type_name = layer_str[:layer_str.find('(')].strip() 32 | return type_name 33 | 34 | 35 | def get_layer_param(model): 36 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()]) 37 | 38 | 39 | ### The input batch size should be 1 to call this function 40 | def measure_layer(layer, x): 41 | global count_ops, count_params, cls_ops, cls_params 42 | delta_ops = 0 43 | delta_params = 0 44 | multi_add = 1 45 | type_name = get_layer_info(layer) 46 | 47 | ### ops_conv 48 | if type_name in ['Conv2d']: 49 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) / 50 | layer.stride[0] + 1) 51 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) / 52 | layer.stride[1] + 1) 53 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \ 54 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add 55 | delta_params = get_layer_param(layer) 56 | 57 | ### ops_nonlinearity 58 | elif type_name in ['ReLU']: 59 | delta_ops = x.numel() 60 | delta_params = get_layer_param(layer) 61 | 62 | ### ops_pooling 63 | elif type_name in ['AvgPool2d', 'MaxPool2d']: 64 | in_w = x.size()[2] 65 | kernel_ops = layer.kernel_size * layer.kernel_size 66 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 67 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1) 68 | delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops 69 | delta_params = get_layer_param(layer) 70 | 71 | elif type_name in ['AdaptiveAvgPool2d']: 72 | delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] 73 | delta_params = get_layer_param(layer) 74 | 75 | ### ops_linear 76 | elif type_name in ['Linear']: 77 | weight_ops = layer.weight.numel() * multi_add 78 | bias_ops = layer.bias.numel() 79 | delta_ops = x.size()[0] * (weight_ops + bias_ops) 80 | delta_params = get_layer_param(layer) 81 | 82 | ### ops_nothing 83 | elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout', 84 | 'MSDNFirstLayer', 'ConvBasic', 'ConvBN', 85 | 'ParallelModule', 'MSDNet', 'Sequential', 86 | 'MSDNLayer', 'ConvDownNormal', 'ConvNormal', 'ClassifierModule']: 87 | delta_params = get_layer_param(layer) 88 | 89 | 90 | ### unknown layer type 91 | else: 92 | raise TypeError('unknown layer type: %s' % type_name) 93 | 94 | count_ops += delta_ops 95 | count_params += delta_params 96 | if type_name == 'Linear': 97 | print('---------------------') 98 | print('FLOPs: %.2fM, Params: %.2fM' % (count_ops / 1e6, count_params / 1e6)) 99 | cls_ops.append(count_ops) 100 | cls_params.append(count_params) 101 | 102 | return 103 | 104 | 105 | def measure_model(model, H, W): 106 | global count_ops, count_params, cls_ops, cls_params 107 | count_ops = 0 108 | count_params = 0 109 | data = Variable(torch.zeros(1, 3, H, W)) 110 | 111 | def should_measure(x): 112 | return is_leaf(x) 113 | 114 | def modify_forward(model): 115 | for child in model.children(): 116 | if should_measure(child): 117 | def new_forward(m): 118 | def lambda_forward(x): 119 | measure_layer(m, x) 120 | return m.old_forward(x) 121 | return lambda_forward 122 | child.old_forward = child.forward 123 | child.forward = new_forward(child) 124 | else: 125 | modify_forward(child) 126 | 127 | def restore_forward(model): 128 | for child in model.children(): 129 | # leaf node 130 | if is_leaf(child) and hasattr(child, 'old_forward'): 131 | child.forward = child.old_forward 132 | child.old_forward = None 133 | else: 134 | restore_forward(child) 135 | 136 | model.eval() 137 | modify_forward(model) 138 | model.forward(data) 139 | restore_forward(model) 140 | return cls_ops, cls_params 141 | -------------------------------------------------------------------------------- /train_cifar.sh: -------------------------------------------------------------------------------- 1 | python main.py --arch RANet --gpu '0' --data-root {your data root} --data 'cifar10' --step 4 --stepmode 'even' --scale-list '1-2-3-3' --grFactor '4-2-1-1' --bnFactor '4-2-1-1' -------------------------------------------------------------------------------- /train_imagenet.sh: -------------------------------------------------------------------------------- 1 | python main.py --arch RANet --gpu '0,1,2,3' --data-root {your data root} --data 'ImageNet' --growthRate 16 --step 8 --stepmode 'even' --scale-list '1-2-3-4' --grFactor '4-2-1-1' --bnFactor '4-2-1-1' --------------------------------------------------------------------------------