├── .gitignore
├── README.md
├── adaptive_inference.py
├── args.py
├── dataloader.py
├── imgs
├── RANet_overview.png
├── anytime_results.png
└── dynamic_results.png
├── main.py
├── models
├── RANet.py
└── __init__.py
├── op_counter.py
├── train_cifar.sh
└── train_imagenet.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Resolution Adaptive Networks for Efficient Inference (CVPR2020)
2 | [Le Yang*](https://github.com/yangle15), [Yizeng Han*](https://github.com/thuallen), [Xi Chen*](https://github.com/FateDawnLeon), Shiji Song, [Jifeng Dai](https://github.com/daijifeng001), [Gao Huang](https://github.com/gaohuang)
3 |
4 | This repository contains the implementation of the paper, '[Resolution Adaptive Networks for Efficient Inference](https://arxiv.org/pdf/2003.07326.pdf)'. The proposed Resolution Adaptive Networks (RANet) conduct the adaptive inferece by exploiting the ``spatial redundancy`` of input images. Our motivation is that low-resolution representations are sufficient for classifying easy samples containing large objects with prototypical features, while only some hard samples need spatially detailed information, which can be demonstrated by the follow figure.
5 |
6 |

7 |
8 | ## Results
9 |
10 | 
11 |
12 | Accuracy (top-1) of anytime prediction models as a function of computational budget on the CIFAR-10 (left), CIFAR-100
13 | (middle) and ImageNet (right) datasets. Higher is better.
14 |
15 | 
16 |
17 | Accuracy (top-1) of budgeted batch classification models as a function of average computational budget per image the on CIFAR-
18 | 10 (left), CIFAR-100 (middle) and ImageNet (right) datasets. Higher is better.
19 |
20 | ## Dependencies:
21 |
22 | * Python3
23 |
24 | * PyTorch >= 1.0
25 |
26 | ## Usage
27 | We Provide shell scripts for training a RANet on CIFAR and ImageNet.
28 |
29 | ### Train a RANet on CIFAR
30 | * Modify the train_cifar.sh to config your path to the dataset, your GPU devices and your saving directory. Then run
31 | ```sh
32 | bash train_cifar.sh
33 | ```
34 |
35 | * You can train your RANet with other configurations.
36 | ```sh
37 | python main.py --arch RANet --gpu '0' --data-root YOUR_DATA_PATH --data 'cifar10' --step 2 --nChannels 16 --stepmode 'lg' --scale-list '1-2-3' --grFactor '4-2-1' --bnFactor '4-2-1'
38 | ```
39 |
40 | ### Train a RANet on ImageNet
41 | * Modify the train_imagenet.sh to config your path to the dataset, your GPU devices and your saving directory. Then run
42 | ```sh
43 | bash train_imagenet.sh
44 | ```
45 |
46 | * You can train your RANet with other configurations.
47 | ```sh
48 | python main.py --arch RANet --gpu '0,1,2,3' --data-root YOUR_DATA_PATH --data 'ImageNet' --step 8 --growthRate 16 --nChannels 32 --stepmode 'even' --scale-list '1-2-3-4' --grFactor '4-2-2-1' --bnFactor '4-2-2-1'
49 | ```
50 |
51 |
52 |
53 | ### Citation
54 | If you find this work useful or use our codes in your own research, please use the following bibtex:
55 | ```
56 | @inproceedings{yang2020resolution,
57 | title={Resolution Adaptive Networks for Efficient Inference},
58 | author={Yang, Le and Han, Yizeng and Chen, Xi and Song, Shiji and Dai, Jifeng and Huang, Gao},
59 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
60 | year={2020}
61 | }
62 | ```
63 |
64 | ### Contact
65 | If you have any questions, please feel free to contact the authors.
66 |
67 | Le Yang: yangle15@mails.tsinghua.edu.cn
68 |
69 | Yizeng Han: [hanyz18@mails.tsinghua.edu.cn](mailto:hanyz18@mails.tsinghua.edu.cn)
70 |
71 | ### Acknowledgments
72 | We use the pytorch implementation of MSDNet in our experiments. The code can be found [here](https://github.com/kalviny/MSDNet-PyTorch).
73 |
74 |
75 |
76 |
--------------------------------------------------------------------------------
/adaptive_inference.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import os
7 | import math
8 | import torch
9 | import torch.nn as nn
10 |
11 | def dynamic_evaluate(model, test_loader, val_loader, args):
12 | tester = Tester(model, args)
13 | if os.path.exists(os.path.join(args.save, 'logits_single.pth')):
14 | val_pred, val_target, test_pred, test_target = \
15 | torch.load(os.path.join(args.save, 'logits_single.pth'))
16 | else:
17 | val_pred, val_target = tester.calc_logit(val_loader)
18 | test_pred, test_target = tester.calc_logit(test_loader)
19 | torch.save((val_pred, val_target, test_pred, test_target),
20 | os.path.join(args.save, 'logits_single.pth'))
21 |
22 | flops = torch.load(os.path.join(args.save, 'flops.pth'))
23 |
24 | acc_list, exp_flops_list = [], []
25 | with open(os.path.join(args.save, 'dynamic.txt'), 'w') as fout:
26 | samples = {}
27 | for p in range(1, 40):
28 | print("*********************")
29 | _p = torch.FloatTensor(1).fill_(p * 1.0 / 20)
30 | probs = torch.exp(torch.log(_p) * torch.range(1, args.num_exits))
31 | probs /= probs.sum()
32 | acc_val, _, T = tester.dynamic_eval_find_threshold(
33 | val_pred, val_target, probs, flops)
34 | acc_test, exp_flops, exit_buckets = tester.dynamic_eval_with_threshold(
35 | test_pred, test_target, flops, T)
36 | print('valid acc: {:.3f}, test acc: {:.3f}, test flops: {:.2f}M'.format(acc_val, acc_test, exp_flops / 1e6))
37 | fout.write('{}\t{}\n'.format(acc_test, exp_flops.item()))
38 | acc_list.append(acc_test)
39 | exp_flops_list.append(exp_flops)
40 | samples[p] = exit_buckets
41 | torch.save([exp_flops_list, acc_list], os.path.join(args.save, 'dynamic.pth'))
42 | torch.save(samples, os.path.join(args.save, 'exit_samples_by_p.pth'))
43 | # return acc_list, exp_flops_list
44 |
45 |
46 | class Tester(object):
47 | def __init__(self, model, args=None):
48 | self.args = args
49 | self.model = model
50 | self.softmax = nn.Softmax(dim=1).cuda()
51 |
52 | def calc_logit(self, dataloader):
53 | self.model.eval()
54 | n_stage = self.args.num_exits
55 | logits = [[] for _ in range(n_stage)]
56 | targets = []
57 | for i, (input, target) in enumerate(dataloader):
58 | targets.append(target)
59 | with torch.no_grad():
60 | input_var = torch.autograd.Variable(input)
61 | output = self.model(input_var)
62 | if not isinstance(output, list):
63 | output = [output]
64 | for b in range(n_stage):
65 | _t = self.softmax(output[b])
66 |
67 | logits[b].append(_t)
68 |
69 | if i % self.args.print_freq == 0:
70 | print('Generate Logit: [{0}/{1}]'.format(i, len(dataloader)))
71 |
72 | for b in range(n_stage):
73 | logits[b] = torch.cat(logits[b], dim=0)
74 |
75 | size = (n_stage, logits[0].size(0), logits[0].size(1))
76 | ts_logits = torch.Tensor().resize_(size).zero_()
77 | for b in range(n_stage):
78 | ts_logits[b].copy_(logits[b])
79 |
80 | targets = torch.cat(targets, dim=0)
81 | ts_targets = torch.Tensor().resize_(size[1]).copy_(targets)
82 |
83 | return ts_logits, ts_targets
84 |
85 | def dynamic_eval_find_threshold(self, logits, targets, p, flops):
86 | """
87 | logits: m * n * c
88 | m: Stages
89 | n: Samples
90 | c: Classes
91 | """
92 | n_stage, n_sample, c = logits.size()
93 |
94 | max_preds, argmax_preds = logits.max(dim=2, keepdim=False)
95 |
96 | _, sorted_idx = max_preds.sort(dim=1, descending=True)
97 |
98 | filtered = torch.zeros(n_sample)
99 | T = torch.Tensor(n_stage).fill_(1e8)
100 |
101 | for k in range(n_stage - 1):
102 | acc, count = 0.0, 0
103 | out_n = math.floor(n_sample * p[k])
104 | for i in range(n_sample):
105 | ori_idx = sorted_idx[k][i]
106 | if filtered[ori_idx] == 0:
107 | count += 1
108 | if count == out_n:
109 | T[k] = max_preds[k][ori_idx]
110 | break
111 | filtered.add_(max_preds[k].ge(T[k]).type_as(filtered))
112 |
113 | T[n_stage -1] = -1e8 # accept all of the samples at the last stage
114 |
115 | acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage)
116 | acc, expected_flops = 0, 0
117 | for i in range(n_sample):
118 | gold_label = targets[i]
119 | for k in range(n_stage):
120 | if max_preds[k][i].item() >= T[k]: # force the sample to exit at k
121 | if int(gold_label.item()) == int(argmax_preds[k][i].item()):
122 | acc += 1
123 | acc_rec[k] += 1
124 | exp[k] += 1
125 | break
126 | acc_all = 0
127 | for k in range(n_stage):
128 | _t = 1.0 * exp[k] / n_sample
129 | expected_flops += _t * flops[k]
130 | acc_all += acc_rec[k]
131 |
132 | return acc * 100.0 / n_sample, expected_flops, T
133 |
134 | def dynamic_eval_with_threshold(self, logits, targets, flops, T):
135 | n_stage, n_sample, _ = logits.size()
136 | max_preds, argmax_preds = logits.max(dim=2, keepdim=False) # take the max logits as confidence
137 |
138 | exit_buckets = {i:{j:[] for j in range(n_stage)} for i in range(1000)} # for each exit use a bucket to keep track of samples outputing from it
139 |
140 | acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage)
141 | acc, expected_flops = 0, 0
142 | for i in range(n_sample):
143 | gold_label = targets[i]
144 | for k in range(n_stage):
145 | if max_preds[k][i].item() >= T[k]: # force to exit at k
146 | _g = int(gold_label.item())
147 | _pred = int(argmax_preds[k][i].item())
148 | if _g == _pred:
149 | acc += 1
150 | acc_rec[k] += 1
151 | exp[k] += 1
152 | exit_buckets[int(gold_label)][k].append(i)
153 | break
154 |
155 | acc_all, sample_all = 0, 0
156 | for k in range(n_stage):
157 | _t = exp[k] * 1.0 / n_sample
158 | sample_all += exp[k]
159 | expected_flops += _t * flops[k]
160 | acc_all += acc_rec[k]
161 |
162 | return acc * 100.0 / n_sample, expected_flops, exit_buckets
163 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import time
4 | import argparse
5 |
6 | model_names = ['RANet']
7 |
8 | arg_parser = argparse.ArgumentParser(description='RANet Image classification')
9 |
10 | exp_group = arg_parser.add_argument_group('exp', 'experiment setting')
11 | exp_group.add_argument('--save', default='save/default-{}'.format(time.time()),
12 | type=str, metavar='SAVE',
13 | help='path to the experiment logging directory'
14 | '(default: save/debug)')
15 | exp_group.add_argument('--resume', action='store_true', default=None,
16 | help='path to latest checkpoint (default: none)')
17 | exp_group.add_argument('--evalmode', default=None,
18 | choices=['anytime', 'dynamic', 'both'],
19 | help='which mode to evaluate')
20 | exp_group.add_argument('--evaluate-from', default='', type=str, metavar='PATH',
21 | help='path to saved checkpoint (default: none)')
22 | exp_group.add_argument('--print-freq', '-p', default=10, type=int,
23 | metavar='N', help='print frequency (default: 100)')
24 | exp_group.add_argument('--seed', default=0, type=int,
25 | help='random seed')
26 | exp_group.add_argument('--gpu', default='0', type=str, help='GPU available.')
27 |
28 | # dataset related
29 | data_group = arg_parser.add_argument_group('data', 'dataset setting')
30 | data_group.add_argument('--data', metavar='D', default='cifar10',
31 | choices=['cifar10', 'cifar100', 'ImageNet'],
32 | help='data to work on')
33 | data_group.add_argument('--data-root', metavar='DIR', default='/data/cx/data',
34 | help='path to dataset (default: data)')
35 | data_group.add_argument('--use-valid', action='store_true', default=False,
36 | help='use validation set or not')
37 | data_group.add_argument('-j', '--workers', default=4, type=int, metavar='N',
38 | help='number of data loading workers (default: 4)')
39 |
40 | # model arch related
41 | arch_group = arg_parser.add_argument_group('arch', 'model architecture setting')
42 | arch_group.add_argument('--arch', type=str, default='RANet')
43 | arch_group.add_argument('--reduction', default=0.5, type=float,
44 | metavar='C', help='compression ratio of DenseNet'
45 | ' (1 means dot\'t use compression) (default: 0.5)')
46 |
47 | # msdnet config
48 | arch_group.add_argument('--nBlocks', type=int, default=2)
49 | arch_group.add_argument('--nChannels', type=int, default=16)
50 | arch_group.add_argument('--growthRate', type=int, default=6)
51 | arch_group.add_argument('--grFactor', default='4-2-1', type=str)
52 | arch_group.add_argument('--bnFactor', default='4-2-1', type=str)
53 | arch_group.add_argument('--block-step', type=int, default=2)
54 | arch_group.add_argument('--scale-list', default='1-2-3', type=str)
55 | arch_group.add_argument('--compress-factor', default=0.25, type=float)
56 | arch_group.add_argument('--step', type=int, default=4)
57 | arch_group.add_argument('--stepmode', type=str, default='even', choices=['even', 'lg'])
58 | arch_group.add_argument('--bnAfter', action='store_true', default=True)
59 |
60 |
61 | # training related
62 | optim_group = arg_parser.add_argument_group('optimization', 'optimization setting')
63 | optim_group.add_argument('--epochs', default=300, type=int, metavar='N',
64 | help='number of total epochs to run (default: 300)')
65 | optim_group.add_argument('--start-epoch', default=0, type=int, metavar='N',
66 | help='manual epoch number (useful on restarts)')
67 | optim_group.add_argument('-b', '--batch-size', default=64, type=int,
68 | metavar='N', help='mini-batch size (default: 64)')
69 | optim_group.add_argument('--optimizer', default='sgd',
70 | choices=['sgd', 'rmsprop', 'adam'], metavar='N',
71 | help='optimizer (default=sgd)')
72 | optim_group.add_argument('--lr', '--learning-rate', default=0.1, type=float,
73 | metavar='LR',
74 | help='initial learning rate (default: 0.1)')
75 | optim_group.add_argument('--lr-type', default='multistep', type=str, metavar='T',
76 | help='learning rate strategy (default: multistep)',
77 | choices=['cosine', 'multistep'])
78 | optim_group.add_argument('--decay-rate', default=0.1, type=float, metavar='N',
79 | help='decay rate of learning rate (default: 0.1)')
80 | optim_group.add_argument('--momentum', default=0.9, type=float, metavar='M',
81 | help='momentum (default=0.9)')
82 | optim_group.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
83 | metavar='W', help='weight decay (default: 1e-4)')
84 |
85 | args = arg_parser.parse_args()
86 |
87 | args.grFactor = list(map(int, args.grFactor.split('-')))
88 | args.bnFactor = list(map(int, args.bnFactor.split('-')))
89 | args.scale_list = list(map(int, args.scale_list.split('-')))
90 | args.nScales = len(args.grFactor)
91 |
92 | if args.use_valid:
93 | args.splits = ['train', 'val', 'test']
94 | else:
95 | args.splits = ['train', 'val']
96 |
97 | if args.data == 'cifar10':
98 | args.num_classes = 10
99 | elif args.data == 'cifar100':
100 | args.num_classes = 100
101 | else:
102 | args.num_classes = 1000
103 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.transforms as transforms
4 | import torchvision.datasets as datasets
5 |
6 |
7 | def get_dataloaders(args):
8 | train_loader, val_loader, test_loader = None, None, None
9 | if args.data == 'cifar10':
10 | normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467],
11 | std=[0.2471, 0.2435, 0.2616])
12 | train_set = datasets.CIFAR10(args.data_root, train=True,
13 | transform=transforms.Compose([
14 | transforms.RandomCrop(32, padding=4),
15 | transforms.RandomHorizontalFlip(),
16 | transforms.ToTensor(),
17 | normalize
18 | ]))
19 | val_set = datasets.CIFAR10(args.data_root, train=False,
20 | transform=transforms.Compose([
21 | transforms.ToTensor(),
22 | normalize
23 | ]))
24 | elif args.data == 'cifar100':
25 | normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
26 | std=[0.2675, 0.2565, 0.2761])
27 | train_set = datasets.CIFAR100(args.data_root, train=True,
28 | transform=transforms.Compose([
29 | transforms.RandomCrop(32, padding=4),
30 | transforms.RandomHorizontalFlip(),
31 | transforms.ToTensor(),
32 | normalize
33 | ]))
34 | val_set = datasets.CIFAR100(args.data_root, train=False,
35 | transform=transforms.Compose([
36 | transforms.ToTensor(),
37 | normalize
38 | ]))
39 | else:
40 | # ImageNet
41 | traindir = os.path.join(args.data_root, 'train')
42 | valdir = os.path.join(args.data_root, 'val')
43 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
44 | std=[0.229, 0.224, 0.225])
45 | train_set = datasets.ImageFolder(traindir, transforms.Compose([
46 | transforms.RandomResizedCrop(224),
47 | transforms.RandomHorizontalFlip(),
48 | transforms.ToTensor(),
49 | normalize
50 | ]))
51 | val_set = datasets.ImageFolder(valdir, transforms.Compose([
52 | transforms.Resize(256),
53 | transforms.CenterCrop(224),
54 | transforms.ToTensor(),
55 | normalize
56 | ]))
57 | if args.use_valid:
58 | train_set_index = torch.randperm(len(train_set))
59 | if os.path.exists(os.path.join(args.save, 'index.pth')):
60 | print('!!!!!! Load train_set_index !!!!!!')
61 | train_set_index = torch.load(os.path.join(args.save, 'index.pth'))
62 | else:
63 | print('!!!!!! Save train_set_index !!!!!!')
64 | torch.save(train_set_index, os.path.join(args.save, 'index.pth'))
65 | if args.data.startswith('cifar'):
66 | num_sample_valid = 5000
67 | else:
68 | num_sample_valid = 50000
69 |
70 | if 'train' in args.splits:
71 | train_loader = torch.utils.data.DataLoader(
72 | train_set, batch_size=args.batch_size,
73 | sampler=torch.utils.data.sampler.SubsetRandomSampler(
74 | train_set_index[:-num_sample_valid]),
75 | num_workers=args.workers, pin_memory=False)
76 | if 'val' in args.splits:
77 | val_loader = torch.utils.data.DataLoader(
78 | train_set, batch_size=args.batch_size,
79 | sampler=torch.utils.data.sampler.SubsetRandomSampler(
80 | train_set_index[-num_sample_valid:]),
81 | num_workers=args.workers, pin_memory=False)
82 | if 'test' in args.splits:
83 | test_loader = torch.utils.data.DataLoader(
84 | val_set,
85 | batch_size=args.batch_size, shuffle=False,
86 | num_workers=args.workers, pin_memory=False)
87 | else:
88 | if 'train' in args.splits:
89 | train_loader = torch.utils.data.DataLoader(
90 | train_set,
91 | batch_size=args.batch_size, shuffle=True,
92 | num_workers=args.workers, pin_memory=False)
93 | if 'val' or 'test' in args.splits:
94 | val_loader = torch.utils.data.DataLoader(
95 | val_set,
96 | batch_size=args.batch_size, shuffle=False,
97 | num_workers=args.workers, pin_memory=False)
98 | test_loader = val_loader
99 |
100 | return train_loader, val_loader, test_loader
101 |
--------------------------------------------------------------------------------
/imgs/RANet_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangle15/RANet-pytorch/be0f25a2286160bc612181507cb44a3ff2cd0e46/imgs/RANet_overview.png
--------------------------------------------------------------------------------
/imgs/anytime_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangle15/RANet-pytorch/be0f25a2286160bc612181507cb44a3ff2cd0e46/imgs/anytime_results.png
--------------------------------------------------------------------------------
/imgs/dynamic_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yangle15/RANet-pytorch/be0f25a2286160bc612181507cb44a3ff2cd0e46/imgs/dynamic_results.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import shutil
5 | import models
6 |
7 | from dataloader import get_dataloaders
8 | from args import args
9 | from adaptive_inference import dynamic_evaluate
10 | from op_counter import measure_model
11 |
12 | import torch
13 | import torch.optim
14 | import torch.nn as nn
15 | import torch.backends.cudnn as cudnn
16 |
17 | torch.manual_seed(args.seed)
18 |
19 | if args.gpu:
20 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
21 |
22 |
23 | def main():
24 |
25 | global args
26 | best_prec1, best_epoch = 0.0, 0
27 |
28 | if not os.path.exists(args.save):
29 | os.makedirs(args.save)
30 |
31 | if args.data.startswith('cifar'):
32 | IM_SIZE = 32
33 | else:
34 | IM_SIZE = 224
35 |
36 | print(args.arch)
37 | model = getattr(models, args.arch)(args)
38 | args.num_exits = len(model.classifier)
39 | global n_flops
40 |
41 | n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE)
42 |
43 | torch.save(n_flops, os.path.join(args.save, 'flops.pth'))
44 | del(model)
45 |
46 | print(args)
47 | with open('{}/args.txt'.format(args.save), 'w') as f:
48 | print(args, file=f)
49 |
50 | model = getattr(models, args.arch)(args)
51 | model = torch.nn.DataParallel(model.cuda())
52 | criterion = nn.CrossEntropyLoss().cuda()
53 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
54 |
55 | if args.resume:
56 | checkpoint = load_checkpoint(args)
57 | if checkpoint is not None:
58 | args.start_epoch = checkpoint['epoch'] + 1
59 | best_prec1 = checkpoint['best_prec1']
60 | model.load_state_dict(checkpoint['state_dict'])
61 | optimizer.load_state_dict(checkpoint['optimizer'])
62 |
63 | cudnn.benchmark = True
64 |
65 | train_loader, val_loader, test_loader = get_dataloaders(args)
66 |
67 | if args.evalmode is not None:
68 | state_dict = torch.load(args.evaluate_from)['state_dict']
69 | model.load_state_dict(state_dict)
70 |
71 | if args.evalmode == 'anytime':
72 | validate(test_loader, model, criterion)
73 | elif args.evalmode == 'dynamic':
74 | dynamic_evaluate(model, test_loader, val_loader, args)
75 | else:
76 | validate(test_loader, model, criterion)
77 | dynamic_evaluate(model, test_loader, val_loader, args)
78 | return
79 |
80 | scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1'
81 | '\tval_prec1\ttrain_prec5\tval_prec5']
82 |
83 | for epoch in range(args.start_epoch, args.epochs):
84 |
85 | train_loss, train_prec1, train_prec5, lr = train(train_loader, model, criterion, optimizer, epoch)
86 |
87 | val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)
88 |
89 | scores.append(('{}\t{:.3f}' + '\t{:.4f}' * 6)
90 | .format(epoch, lr, train_loss, val_loss,
91 | train_prec1, val_prec1, train_prec5, val_prec5))
92 |
93 | is_best = val_prec1 > best_prec1
94 | if is_best:
95 | best_prec1 = val_prec1
96 | best_epoch = epoch
97 | print('Best var_prec1 {}'.format(best_prec1))
98 |
99 | model_filename = 'checkpoint_%03d.pth.tar' % epoch
100 | save_checkpoint({
101 | 'epoch': epoch,
102 | 'arch': args.arch,
103 | 'state_dict': model.state_dict(),
104 | 'best_prec1': best_prec1,
105 | 'optimizer': optimizer.state_dict(),
106 | }, args, is_best, model_filename, scores)
107 |
108 | model_path = '%s/save_models/checkpoint_%03d.pth.tar' % (args.save, epoch-1)
109 | if os.path.exists(model_path):
110 | os.remove(model_path)
111 |
112 | print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch))
113 |
114 | ### Test the final model
115 | print('********** Final prediction results **********')
116 | validate(test_loader, model, criterion)
117 |
118 | return
119 |
120 | def train(train_loader, model, criterion, optimizer, epoch):
121 | batch_time = AverageMeter()
122 | data_time = AverageMeter()
123 | losses = AverageMeter()
124 | top1, top5 = [], []
125 | for i in range(args.num_exits):
126 | top1.append(AverageMeter())
127 | top5.append(AverageMeter())
128 |
129 | # switch to train mode
130 | model.train()
131 |
132 | end = time.time()
133 |
134 | running_lr = None
135 | for i, (input, target) in enumerate(train_loader):
136 | lr = adjust_learning_rate(optimizer, epoch, args, batch=i,
137 | nBatch=len(train_loader), method=args.lr_type)
138 |
139 | if running_lr is None:
140 | running_lr = lr
141 |
142 | data_time.update(time.time() - end)
143 |
144 | target = target.cuda(non_blocking=True)
145 | input_var = torch.autograd.Variable(input)
146 | target_var = torch.autograd.Variable(target)
147 |
148 | output = model(input_var)
149 | if not isinstance(output, list):
150 | output = [output]
151 |
152 | loss = 0.0
153 | for j in range(len(output)):
154 | loss += criterion(output[j], target_var)
155 |
156 | losses.update(loss.item(), input.size(0))
157 |
158 | for j in range(len(output)):
159 | prec1, prec5 = accuracy(output[j].data, target, topk=(1, 5))
160 | top1[j].update(prec1.item(), input.size(0))
161 | top5[j].update(prec5.item(), input.size(0))
162 |
163 | # compute gradient and do SGD step
164 | optimizer.zero_grad()
165 | loss.backward()
166 | optimizer.step()
167 |
168 | # measure elapsed time
169 | batch_time.update(time.time() - end)
170 | end = time.time()
171 |
172 | if i % args.print_freq == 0:
173 | print('Epoch: [{0}][{1}/{2}]\t'
174 | 'Time {batch_time.avg:.3f}\t'
175 | 'Data {data_time.avg:.3f}\t'
176 | 'Loss {loss.val:.4f}\t'
177 | 'Acc@1 {top1.val:.4f}\t'
178 | 'Acc@5 {top5.val:.4f}'.format(
179 | epoch, i + 1, len(train_loader),
180 | batch_time=batch_time, data_time=data_time,
181 | loss=losses, top1=top1[-1], top5=top5[-1]))
182 |
183 | return losses.avg, top1[-1].avg, top5[-1].avg, running_lr
184 |
185 | def validate(val_loader, model, criterion):
186 | batch_time = AverageMeter()
187 | losses = AverageMeter()
188 | data_time = AverageMeter()
189 | top1, top5 = [], []
190 | for i in range(args.num_exits):
191 | top1.append(AverageMeter())
192 | top5.append(AverageMeter())
193 |
194 | model.eval()
195 |
196 | end = time.time()
197 | with torch.no_grad():
198 | for i, (input, target) in enumerate(val_loader):
199 | target = target.cuda(non_blocking=True)
200 | input = input.cuda()
201 |
202 | input_var = torch.autograd.Variable(input)
203 | target_var = torch.autograd.Variable(target)
204 |
205 | data_time.update(time.time() - end)
206 |
207 | output = model(input_var)
208 | if not isinstance(output, list):
209 | output = [output]
210 |
211 | loss = 0.0
212 | for j in range(len(output)):
213 | loss += criterion(output[j], target_var)
214 |
215 | losses.update(loss.item(), input.size(0))
216 |
217 | for j in range(len(output)):
218 | prec1, prec5 = accuracy(output[j].data, target, topk=(1, 5))
219 | top1[j].update(prec1.item(), input.size(0))
220 | top5[j].update(prec5.item(), input.size(0))
221 |
222 | # measure elapsed time
223 | batch_time.update(time.time() - end)
224 | end = time.time()
225 |
226 | if i % args.print_freq == 0:
227 | print('Epoch: [{0}/{1}]\t'
228 | 'Time {batch_time.avg:.3f}\t'
229 | 'Data {data_time.avg:.3f}\t'
230 | 'Loss {loss.val:.4f}\t'
231 | 'Acc@1 {top1.val:.4f}\t'
232 | 'Acc@5 {top5.val:.4f}'.format(
233 | i + 1, len(val_loader),
234 | batch_time=batch_time, data_time=data_time,
235 | loss=losses, top1=top1[-1], top5=top5[-1]))
236 |
237 | result_file = os.path.join(args.save, 'AnytimeResults.txt')
238 |
239 | fd = open(result_file, 'w+')
240 | fd.write('AnytimeResults' + '\n')
241 | for j in range(args.num_exits):
242 | test_str = (' @{ext}** flops {flops:.2f}M prec@1 {top1.avg:.3f} prec@5 {top5.avg:.3f}'.format(ext = j+1, flops=n_flops[j]/1e6, top1=top1[j], top5=top5[j]))
243 | print(test_str)
244 | fd = open(result_file, 'a+')
245 | fd.write(test_str + '\n')
246 | fd.close()
247 | torch.save([e.avg for e in top1], os.path.join(args.save, 'acc.pth'))
248 | return losses.avg, top1[-1].avg, top5[-1].avg
249 |
250 | def save_checkpoint(state, args, is_best, filename, result):
251 | print(args)
252 | result_filename = os.path.join(args.save, 'scores.tsv')
253 | model_dir = os.path.join(args.save, 'save_models')
254 | latest_filename = os.path.join(model_dir, 'latest.txt')
255 | model_filename = os.path.join(model_dir, filename)
256 | best_filename = os.path.join(model_dir, 'model_best.pth.tar')
257 | os.makedirs(args.save, exist_ok=True)
258 | os.makedirs(model_dir, exist_ok=True)
259 | print("=> saving checkpoint '{}'".format(model_filename))
260 |
261 | torch.save(state, model_filename)
262 |
263 | with open(result_filename, 'w') as f:
264 | print('\n'.join(result), file=f)
265 |
266 | with open(latest_filename, 'w') as fout:
267 | fout.write(model_filename)
268 |
269 | if is_best:
270 | shutil.copyfile(model_filename, best_filename)
271 | best_filename_epoch = os.path.join(model_dir, 'best_model_epoch.txt')
272 | with open(best_filename_epoch, 'w') as fout:
273 | fout.write(model_filename)
274 |
275 | print("=> saved checkpoint '{}'".format(model_filename))
276 | return
277 |
278 | def load_checkpoint(args):
279 | model_dir = os.path.join(args.save, 'save_models')
280 | latest_filename = os.path.join(model_dir, 'latest.txt')
281 | if os.path.exists(latest_filename):
282 | with open(latest_filename, 'r') as fin:
283 | model_filename = fin.readlines()[0].strip()
284 | else:
285 | return None
286 | print("=> loading checkpoint '{}'".format(model_filename))
287 | state = torch.load(model_filename)
288 | print("=> loaded checkpoint '{}'".format(model_filename))
289 | return state
290 |
291 | class AverageMeter(object):
292 | """Computes and stores the average and current value"""
293 |
294 | def __init__(self):
295 | self.reset()
296 |
297 | def reset(self):
298 | self.val = 0
299 | self.avg = 0
300 | self.sum = 0
301 | self.count = 0
302 |
303 | def update(self, val, n=1):
304 | self.val = val
305 | self.sum += val * n
306 | self.count += n
307 | self.avg = self.sum / self.count
308 |
309 | def accuracy(output, target, topk=(1,)):
310 | """Computes the precor@k for the specified values of k"""
311 | maxk = max(topk)
312 | batch_size = target.size(0)
313 |
314 | _, pred = output.topk(maxk, 1, True, True)
315 | pred = pred.t()
316 | correct = pred.eq(target.view(1, -1).expand_as(pred))
317 |
318 | res = []
319 | for k in topk:
320 | correct_k = correct[:k].view(-1).float().sum(0)
321 | res.append(correct_k.mul_(100.0 / batch_size))
322 | return res
323 |
324 | def adjust_learning_rate(optimizer, epoch, args, batch=None,
325 | nBatch=None, method='multistep'):
326 | if method == 'cosine':
327 | T_total = args.epochs * nBatch
328 | T_cur = (epoch % args.epochs) * nBatch + batch
329 | lr = 0.5 * args.lr * (1 + math.cos(math.pi * T_cur / T_total))
330 | elif method == 'multistep':
331 | if args.data.startswith('cifar'):
332 | lr, decay_rate = args.lr, 0.1
333 | if epoch >= args.epochs * 0.75:
334 | lr *= decay_rate ** 2
335 | elif epoch >= args.epochs * 0.5:
336 | lr *= decay_rate
337 | else:
338 | lr = args.lr * (0.1 ** (epoch // 30))
339 | for param_group in optimizer.param_groups:
340 | param_group['lr'] = lr
341 | return lr
342 |
343 | if __name__ == '__main__':
344 | main()
345 |
--------------------------------------------------------------------------------
/models/RANet.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import os
3 | import copy
4 | import math
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class ConvBasic(nn.Module):
13 | def __init__(self, nIn, nOut, kernel=3, stride=1, padding=1):
14 | super(ConvBasic, self).__init__()
15 | self.net = nn.Sequential(
16 | nn.Conv2d(nIn, nOut, kernel_size=kernel, stride=stride,
17 | padding=padding, bias=False),
18 | nn.BatchNorm2d(nOut),
19 | nn.ReLU(True)
20 | )
21 |
22 | def forward(self, x):
23 | return self.net(x)
24 |
25 |
26 | class ConvBN(nn.Module):
27 | def __init__(self, nIn, nOut, type: str, bnAfter, bnWidth):
28 | """
29 | a basic conv in RANet, two type
30 | :param nIn:
31 | :param nOut:
32 | :param type: normal or down
33 | :param bnAfter: the location of batch Norm
34 | :param bnWidth: bottleneck factor
35 | """
36 | super(ConvBN, self).__init__()
37 | layer = []
38 | nInner = nIn
39 | if bnAfter is True:
40 | nInner = min(nInner, bnWidth * nOut)
41 | layer.append(nn.Conv2d(
42 | nIn, nInner, kernel_size=1, stride=1, padding=0, bias=False))
43 | layer.append(nn.BatchNorm2d(nInner))
44 | layer.append(nn.ReLU(True))
45 | if type == 'normal':
46 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3,
47 | stride=1, padding=1, bias=False))
48 | elif type == 'down':
49 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3,
50 | stride=2, padding=1, bias=False))
51 | else:
52 | raise ValueError
53 | layer.append(nn.BatchNorm2d(nOut))
54 | layer.append(nn.ReLU(True))
55 |
56 | else:
57 | nInner = min(nInner, bnWidth * nOut)
58 | layer.append(nn.BatchNorm2d(nIn))
59 | layer.append(nn.ReLU(True))
60 | layer.append(nn.Conv2d(
61 | nIn, nInner, kernel_size=1, stride=1, padding=0, bias=False))
62 | layer.append(nn.BatchNorm2d(nInner))
63 | layer.append(nn.ReLU(True))
64 | if type == 'normal':
65 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3,
66 | stride=1, padding=1, bias=False))
67 | elif type == 'down':
68 | layer.append(nn.Conv2d(nInner, nOut, kernel_size=3,
69 | stride=2, padding=1, bias=False))
70 | else:
71 | raise ValueError
72 |
73 | self.net = nn.Sequential(*layer)
74 |
75 | def forward(self, x):
76 | return self.net(x)
77 |
78 |
79 | class ConvUpNormal(nn.Module):
80 | def __init__(self, nIn1, nIn2, nOut, bottleneck, bnWidth1, bnWidth2, compress_factor, down_sample):
81 | '''
82 | The convolution with normal and up-sampling connection.
83 | '''
84 | super(ConvUpNormal, self).__init__()
85 | self.conv_up = ConvBN(nIn2, math.floor(nOut*compress_factor), 'normal',
86 | bottleneck, bnWidth2)
87 | if down_sample:
88 | self.conv_normal = ConvBN(nIn1, nOut-math.floor(nOut*compress_factor), 'down',
89 | bottleneck, bnWidth1)
90 | else:
91 | self.conv_normal = ConvBN(nIn1, nOut-math.floor(nOut*compress_factor), 'normal',
92 | bottleneck, bnWidth1)
93 |
94 | def forward(self, x):
95 | res = self.conv_normal(x[1])
96 | _,_,h,w = res.size()
97 | res = [F.interpolate(x[1], size=(h,w), mode = 'bilinear', align_corners=True),
98 | F.interpolate(self.conv_up(x[0]), size=(h,w), mode = 'bilinear', align_corners=True),
99 | res]
100 | return torch.cat(res, dim=1)
101 |
102 |
103 | class ConvNormal(nn.Module):
104 | def __init__(self, nIn, nOut, bottleneck, bnWidth):
105 | '''
106 | The convolution with normal connection.
107 | '''
108 | super(ConvNormal, self).__init__()
109 | self.conv_normal = ConvBN(nIn, nOut, 'normal',
110 | bottleneck, bnWidth)
111 |
112 | def forward(self, x):
113 | if not isinstance(x, list):
114 | x = [x]
115 | res = [x[0], self.conv_normal(x[0])]
116 | return torch.cat(res, dim=1)
117 |
118 |
119 | class _BlockNormal(nn.Module):
120 | def __init__(self, num_layers, nIn, growth_rate, reduction_rate, trans, bnFactor):
121 | '''
122 | The basic computational block in RANet with num_layers layers.
123 | trans: If True, the block will add a transiation layer at the end of the block
124 | with reduction_rate.
125 | '''
126 | super(_BlockNormal, self).__init__()
127 | self.layers = nn.ModuleList()
128 | self.num_layers = num_layers
129 | for i in range(num_layers):
130 | self.layers.append(ConvNormal(nIn + i*growth_rate, growth_rate, True, bnFactor))
131 | nOut = nIn + num_layers*growth_rate
132 | self.trans_flag = trans
133 | if trans:
134 | self.trans = ConvBasic(nOut, math.floor(1.0 * reduction_rate * nOut), kernel=1, stride=1, padding=0)
135 |
136 | def forward(self, x):
137 | output = [x]
138 | for i in range(self.num_layers):
139 | x = self.layers[i](x)
140 | # print(x.size())
141 | output.append(x)
142 | x = output[-1]
143 | if self.trans_flag:
144 | x = self.trans(x)
145 | return x, output
146 |
147 | def _blockType(self):
148 | return 'norm'
149 |
150 |
151 | class _BlockUpNormal(nn.Module):
152 | def __init__(self, num_layers, nIn, nIn_lowFtrs, growth_rate, reduction_rate, trans, down, compress_factor, bnFactor1, bnFactor2):
153 | '''
154 | The basic fusion block in RANet with num_layers layers.
155 | trans: If True, the block will add a transiation layer at the end of the block
156 | with reduction_rate.
157 | compress_factor: There will be compress_factor*100% information from the previous
158 | sub-network.
159 | '''
160 | super(_BlockUpNormal, self).__init__()
161 |
162 | self.layers = nn.ModuleList()
163 | self.num_layers = num_layers
164 | for i in range(num_layers-1):
165 | self.layers.append(ConvUpNormal(nIn + i*growth_rate, nIn_lowFtrs[i], growth_rate, True, bnFactor1, bnFactor2, compress_factor, False))
166 |
167 | self.layers.append(ConvUpNormal(nIn + (i+1)*growth_rate, nIn_lowFtrs[i+1], growth_rate, True, bnFactor1, bnFactor2, compress_factor, down))
168 | nOut = nIn + num_layers*growth_rate
169 |
170 | self.conv_last = ConvBasic(nIn_lowFtrs[num_layers], math.floor(nOut*compress_factor), kernel=1, stride=1, padding=0)
171 | nOut = nOut + math.floor(nOut*compress_factor)
172 | self.trans_flag = trans
173 | if trans:
174 | self.trans = ConvBasic(nOut, math.floor(1.0 * reduction_rate * nOut), kernel=1, stride=1, padding=0)
175 |
176 | def forward(self, x, low_feat):
177 | output = [x]
178 | for i in range(self.num_layers):
179 | inp = [low_feat[i]]
180 | inp.append(x)
181 | x = self.layers[i](inp)
182 | output.append(x)
183 | x = output[-1]
184 | _,_,h,w = x.size()
185 | x = [x]
186 | x.append(F.interpolate(self.conv_last(low_feat[self.num_layers]), size=(h,w), mode = 'bilinear', align_corners=True))
187 | x = torch.cat(x, dim = 1)
188 | if self.trans_flag:
189 | x = self.trans(x)
190 | return x, output
191 |
192 | def _blockType(self):
193 | return 'up'
194 |
195 |
196 | class RAFirstLayer(nn.Module):
197 | def __init__(self, nIn, nOut, args):
198 | '''
199 | RAFirstLayer gennerates the base features for RANet.
200 | The scale 1 means the lowest resoultion in the network.
201 | '''
202 | super(RAFirstLayer, self).__init__()
203 | _grFactor = args.grFactor[::-1] # 1-2-4
204 | _scale_list = args.scale_list[::-1] # 3-2-1
205 | self.layers = nn.ModuleList()
206 | if args.data.startswith('cifar'):
207 | self.layers.append(ConvBasic(nIn, nOut * _grFactor[0],
208 | kernel=3, stride=1, padding=1))
209 | elif args.data == 'ImageNet':
210 | conv = nn.Sequential(
211 | nn.Conv2d(nIn, nOut * _grFactor[0], 7, 2, 3),
212 | nn.BatchNorm2d(nOut * _grFactor[0]),
213 | nn.ReLU(inplace=True),
214 | nn.MaxPool2d(3, 2, 1))
215 | self.layers.append(conv)
216 |
217 | nIn = nOut * _grFactor[0]
218 |
219 | s = _scale_list[0]
220 | for i in range(1, args.nScales):
221 | if s == _scale_list[i]:
222 | self.layers.append(ConvBasic(nIn, nOut * _grFactor[i],
223 | kernel=3, stride=1, padding=1))
224 | else:
225 | self.layers.append(ConvBasic(nIn, nOut * _grFactor[i],
226 | kernel=3, stride=2, padding=1))
227 | s = _scale_list[i]
228 | nIn = nOut * _grFactor[i]
229 |
230 | def forward(self, x):
231 | # res[0] with the smallest resolutions
232 | res = []
233 | for i in range(len(self.layers)):
234 | x = self.layers[i](x)
235 | res.append(x)
236 | return res[::-1]
237 |
238 |
239 | class RANet(nn.Module):
240 | def __init__(self, args):
241 | super(RANet, self).__init__()
242 | self.scale_flows = nn.ModuleList()
243 | self.classifier = nn.ModuleList()
244 |
245 | # self.args = args
246 | self.compress_factor = args.compress_factor
247 | self.bnFactor = copy.copy(args.bnFactor)
248 |
249 | scale_list = args.scale_list # 1-2-3
250 | self.nScales = len(args.scale_list) # 3
251 |
252 | # The number of blocks in each scale flow
253 | self.nBlocks = [0]
254 | for i in range(self.nScales):
255 | self.nBlocks.append(args.block_step*i + args.nBlocks) # [0, 2, 4, 6]
256 |
257 | # The number of layers in each block
258 | self.steps = args.step
259 |
260 | self.FirstLayer = RAFirstLayer(3, args.nChannels, args)
261 |
262 | steps = [args.step]
263 | for ii in range(self.nScales):
264 |
265 | scale_flow = nn.ModuleList()
266 |
267 | n_block_curr = 1
268 | nIn = args.nChannels*args.grFactor[ii] # grFactor = [4,2,1]
269 | _nIn_lowFtrs = []
270 |
271 | for i in range(self.nBlocks[ii+1]):
272 | growth_rate = args.growthRate*args.grFactor[ii]
273 |
274 | # If transiation
275 | trans = self._trans_flag(n_block_curr, n_block_all = self.nBlocks[ii+1], inScale = scale_list[ii])
276 |
277 | if n_block_curr > self.nBlocks[ii]:
278 | m, nOuts = self._build_norm_block(nIn, steps[n_block_curr-1], growth_rate, args.reduction, trans, bnFactor=self.bnFactor[ii])
279 | if args.stepmode == 'even':
280 | steps.append(args.step)
281 | elif args.stepmode == 'lg':
282 | steps.append(steps[-1]+args.step)
283 | else:
284 | raise NotImplementedError
285 | else:
286 | if n_block_curr in self.nBlocks[:ii+1][-(scale_list[ii]-1):]:
287 | m, nOuts = self._build_upNorm_block(nIn, nIn_lowFtrs[i], steps[n_block_curr-1], growth_rate, args.reduction, trans, down=True, bnFactor1=self.bnFactor[ii], bnFactor2=self.bnFactor[ii-1])
288 | else:
289 | m, nOuts = self._build_upNorm_block(nIn, nIn_lowFtrs[i], steps[n_block_curr-1], growth_rate, args.reduction, trans, down=False, bnFactor1=self.bnFactor[ii], bnFactor2=self.bnFactor[ii-1])
290 |
291 | nIn = nOuts[-1]
292 | scale_flow.append(m)
293 |
294 | if n_block_curr > self.nBlocks[ii]:
295 | if args.data.startswith('cifar100'):
296 | self.classifier.append(
297 | self._build_classifier_cifar(nIn, 100))
298 | elif args.data.startswith('cifar10'):
299 | self.classifier.append(self._build_classifier_cifar(nIn, 10))
300 | elif args.data == 'ImageNet':
301 | self.classifier.append(
302 | self._build_classifier_imagenet(nIn, 1000))
303 | else:
304 | raise NotImplementedError
305 |
306 | _nIn_lowFtrs.append(nOuts[:-1])
307 | n_block_curr += 1
308 |
309 | nIn_lowFtrs = _nIn_lowFtrs
310 | self.scale_flows.append(scale_flow)
311 |
312 | args.num_exits = len(self.classifier)
313 |
314 | for m in self.scale_flows:
315 | for _m in m.modules():
316 | self._init_weights(_m)
317 |
318 | for m in self.classifier:
319 | for _m in m.modules():
320 | self._init_weights(_m)
321 |
322 | def _init_weights(self, m):
323 | if isinstance(m, nn.Conv2d):
324 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
325 | m.weight.data.normal_(0, math.sqrt(2. / n))
326 | elif isinstance(m, nn.BatchNorm2d):
327 | m.weight.data.fill_(1)
328 | m.bias.data.zero_()
329 | elif isinstance(m, nn.Linear):
330 | m.bias.data.zero_()
331 |
332 | def _build_norm_block(self, nIn, step, growth_rate, reduction_rate, trans, bnFactor=2):
333 |
334 | block = _BlockNormal(step, nIn, growth_rate, reduction_rate, trans, bnFactor=bnFactor)
335 | nOuts = []
336 | for i in range(step+1):
337 | nOut = (nIn + i * growth_rate)
338 | nOuts.append(nOut)
339 | if trans:
340 | nOut = math.floor(1.0 * reduction_rate * nOut)
341 | nOuts.append(nOut)
342 |
343 | return block, nOuts
344 |
345 | def _build_upNorm_block(self, nIn, nIn_lowFtr, step, growth_rate, reduction_rate, trans, down, bnFactor1=1, bnFactor2=2):
346 | compress_factor = self.compress_factor
347 |
348 | block = _BlockUpNormal(step, nIn, nIn_lowFtr, growth_rate, reduction_rate, trans, down, compress_factor, bnFactor1=bnFactor1, bnFactor2=bnFactor2)
349 | nOuts = []
350 | for i in range(step+1):
351 | nOut = (nIn + i * growth_rate)
352 | nOuts.append(nOut)
353 | nOut = nOut + math.floor(nOut*compress_factor)
354 | if trans:
355 | nOut = math.floor(1.0 * reduction_rate * nOut)
356 | nOuts.append(nOut)
357 |
358 | return block, nOuts
359 |
360 | def _trans_flag(self, n_block_curr, n_block_all, inScale):
361 | flag = False
362 | for i in range(inScale-1):
363 | if n_block_curr == math.floor((i+1)*n_block_all /inScale):
364 | flag = True
365 | return flag
366 |
367 | def forward(self, x):
368 | inp = self.FirstLayer(x)
369 | res, low_ftrs = [], []
370 | classifier_idx = 0
371 | for ii in range(self.nScales):
372 | _x = inp[ii]
373 | _low_ftrs = []
374 | n_block_curr = 0
375 | for i in range(self.nBlocks[ii+1]):
376 | if self.scale_flows[ii][i]._blockType() == 'norm':
377 | _x, _low_ftr = self.scale_flows[ii][i](_x)
378 | _low_ftrs.append(_low_ftr)
379 | else:
380 | _x, _low_ftr = self.scale_flows[ii][i](_x, low_ftrs[i])
381 | _low_ftrs.append(_low_ftr)
382 | n_block_curr += 1
383 |
384 | if n_block_curr > self.nBlocks[ii]:
385 | res.append(self.classifier[classifier_idx](_x))
386 | classifier_idx += 1
387 |
388 | low_ftrs = _low_ftrs
389 | return res
390 |
391 | def _build_classifier_cifar(self, nIn, num_classes):
392 | interChannels1, interChannels2 = 128, 128
393 | conv = nn.Sequential(
394 | ConvBasic(nIn, interChannels1, kernel=3, stride=2, padding=1),
395 | ConvBasic(interChannels1, interChannels2, kernel=3, stride=2, padding=1),
396 | nn.AvgPool2d(2),
397 | )
398 | return ClassifierModule(conv, interChannels2, num_classes)
399 |
400 | def _build_classifier_imagenet(self, nIn, num_classes):
401 | conv = nn.Sequential(
402 | ConvBasic(nIn, nIn, kernel=3, stride=2, padding=1),
403 | ConvBasic(nIn, nIn, kernel=3, stride=2, padding=1),
404 | nn.AvgPool2d(2)
405 | )
406 | return ClassifierModule(conv, nIn, num_classes)
407 |
408 | class ClassifierModule(nn.Module):
409 | def __init__(self, m, channel, num_classes):
410 | super(ClassifierModule, self).__init__()
411 | self.m = m
412 | self.linear = nn.Linear(channel, num_classes)
413 | def forward(self, x):
414 | res = self.m(x)
415 | res = res.view(res.size(0), -1)
416 | return self.linear(res)
417 |
418 |
419 | if __name__ == '__main__':
420 | from args_v5 import arg_parser
421 | from op_counter import measure_model
422 |
423 | args = arg_parser.parse_args()
424 | # if args.gpu:
425 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
426 |
427 | args.nBlocks = 2
428 | args.Block_base = 2
429 | args.step = 8
430 | args.stepmode ='even'
431 | args.compress_factor = 0.25
432 | args.nChannels = 64
433 | args.data = 'ImageNet'
434 | args.growthRate = 16
435 |
436 | args.grFactor = '4-2-2-1'
437 | args.bnFactor = '4-2-2-1'
438 | args.scale_list = '1-2-3-4'
439 |
440 | args.reduction = 0.5
441 |
442 | args.grFactor = list(map(int, args.grFactor.split('-')))
443 | args.bnFactor = list(map(int, args.bnFactor.split('-')))
444 | args.scale_list = list(map(int, args.scale_list.split('-')))
445 | args.nScales = len(args.grFactor)
446 | # print(args.grFactor)
447 | if args.use_valid:
448 | args.splits = ['train', 'val', 'test']
449 | else:
450 | args.splits = ['train', 'val']
451 |
452 | if args.data == 'cifar10':
453 | args.num_classes = 10
454 | elif args.data == 'cifar100':
455 | args.num_classes = 100
456 | else:
457 | args.num_classes = 1000
458 |
459 | inp_c = torch.rand(16,3,224,224)
460 |
461 | model = MSDNet(args)
462 | # output = model(inp_c)
463 | # oup = net_head(inp_c)
464 | # print(len(oup))
465 |
466 | n_flops, n_params = measure_model(model, 224, 224)
467 | # net = _BlockNormal(num_layers = 4, nIn = 64, growth_rate = 24, reduction_rate = 0.5, trans_down = True)
468 |
469 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # from .msdnet import MSDNet as msdnet
2 | # from .msdnetV5 import MSDNet as msdnetV5
3 | # from .msdnetV5_imagenet import MSDNet as msdnetV5_imagenet
4 | #from .msdnetV5_bnf import MSDNet as msdnetV5_bnf
5 | #from .msdnetV5_bnf2 import MSDNet as msdnetV5_bnf2
6 | from .RANet import RANet
7 |
8 | #from .msdnetV5_bnf_lg_ba import MSDNet as msdnetv5_ba
9 | #from .msdnetV5_bnf_lg_ba_drop import MSDNet as msdnetv5_ba_drop
10 |
11 | #from .ranet_1 import MSDNet as ranet1
12 |
13 |
--------------------------------------------------------------------------------
/op_counter.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import unicode_literals
3 | from __future__ import print_function
4 | from __future__ import division
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.autograd import Variable
9 | from functools import reduce
10 | import operator
11 |
12 | '''
13 | Calculate the FLOPS of each exit without lazy prediction pruning"
14 | '''
15 |
16 | count_ops = 0
17 | count_params = 0
18 | cls_ops = []
19 | cls_params = []
20 |
21 | def get_num_gen(gen):
22 | return sum(1 for x in gen)
23 |
24 |
25 | def is_leaf(model):
26 | return get_num_gen(model.children()) == 0
27 |
28 |
29 | def get_layer_info(layer):
30 | layer_str = str(layer)
31 | type_name = layer_str[:layer_str.find('(')].strip()
32 | return type_name
33 |
34 |
35 | def get_layer_param(model):
36 | return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])
37 |
38 |
39 | ### The input batch size should be 1 to call this function
40 | def measure_layer(layer, x):
41 | global count_ops, count_params, cls_ops, cls_params
42 | delta_ops = 0
43 | delta_params = 0
44 | multi_add = 1
45 | type_name = get_layer_info(layer)
46 |
47 | ### ops_conv
48 | if type_name in ['Conv2d']:
49 | out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
50 | layer.stride[0] + 1)
51 | out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
52 | layer.stride[1] + 1)
53 | delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
54 | layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
55 | delta_params = get_layer_param(layer)
56 |
57 | ### ops_nonlinearity
58 | elif type_name in ['ReLU']:
59 | delta_ops = x.numel()
60 | delta_params = get_layer_param(layer)
61 |
62 | ### ops_pooling
63 | elif type_name in ['AvgPool2d', 'MaxPool2d']:
64 | in_w = x.size()[2]
65 | kernel_ops = layer.kernel_size * layer.kernel_size
66 | out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
67 | out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
68 | delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops
69 | delta_params = get_layer_param(layer)
70 |
71 | elif type_name in ['AdaptiveAvgPool2d']:
72 | delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3]
73 | delta_params = get_layer_param(layer)
74 |
75 | ### ops_linear
76 | elif type_name in ['Linear']:
77 | weight_ops = layer.weight.numel() * multi_add
78 | bias_ops = layer.bias.numel()
79 | delta_ops = x.size()[0] * (weight_ops + bias_ops)
80 | delta_params = get_layer_param(layer)
81 |
82 | ### ops_nothing
83 | elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout',
84 | 'MSDNFirstLayer', 'ConvBasic', 'ConvBN',
85 | 'ParallelModule', 'MSDNet', 'Sequential',
86 | 'MSDNLayer', 'ConvDownNormal', 'ConvNormal', 'ClassifierModule']:
87 | delta_params = get_layer_param(layer)
88 |
89 |
90 | ### unknown layer type
91 | else:
92 | raise TypeError('unknown layer type: %s' % type_name)
93 |
94 | count_ops += delta_ops
95 | count_params += delta_params
96 | if type_name == 'Linear':
97 | print('---------------------')
98 | print('FLOPs: %.2fM, Params: %.2fM' % (count_ops / 1e6, count_params / 1e6))
99 | cls_ops.append(count_ops)
100 | cls_params.append(count_params)
101 |
102 | return
103 |
104 |
105 | def measure_model(model, H, W):
106 | global count_ops, count_params, cls_ops, cls_params
107 | count_ops = 0
108 | count_params = 0
109 | data = Variable(torch.zeros(1, 3, H, W))
110 |
111 | def should_measure(x):
112 | return is_leaf(x)
113 |
114 | def modify_forward(model):
115 | for child in model.children():
116 | if should_measure(child):
117 | def new_forward(m):
118 | def lambda_forward(x):
119 | measure_layer(m, x)
120 | return m.old_forward(x)
121 | return lambda_forward
122 | child.old_forward = child.forward
123 | child.forward = new_forward(child)
124 | else:
125 | modify_forward(child)
126 |
127 | def restore_forward(model):
128 | for child in model.children():
129 | # leaf node
130 | if is_leaf(child) and hasattr(child, 'old_forward'):
131 | child.forward = child.old_forward
132 | child.old_forward = None
133 | else:
134 | restore_forward(child)
135 |
136 | model.eval()
137 | modify_forward(model)
138 | model.forward(data)
139 | restore_forward(model)
140 | return cls_ops, cls_params
141 |
--------------------------------------------------------------------------------
/train_cifar.sh:
--------------------------------------------------------------------------------
1 | python main.py --arch RANet --gpu '0' --data-root {your data root} --data 'cifar10' --step 4 --stepmode 'even' --scale-list '1-2-3-3' --grFactor '4-2-1-1' --bnFactor '4-2-1-1'
--------------------------------------------------------------------------------
/train_imagenet.sh:
--------------------------------------------------------------------------------
1 | python main.py --arch RANet --gpu '0,1,2,3' --data-root {your data root} --data 'ImageNet' --growthRate 16 --step 8 --stepmode 'even' --scale-list '1-2-3-4' --grFactor '4-2-1-1' --bnFactor '4-2-1-1'
--------------------------------------------------------------------------------